mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-03 06:08:15 +00:00
refactor: remove cli/ from crewai package and relocate to proper modules
Move framework infrastructure out of crewai/cli/ to dedicated modules: - cli/authentication/ → crewai/auth/ - cli/config.py → crewai/settings.py - cli/constants.py → crewai/constants.py - cli/plus_api.py → crewai/plus_api.py - cli/version.py → crewai/version.py - cli/crew_chat.py → crewai/utilities/crew_chat.py - cli/reset_memories_command.py → crewai/utilities/reset_memories.py - cli/utils.py (framework parts) → crewai/utilities/project_utils.py Delete CLI-only duplicates (command.py, git.py, provider.py) already present in crewai_cli. Replace _login_to_tool_repository with a _post_login() hook in AuthenticationCommand. Update all imports and mock.patch paths across both packages and tests.
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
"""Wrapper for the crew chat command.
|
||||
|
||||
Delegates to ``crewai.cli.crew_chat.run_chat`` when the full crewai package is
|
||||
installed, otherwise prints a helpful error message.
|
||||
Delegates to ``crewai.utilities.crew_chat.run_chat`` when the full crewai
|
||||
package is installed, otherwise prints a helpful error message.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -11,7 +11,7 @@ import click
|
||||
|
||||
def run_chat() -> None:
|
||||
try:
|
||||
from crewai.cli.crew_chat import run_chat as _run_chat
|
||||
from crewai.utilities.crew_chat import run_chat as _run_chat
|
||||
except ImportError:
|
||||
click.secho(
|
||||
"The 'chat' command requires the full crewai package.\n"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Wrapper for the reset-memories command.
|
||||
|
||||
Delegates to ``crewai.cli.reset_memories_command`` when the full crewai
|
||||
Delegates to ``crewai.utilities.reset_memories`` when the full crewai
|
||||
package is installed, otherwise prints a helpful error message.
|
||||
"""
|
||||
|
||||
@@ -17,7 +17,7 @@ def reset_memories_command(
|
||||
all: bool,
|
||||
) -> None:
|
||||
try:
|
||||
from crewai.cli.reset_memories_command import (
|
||||
from crewai.utilities.reset_memories import (
|
||||
reset_memories_command as _reset,
|
||||
)
|
||||
except ImportError:
|
||||
|
||||
@@ -246,7 +246,7 @@ def is_valid_tool(obj: Any) -> bool:
|
||||
Falls back to crewai's ``is_valid_tool`` when available.
|
||||
"""
|
||||
try:
|
||||
from crewai.cli.utils import is_valid_tool as _core_is_valid_tool
|
||||
from crewai.utilities.project_utils import is_valid_tool as _core_is_valid_tool
|
||||
|
||||
return _core_is_valid_tool(obj)
|
||||
except ImportError:
|
||||
|
||||
7
lib/crewai/src/crewai/auth/__init__.py
Normal file
7
lib/crewai/src/crewai/auth/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""Authentication utilities for the CrewAI platform."""
|
||||
|
||||
from crewai.auth.oauth2 import AuthenticationCommand
|
||||
from crewai.auth.token import AuthError, get_auth_token
|
||||
|
||||
|
||||
__all__ = ["AuthError", "AuthenticationCommand", "get_auth_token"]
|
||||
3
lib/crewai/src/crewai/auth/constants.py
Normal file
3
lib/crewai/src/crewai/auth/constants.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""Authentication constants."""
|
||||
|
||||
ALGORITHMS = ["RS256"]
|
||||
@@ -1,3 +1,5 @@
|
||||
"""OAuth2 authentication for the CrewAI platform."""
|
||||
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, TypeVar, cast
|
||||
import webbrowser
|
||||
@@ -6,9 +8,9 @@ import httpx
|
||||
from pydantic import BaseModel, Field
|
||||
from rich.console import Console
|
||||
|
||||
from crewai.cli.authentication.utils import validate_jwt_token
|
||||
from crewai.cli.config import Settings
|
||||
from crewai.cli.shared.token_manager import TokenManager
|
||||
from crewai.auth.token_manager import TokenManager
|
||||
from crewai.auth.utils import validate_jwt_token
|
||||
from crewai.settings import Settings
|
||||
|
||||
|
||||
console = Console()
|
||||
@@ -17,6 +19,8 @@ TOauth2Settings = TypeVar("TOauth2Settings", bound="Oauth2Settings")
|
||||
|
||||
|
||||
class Oauth2Settings(BaseModel):
|
||||
"""OAuth2 provider configuration."""
|
||||
|
||||
provider: str = Field(
|
||||
description="OAuth2 provider used for authentication (e.g., workos, okta, auth0)."
|
||||
)
|
||||
@@ -38,7 +42,6 @@ class Oauth2Settings(BaseModel):
|
||||
@classmethod
|
||||
def from_settings(cls: type[TOauth2Settings]) -> TOauth2Settings:
|
||||
"""Create an Oauth2Settings instance from the CLI settings."""
|
||||
|
||||
settings = Settings()
|
||||
|
||||
return cls(
|
||||
@@ -51,23 +54,25 @@ class Oauth2Settings(BaseModel):
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.cli.authentication.providers.base_provider import BaseProvider
|
||||
from crewai.auth.providers.base_provider import BaseProvider
|
||||
|
||||
|
||||
class ProviderFactory:
|
||||
"""Factory for creating OAuth2 providers from settings."""
|
||||
|
||||
@classmethod
|
||||
def from_settings(
|
||||
cls: type["ProviderFactory"], # noqa: UP037
|
||||
settings: Oauth2Settings | None = None,
|
||||
) -> "BaseProvider": # noqa: UP037
|
||||
"""Create a provider instance from settings."""
|
||||
settings = settings or Oauth2Settings.from_settings()
|
||||
|
||||
import importlib
|
||||
|
||||
module = importlib.import_module(
|
||||
f"crewai.cli.authentication.providers.{settings.provider.lower()}"
|
||||
f"crewai.auth.providers.{settings.provider.lower()}"
|
||||
)
|
||||
# Converts from snake_case to CamelCase to obtain the provider class name.
|
||||
provider = getattr(
|
||||
module,
|
||||
f"{''.join(word.capitalize() for word in settings.provider.split('_'))}Provider",
|
||||
@@ -77,6 +82,8 @@ class ProviderFactory:
|
||||
|
||||
|
||||
class AuthenticationCommand:
|
||||
"""Handles authentication with the CrewAI platform."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.token_manager = TokenManager()
|
||||
self.oauth2_provider = ProviderFactory.from_settings()
|
||||
@@ -92,7 +99,6 @@ class AuthenticationCommand:
|
||||
|
||||
def _get_device_code(self) -> dict[str, Any]:
|
||||
"""Get the device code to authenticate the user."""
|
||||
|
||||
device_code_payload = {
|
||||
"client_id": self.oauth2_provider.get_client_id(),
|
||||
"scope": " ".join(self.oauth2_provider.get_oauth_scopes()),
|
||||
@@ -108,7 +114,6 @@ class AuthenticationCommand:
|
||||
|
||||
def _display_auth_instructions(self, device_code_data: dict[str, str]) -> None:
|
||||
"""Display the authentication instructions to the user."""
|
||||
|
||||
verification_uri = device_code_data.get(
|
||||
"verification_uri_complete", device_code_data.get("verification_uri", "")
|
||||
)
|
||||
@@ -119,7 +124,6 @@ class AuthenticationCommand:
|
||||
|
||||
def _poll_for_token(self, device_code_data: dict[str, Any]) -> None:
|
||||
"""Polls the server for the token until it is received, or max attempts are reached."""
|
||||
|
||||
token_payload = {
|
||||
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
|
||||
"device_code": device_code_data["device_code"],
|
||||
@@ -143,7 +147,7 @@ class AuthenticationCommand:
|
||||
style="bold green",
|
||||
)
|
||||
|
||||
self._login_to_tool_repository()
|
||||
self._post_login()
|
||||
|
||||
console.print("\n[bold green]Welcome to CrewAI AMP![/bold green]\n")
|
||||
return
|
||||
@@ -162,7 +166,6 @@ class AuthenticationCommand:
|
||||
|
||||
def _validate_and_save_token(self, token_data: dict[str, Any]) -> None:
|
||||
"""Validates the JWT token and saves the token to the token manager."""
|
||||
|
||||
jwt_token = token_data["access_token"]
|
||||
issuer = self.oauth2_provider.get_issuer()
|
||||
jwt_token_data = {
|
||||
@@ -177,39 +180,5 @@ class AuthenticationCommand:
|
||||
expires_at = decoded_token.get("exp", 0)
|
||||
self.token_manager.save_tokens(jwt_token, expires_at)
|
||||
|
||||
def _login_to_tool_repository(self) -> None:
|
||||
"""Login to the tool repository."""
|
||||
|
||||
from crewai_cli.tools.main import ToolCommand
|
||||
|
||||
try:
|
||||
console.print(
|
||||
"Now logging you in to the Tool Repository... ",
|
||||
style="bold blue",
|
||||
end="",
|
||||
)
|
||||
|
||||
ToolCommand().login()
|
||||
|
||||
console.print(
|
||||
"Success!\n",
|
||||
style="bold green",
|
||||
)
|
||||
|
||||
settings = Settings()
|
||||
|
||||
console.print(
|
||||
f"You are now authenticated to the tool repository for organization [bold cyan]'{settings.org_name if settings.org_name else settings.org_uuid}'[/bold cyan]",
|
||||
style="green",
|
||||
)
|
||||
except Exception:
|
||||
console.print(
|
||||
"\n[bold yellow]Warning:[/bold yellow] Authentication with the Tool Repository failed.",
|
||||
style="yellow",
|
||||
)
|
||||
console.print(
|
||||
"Other features will work normally, but you may experience limitations "
|
||||
"with downloading and publishing tools."
|
||||
"\nRun [bold]crewai login[/bold] to try logging in again.\n",
|
||||
style="yellow",
|
||||
)
|
||||
def _post_login(self) -> None:
|
||||
"""Hook called after successful login. Override in subclasses for additional behavior."""
|
||||
1
lib/crewai/src/crewai/auth/providers/__init__.py
Normal file
1
lib/crewai/src/crewai/auth/providers/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""OAuth2 authentication providers."""
|
||||
@@ -1,7 +1,11 @@
|
||||
from crewai.cli.authentication.providers.base_provider import BaseProvider
|
||||
"""Auth0 OAuth2 provider."""
|
||||
|
||||
from crewai.auth.providers.base_provider import BaseProvider
|
||||
|
||||
|
||||
class Auth0Provider(BaseProvider):
|
||||
"""Auth0 OAuth2 provider implementation."""
|
||||
|
||||
def get_authorize_url(self) -> str:
|
||||
return f"https://{self._get_domain()}/oauth/device/code"
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
"""Base OAuth2 provider interface."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from crewai.cli.authentication.main import Oauth2Settings
|
||||
from crewai.auth.oauth2 import Oauth2Settings
|
||||
|
||||
|
||||
class BaseProvider(ABC):
|
||||
"""Abstract base class for OAuth2 providers."""
|
||||
|
||||
def __init__(self, settings: Oauth2Settings):
|
||||
self.settings = settings
|
||||
|
||||
@@ -26,8 +30,9 @@ class BaseProvider(ABC):
|
||||
def get_client_id(self) -> str: ...
|
||||
|
||||
def get_required_fields(self) -> list[str]:
|
||||
"""Returns which provider-specific fields inside the "extra" dict will be required"""
|
||||
"""Returns which provider-specific fields inside the "extra" dict will be required."""
|
||||
return []
|
||||
|
||||
def get_oauth_scopes(self) -> list[str]:
|
||||
"""Returns the OAuth scopes to request."""
|
||||
return ["openid", "profile", "email"]
|
||||
@@ -1,9 +1,13 @@
|
||||
"""Entra ID (Azure AD) OAuth2 provider."""
|
||||
|
||||
from typing import cast
|
||||
|
||||
from crewai.cli.authentication.providers.base_provider import BaseProvider
|
||||
from crewai.auth.providers.base_provider import BaseProvider
|
||||
|
||||
|
||||
class EntraIdProvider(BaseProvider):
|
||||
"""Entra ID (Azure AD) OAuth2 provider implementation."""
|
||||
|
||||
def get_authorize_url(self) -> str:
|
||||
return f"{self._base_url()}/oauth2/v2.0/devicecode"
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
from crewai.cli.authentication.providers.base_provider import BaseProvider
|
||||
"""Keycloak OAuth2 provider."""
|
||||
|
||||
from crewai.auth.providers.base_provider import BaseProvider
|
||||
|
||||
|
||||
class KeycloakProvider(BaseProvider):
|
||||
"""Keycloak OAuth2 provider implementation."""
|
||||
|
||||
def get_authorize_url(self) -> str:
|
||||
return f"{self._oauth2_base_url()}/realms/{self.settings.extra.get('realm')}/protocol/openid-connect/auth/device"
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
from crewai.cli.authentication.providers.base_provider import BaseProvider
|
||||
"""Okta OAuth2 provider."""
|
||||
|
||||
from crewai.auth.providers.base_provider import BaseProvider
|
||||
|
||||
|
||||
class OktaProvider(BaseProvider):
|
||||
"""Okta OAuth2 provider implementation."""
|
||||
|
||||
def get_authorize_url(self) -> str:
|
||||
return f"{self._oauth2_base_url()}/v1/device/authorize"
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
from crewai.cli.authentication.providers.base_provider import BaseProvider
|
||||
"""WorkOS OAuth2 provider."""
|
||||
|
||||
from crewai.auth.providers.base_provider import BaseProvider
|
||||
|
||||
|
||||
class WorkosProvider(BaseProvider):
|
||||
"""WorkOS OAuth2 provider implementation."""
|
||||
|
||||
def get_authorize_url(self) -> str:
|
||||
return f"https://{self._get_domain()}/oauth2/device_authorization"
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from crewai.cli.shared.token_manager import TokenManager
|
||||
"""Authentication token retrieval."""
|
||||
|
||||
from crewai.auth.token_manager import TokenManager
|
||||
|
||||
|
||||
class AuthError(Exception):
|
||||
pass
|
||||
"""Raised when authentication fails."""
|
||||
|
||||
|
||||
def get_auth_token() -> str:
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Manages encrypted token storage."""
|
||||
|
||||
from datetime import datetime
|
||||
import json
|
||||
import os
|
||||
@@ -1,3 +1,5 @@
|
||||
"""JWT token validation utilities."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import jwt
|
||||
@@ -7,18 +9,20 @@ from jwt import PyJWKClient
|
||||
def validate_jwt_token(
|
||||
jwt_token: str, jwks_url: str, issuer: str, audience: str
|
||||
) -> Any:
|
||||
"""
|
||||
Verify the token's signature and claims using PyJWT.
|
||||
:param jwt_token: The JWT (JWS) string to validate.
|
||||
:param jwks_url: The URL of the JWKS endpoint.
|
||||
:param issuer: The expected issuer of the token.
|
||||
:param audience: The expected audience of the token.
|
||||
:return: The decoded token.
|
||||
:raises Exception: If the token is invalid for any reason (e.g., signature mismatch,
|
||||
expired, incorrect issuer/audience, JWKS fetching error,
|
||||
missing required claims).
|
||||
"""
|
||||
"""Verify the token's signature and claims using PyJWT.
|
||||
|
||||
Args:
|
||||
jwt_token: The JWT (JWS) string to validate.
|
||||
jwks_url: The URL of the JWKS endpoint.
|
||||
issuer: The expected issuer of the token.
|
||||
audience: The expected audience of the token.
|
||||
|
||||
Returns:
|
||||
The decoded token.
|
||||
|
||||
Raises:
|
||||
Exception: If the token is invalid for any reason.
|
||||
"""
|
||||
try:
|
||||
jwk_client = PyJWKClient(jwks_url)
|
||||
signing_key = jwk_client.get_signing_key_from_jwt(jwt_token)
|
||||
@@ -1,4 +0,0 @@
|
||||
from crewai.cli.authentication.main import AuthenticationCommand
|
||||
|
||||
|
||||
__all__ = ["AuthenticationCommand"]
|
||||
@@ -1 +0,0 @@
|
||||
ALGORITHMS = ["RS256"]
|
||||
@@ -1,76 +0,0 @@
|
||||
import json
|
||||
|
||||
import httpx
|
||||
from rich.console import Console
|
||||
|
||||
from crewai.cli.authentication.token import get_auth_token
|
||||
from crewai.cli.plus_api import PlusAPI
|
||||
from crewai.telemetry.telemetry import Telemetry
|
||||
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
class BaseCommand:
|
||||
def __init__(self) -> None:
|
||||
self._telemetry = Telemetry()
|
||||
self._telemetry.set_tracer()
|
||||
|
||||
|
||||
class PlusAPIMixin:
|
||||
def __init__(self, telemetry: Telemetry) -> None:
|
||||
try:
|
||||
telemetry.set_tracer()
|
||||
self.plus_api_client = PlusAPI(api_key=get_auth_token())
|
||||
except Exception:
|
||||
telemetry.deploy_signup_error_span()
|
||||
console.print(
|
||||
"Please sign up/login to CrewAI+ before using the CLI.",
|
||||
style="bold red",
|
||||
)
|
||||
console.print("Run 'crewai login' to sign up/login.", style="bold green")
|
||||
raise SystemExit from None
|
||||
|
||||
def _validate_response(self, response: httpx.Response) -> None:
|
||||
"""
|
||||
Handle and display error messages from API responses.
|
||||
|
||||
Args:
|
||||
response (httpx.Response): The response from the Plus API
|
||||
"""
|
||||
try:
|
||||
json_response = response.json()
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
console.print(
|
||||
"Failed to parse response from Enterprise API failed. Details:",
|
||||
style="bold red",
|
||||
)
|
||||
console.print(f"Status Code: {response.status_code}")
|
||||
console.print(
|
||||
f"Response:\n{response.content.decode('utf-8', errors='replace')}"
|
||||
)
|
||||
raise SystemExit from None
|
||||
|
||||
if response.status_code == 422:
|
||||
console.print(
|
||||
"Failed to complete operation. Please fix the following errors:",
|
||||
style="bold red",
|
||||
)
|
||||
for field, messages in json_response.items():
|
||||
for message in messages:
|
||||
console.print(
|
||||
f"* [bold red]{field.capitalize()}[/bold red] {message}"
|
||||
)
|
||||
raise SystemExit
|
||||
|
||||
if not response.is_success:
|
||||
console.print(
|
||||
"Request to Enterprise API failed. Details:", style="bold red"
|
||||
)
|
||||
details = (
|
||||
json_response.get("error")
|
||||
or json_response.get("message")
|
||||
or response.content.decode("utf-8", errors="replace")
|
||||
)
|
||||
console.print(f"{details}")
|
||||
raise SystemExit
|
||||
@@ -1,89 +0,0 @@
|
||||
from functools import lru_cache
|
||||
import subprocess
|
||||
|
||||
|
||||
class Repository:
|
||||
def __init__(self, path: str = ".") -> None:
|
||||
self.path = path
|
||||
|
||||
if not self.is_git_installed():
|
||||
raise ValueError("Git is not installed or not found in your PATH.")
|
||||
|
||||
if not self.is_git_repo():
|
||||
raise ValueError(f"{self.path} is not a Git repository.")
|
||||
|
||||
self.fetch()
|
||||
|
||||
@staticmethod
|
||||
def is_git_installed() -> bool:
|
||||
"""Check if Git is installed and available in the system."""
|
||||
try:
|
||||
subprocess.run(
|
||||
["git", "--version"], # noqa: S607
|
||||
capture_output=True,
|
||||
check=True,
|
||||
text=True,
|
||||
)
|
||||
return True
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
return False
|
||||
|
||||
def fetch(self) -> None:
|
||||
"""Fetch latest updates from the remote."""
|
||||
subprocess.run(["git", "fetch"], cwd=self.path, check=True) # noqa: S607
|
||||
|
||||
def status(self) -> str:
|
||||
"""Get the git status in porcelain format."""
|
||||
return subprocess.check_output(
|
||||
["git", "status", "--branch", "--porcelain"], # noqa: S607
|
||||
cwd=self.path,
|
||||
encoding="utf-8",
|
||||
).strip()
|
||||
|
||||
@lru_cache(maxsize=None) # noqa: B019
|
||||
def is_git_repo(self) -> bool:
|
||||
"""Check if the current directory is a git repository.
|
||||
|
||||
Notes:
|
||||
- TODO: This method is cached to avoid redundant checks, but using lru_cache on methods can lead to memory leaks
|
||||
"""
|
||||
try:
|
||||
subprocess.check_output(
|
||||
["git", "rev-parse", "--is-inside-work-tree"], # noqa: S607
|
||||
cwd=self.path,
|
||||
encoding="utf-8",
|
||||
)
|
||||
return True
|
||||
except subprocess.CalledProcessError:
|
||||
return False
|
||||
|
||||
def has_uncommitted_changes(self) -> bool:
|
||||
"""Check if the repository has uncommitted changes."""
|
||||
return len(self.status().splitlines()) > 1
|
||||
|
||||
def is_ahead_or_behind(self) -> bool:
|
||||
"""Check if the repository is ahead or behind the remote."""
|
||||
for line in self.status().splitlines():
|
||||
if line.startswith("##") and ("ahead" in line or "behind" in line):
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_synced(self) -> bool:
|
||||
"""Return True if the Git repository is fully synced with the remote, False otherwise."""
|
||||
if self.has_uncommitted_changes() or self.is_ahead_or_behind():
|
||||
return False
|
||||
return True
|
||||
|
||||
def origin_url(self) -> str | None:
|
||||
"""Get the Git repository's remote URL."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "remote", "get-url", "origin"], # noqa: S607
|
||||
cwd=self.path,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
return result.stdout.strip()
|
||||
except subprocess.CalledProcessError:
|
||||
return None
|
||||
@@ -1,231 +0,0 @@
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import certifi
|
||||
import click
|
||||
import httpx
|
||||
|
||||
from crewai.cli.constants import JSON_URL, MODELS, PROVIDERS
|
||||
|
||||
|
||||
def select_choice(prompt_message: str, choices: Sequence[str]) -> str | None:
|
||||
"""Presents a list of choices to the user and prompts them to select one.
|
||||
|
||||
Args:
|
||||
prompt_message: The message to display to the user before presenting the choices.
|
||||
choices: A list of options to present to the user.
|
||||
|
||||
Returns:
|
||||
The selected choice from the list, or None if the user chooses to quit.
|
||||
"""
|
||||
|
||||
provider_models = get_provider_data()
|
||||
if not provider_models:
|
||||
return None
|
||||
click.secho(prompt_message, fg="cyan")
|
||||
for idx, choice in enumerate(choices, start=1):
|
||||
click.secho(f"{idx}. {choice}", fg="cyan")
|
||||
click.secho("q. Quit", fg="cyan")
|
||||
|
||||
while True:
|
||||
choice = click.prompt(
|
||||
"Enter the number of your choice or 'q' to quit", type=str
|
||||
)
|
||||
|
||||
if choice.lower() == "q":
|
||||
return None
|
||||
|
||||
try:
|
||||
selected_index = int(choice) - 1
|
||||
if 0 <= selected_index < len(choices):
|
||||
return choices[selected_index]
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
click.secho(
|
||||
"Invalid selection. Please select a number between 1 and 6 or 'q' to quit.",
|
||||
fg="red",
|
||||
)
|
||||
|
||||
|
||||
def select_provider(provider_models: dict[str, list[str]]) -> str | None | bool:
|
||||
"""Presents a list of providers to the user and prompts them to select one.
|
||||
|
||||
Args:
|
||||
provider_models: A dictionary of provider models.
|
||||
|
||||
Returns:
|
||||
The selected provider, None if user explicitly quits, or False if no selection.
|
||||
"""
|
||||
predefined_providers = [p.lower() for p in PROVIDERS]
|
||||
all_providers = sorted(set(predefined_providers + list(provider_models.keys())))
|
||||
|
||||
provider = select_choice(
|
||||
"Select a provider to set up:", [*predefined_providers, "other"]
|
||||
)
|
||||
if provider is None: # User typed 'q'
|
||||
return None
|
||||
|
||||
if provider == "other":
|
||||
provider = select_choice("Select a provider from the full list:", all_providers)
|
||||
if provider is None: # User typed 'q'
|
||||
return None
|
||||
|
||||
return provider.lower() if provider else False
|
||||
|
||||
|
||||
def select_model(provider: str, provider_models: dict[str, list[str]]) -> str | None:
|
||||
"""Presents a list of models for a given provider to the user and prompts them to select one.
|
||||
|
||||
Args:
|
||||
provider: The provider for which to select a model.
|
||||
provider_models: A dictionary of provider models.
|
||||
|
||||
Returns:
|
||||
The selected model, or None if the operation is aborted or an invalid selection is made.
|
||||
"""
|
||||
predefined_providers = [p.lower() for p in PROVIDERS]
|
||||
|
||||
if provider in predefined_providers:
|
||||
available_models = MODELS.get(provider, [])
|
||||
else:
|
||||
available_models = provider_models.get(provider, [])
|
||||
|
||||
if not available_models:
|
||||
click.secho(f"No models available for provider '{provider}'.", fg="red")
|
||||
return None
|
||||
|
||||
return select_choice(
|
||||
f"Select a model to use for {provider.capitalize()}:", available_models
|
||||
)
|
||||
|
||||
|
||||
def load_provider_data(cache_file: Path, cache_expiry: int) -> dict[str, Any] | None:
|
||||
"""Loads provider data from a cache file if it exists and is not expired.
|
||||
|
||||
If the cache is expired or corrupted, it fetches the data from the web.
|
||||
|
||||
Args:
|
||||
cache_file: The path to the cache file.
|
||||
cache_expiry: The cache expiry time in seconds.
|
||||
|
||||
Returns:
|
||||
The loaded provider data or None if the operation fails.
|
||||
"""
|
||||
current_time = time.time()
|
||||
if (
|
||||
cache_file.exists()
|
||||
and (current_time - cache_file.stat().st_mtime) < cache_expiry
|
||||
):
|
||||
data = read_cache_file(cache_file)
|
||||
if data:
|
||||
return data
|
||||
click.secho(
|
||||
"Cache is corrupted. Fetching provider data from the web...", fg="yellow"
|
||||
)
|
||||
else:
|
||||
click.secho(
|
||||
"Cache expired or not found. Fetching provider data from the web...",
|
||||
fg="cyan",
|
||||
)
|
||||
return fetch_provider_data(cache_file)
|
||||
|
||||
|
||||
def read_cache_file(cache_file: Path) -> dict[str, Any] | None:
|
||||
"""Reads and returns the JSON content from a cache file.
|
||||
|
||||
Args:
|
||||
cache_file: The path to the cache file.
|
||||
|
||||
Returns:
|
||||
The JSON content of the cache file or None if the JSON is invalid.
|
||||
"""
|
||||
try:
|
||||
with open(cache_file, "r") as f:
|
||||
data: dict[str, Any] = json.load(f)
|
||||
return data
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
|
||||
def fetch_provider_data(cache_file: Path) -> dict[str, Any] | None:
|
||||
"""Fetches provider data from a specified URL and caches it to a file.
|
||||
|
||||
Args:
|
||||
cache_file: The path to the cache file.
|
||||
|
||||
Returns:
|
||||
The fetched provider data or None if the operation fails.
|
||||
"""
|
||||
ssl_config = os.environ["SSL_CERT_FILE"] = certifi.where()
|
||||
|
||||
try:
|
||||
with httpx.stream("GET", JSON_URL, timeout=60, verify=ssl_config) as response:
|
||||
response.raise_for_status()
|
||||
data = download_data(response)
|
||||
with open(cache_file, "w") as f:
|
||||
json.dump(data, f)
|
||||
return data
|
||||
except httpx.HTTPError as e:
|
||||
click.secho(f"Error fetching provider data: {e}", fg="red")
|
||||
except json.JSONDecodeError:
|
||||
click.secho("Error parsing provider data. Invalid JSON format.", fg="red")
|
||||
return None
|
||||
|
||||
|
||||
def download_data(response: httpx.Response) -> dict[str, Any]:
|
||||
"""Downloads data from a given HTTP response and returns the JSON content.
|
||||
|
||||
Args:
|
||||
response: The HTTP response object.
|
||||
|
||||
Returns:
|
||||
The JSON content of the response.
|
||||
"""
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
block_size = 8192
|
||||
data_chunks: list[bytes] = []
|
||||
bar: Any
|
||||
with click.progressbar(
|
||||
length=total_size, label="Downloading", show_pos=True
|
||||
) as bar:
|
||||
for chunk in response.iter_bytes(block_size):
|
||||
if chunk:
|
||||
data_chunks.append(chunk)
|
||||
bar.update(len(chunk))
|
||||
data_content = b"".join(data_chunks)
|
||||
result: dict[str, Any] = json.loads(data_content.decode("utf-8"))
|
||||
return result
|
||||
|
||||
|
||||
def get_provider_data() -> dict[str, list[str]] | None:
|
||||
"""Retrieves provider data from a cache file.
|
||||
|
||||
Filters out models based on provider criteria, and returns a dictionary of providers
|
||||
mapped to their models.
|
||||
|
||||
Returns:
|
||||
A dictionary of providers mapped to their models or None if the operation fails.
|
||||
"""
|
||||
cache_dir = Path.home() / ".crewai"
|
||||
cache_dir.mkdir(exist_ok=True)
|
||||
cache_file = cache_dir / "provider_cache.json"
|
||||
cache_expiry = 24 * 3600
|
||||
|
||||
data = load_provider_data(cache_file, cache_expiry)
|
||||
if not data:
|
||||
return None
|
||||
|
||||
provider_models = defaultdict(list)
|
||||
for model_name, properties in data.items():
|
||||
provider = properties.get("litellm_provider", "").strip().lower()
|
||||
if "http" in provider or provider == "other":
|
||||
continue
|
||||
if provider:
|
||||
provider_models[provider].append(model_name)
|
||||
return provider_models
|
||||
@@ -1,3 +1,5 @@
|
||||
"""CrewAI constants."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
@@ -8,17 +8,17 @@ import uuid
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
|
||||
from crewai.cli.authentication.token import AuthError, get_auth_token
|
||||
from crewai.cli.config import Settings
|
||||
from crewai.cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL
|
||||
from crewai.cli.plus_api import PlusAPI
|
||||
from crewai.cli.version import get_crewai_version
|
||||
from crewai.auth.token import AuthError, get_auth_token
|
||||
from crewai.constants import DEFAULT_CREWAI_ENTERPRISE_URL
|
||||
from crewai.events.listeners.tracing.types import TraceEvent
|
||||
from crewai.events.listeners.tracing.utils import (
|
||||
get_user_id,
|
||||
is_tracing_enabled_in_context,
|
||||
should_auto_collect_first_time_traces,
|
||||
)
|
||||
from crewai.plus_api import PlusAPI
|
||||
from crewai.settings import Settings
|
||||
from crewai.version import get_crewai_version
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
@@ -6,8 +6,7 @@ import uuid
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.cli.authentication.token import AuthError, get_auth_token
|
||||
from crewai.cli.version import get_crewai_version
|
||||
from crewai.auth.token import AuthError, get_auth_token
|
||||
from crewai.events.base_event_listener import BaseEventListener
|
||||
from crewai.events.base_events import BaseEvent
|
||||
from crewai.events.event_bus import CrewAIEventsBus
|
||||
@@ -110,6 +109,7 @@ from crewai.events.types.tool_usage_events import (
|
||||
ToolUsageStartedEvent,
|
||||
)
|
||||
from crewai.events.utils.console_formatter import ConsoleFormatter
|
||||
from crewai.version import get_crewai_version
|
||||
|
||||
|
||||
class TraceCollectionListener(BaseEventListener):
|
||||
|
||||
@@ -8,7 +8,7 @@ from rich.live import Live
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
from crewai.cli.version import is_current_version_yanked, is_newer_version_available
|
||||
from crewai.version import is_current_version_yanked, is_newer_version_available
|
||||
|
||||
|
||||
_disable_version_check: ContextVar[bool] = ContextVar(
|
||||
|
||||
@@ -195,7 +195,7 @@ class MCPToolResolver:
|
||||
get_platform_integration_token,
|
||||
)
|
||||
|
||||
from crewai.cli.plus_api import PlusAPI
|
||||
from crewai.plus_api import PlusAPI
|
||||
|
||||
plus_api = PlusAPI(api_key=get_platform_integration_token())
|
||||
response = plus_api.get_mcp_configs(slugs)
|
||||
@@ -285,6 +285,7 @@ class MCPToolResolver:
|
||||
independent transport so that parallel tool executions never share
|
||||
state.
|
||||
"""
|
||||
transport: StdioTransport | HTTPTransport | SSETransport
|
||||
if isinstance(mcp_config, MCPServerStdio):
|
||||
transport = StdioTransport(
|
||||
command=mcp_config.command,
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
"""CrewAI+ API client."""
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
|
||||
from crewai.cli.config import Settings
|
||||
from crewai.cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL
|
||||
from crewai.cli.version import get_crewai_version
|
||||
from crewai.constants import DEFAULT_CREWAI_ENTERPRISE_URL
|
||||
from crewai.settings import Settings
|
||||
from crewai.version import get_crewai_version
|
||||
|
||||
|
||||
class PlusAPI:
|
||||
"""
|
||||
This class exposes methods for working with the CrewAI+ API.
|
||||
"""
|
||||
"""Client for working with the CrewAI+ API."""
|
||||
|
||||
TOOLS_RESOURCE = "/crewai_plus/api/v1/tools"
|
||||
ORGANIZATIONS_RESOURCE = "/crewai_plus/api/v1/me/organizations"
|
||||
@@ -1,3 +1,5 @@
|
||||
"""CrewAI platform settings management."""
|
||||
|
||||
import json
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
@@ -6,14 +8,14 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.cli.constants import (
|
||||
from crewai.auth.token_manager import TokenManager
|
||||
from crewai.constants import (
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER,
|
||||
DEFAULT_CREWAI_ENTERPRISE_URL,
|
||||
)
|
||||
from crewai.cli.shared.token_manager import TokenManager
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
@@ -22,8 +24,7 @@ DEFAULT_CONFIG_PATH = Path.home() / ".config" / "crewai" / "settings.json"
|
||||
|
||||
|
||||
def get_writable_config_path() -> Path | None:
|
||||
"""
|
||||
Find a writable location for the config file with fallback options.
|
||||
"""Find a writable location for the config file with fallback options.
|
||||
|
||||
Tries in order:
|
||||
1. Default: ~/.config/crewai/settings.json
|
||||
@@ -32,12 +33,12 @@ def get_writable_config_path() -> Path | None:
|
||||
4. In-memory only (returns None)
|
||||
|
||||
Returns:
|
||||
Path object for writable config location, or None if no writable location found
|
||||
Path object for writable config location, or None if no writable location found.
|
||||
"""
|
||||
fallback_paths = [
|
||||
DEFAULT_CONFIG_PATH, # Default location
|
||||
Path(tempfile.gettempdir()) / "crewai_settings.json", # Temporary directory
|
||||
Path.cwd() / "crewai_settings.json", # Current working directory
|
||||
DEFAULT_CONFIG_PATH,
|
||||
Path(tempfile.gettempdir()) / "crewai_settings.json",
|
||||
Path.cwd() / "crewai_settings.json",
|
||||
]
|
||||
|
||||
for config_path in fallback_paths:
|
||||
@@ -46,7 +47,7 @@ def get_writable_config_path() -> Path | None:
|
||||
test_file = config_path.parent / ".crewai_write_test"
|
||||
try:
|
||||
test_file.write_text("test")
|
||||
test_file.unlink() # Clean up test file
|
||||
test_file.unlink()
|
||||
logger.info(f"Using config path: {config_path}")
|
||||
return config_path
|
||||
except Exception: # noqa: S112
|
||||
@@ -58,7 +59,6 @@ def get_writable_config_path() -> Path | None:
|
||||
return None
|
||||
|
||||
|
||||
# Settings that are related to the user's account
|
||||
USER_SETTINGS_KEYS = [
|
||||
"tool_repository_username",
|
||||
"tool_repository_password",
|
||||
@@ -66,7 +66,6 @@ USER_SETTINGS_KEYS = [
|
||||
"org_uuid",
|
||||
]
|
||||
|
||||
# Settings that are related to the CLI
|
||||
CLI_SETTINGS_KEYS = [
|
||||
"enterprise_base_url",
|
||||
"oauth2_provider",
|
||||
@@ -76,7 +75,6 @@ CLI_SETTINGS_KEYS = [
|
||||
"oauth2_extra",
|
||||
]
|
||||
|
||||
# Default values for CLI settings
|
||||
DEFAULT_CLI_SETTINGS = {
|
||||
"enterprise_base_url": DEFAULT_CREWAI_ENTERPRISE_URL,
|
||||
"oauth2_provider": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER,
|
||||
@@ -86,13 +84,11 @@ DEFAULT_CLI_SETTINGS = {
|
||||
"oauth2_extra": {},
|
||||
}
|
||||
|
||||
# Readonly settings - cannot be set by the user
|
||||
READONLY_SETTINGS_KEYS = [
|
||||
"org_name",
|
||||
"org_uuid",
|
||||
]
|
||||
|
||||
# Hidden settings - not displayed by the 'list' command and cannot be set by the user
|
||||
HIDDEN_SETTINGS_KEYS = [
|
||||
"config_path",
|
||||
"tool_repository_username",
|
||||
@@ -101,8 +97,10 @@ HIDDEN_SETTINGS_KEYS = [
|
||||
|
||||
|
||||
class Settings(BaseModel):
|
||||
"""CrewAI platform settings."""
|
||||
|
||||
enterprise_base_url: str | None = Field(
|
||||
default=DEFAULT_CLI_SETTINGS["enterprise_base_url"],
|
||||
default=DEFAULT_CREWAI_ENTERPRISE_URL,
|
||||
description="Base URL of the CrewAI AMP instance",
|
||||
)
|
||||
tool_repository_username: str | None = Field(
|
||||
@@ -121,22 +119,22 @@ class Settings(BaseModel):
|
||||
|
||||
oauth2_provider: str = Field(
|
||||
description="OAuth2 provider used for authentication (e.g., workos, okta, auth0).",
|
||||
default=DEFAULT_CLI_SETTINGS["oauth2_provider"],
|
||||
default=CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER,
|
||||
)
|
||||
|
||||
oauth2_audience: str | None = Field(
|
||||
description="OAuth2 audience value, typically used to identify the target API or resource.",
|
||||
default=DEFAULT_CLI_SETTINGS["oauth2_audience"],
|
||||
default=CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
|
||||
)
|
||||
|
||||
oauth2_client_id: str = Field(
|
||||
default=DEFAULT_CLI_SETTINGS["oauth2_client_id"],
|
||||
default=CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
|
||||
description="OAuth2 client ID issued by the provider, used during authentication requests.",
|
||||
)
|
||||
|
||||
oauth2_domain: str = Field(
|
||||
description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens.",
|
||||
default=DEFAULT_CLI_SETTINGS["oauth2_domain"],
|
||||
default=CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
|
||||
)
|
||||
|
||||
oauth2_extra: dict[str, Any] = Field(
|
||||
@@ -145,14 +143,12 @@ class Settings(BaseModel):
|
||||
)
|
||||
|
||||
def __init__(self, config_path: Path | None = None, **data: dict[str, Any]) -> None:
|
||||
"""Load Settings from config path with fallback support"""
|
||||
"""Load Settings from config path with fallback support."""
|
||||
if config_path is None:
|
||||
config_path = get_writable_config_path()
|
||||
|
||||
# If config_path is None, we're in memory-only mode
|
||||
if config_path is None:
|
||||
merged_data = {**data}
|
||||
# Dummy path for memory-only mode
|
||||
super().__init__(config_path=Path("/dev/null"), **merged_data)
|
||||
return
|
||||
|
||||
@@ -160,7 +156,6 @@ class Settings(BaseModel):
|
||||
config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
except Exception:
|
||||
merged_data = {**data}
|
||||
# Dummy path for memory-only mode
|
||||
super().__init__(config_path=Path("/dev/null"), **merged_data)
|
||||
return
|
||||
|
||||
@@ -176,19 +171,19 @@ class Settings(BaseModel):
|
||||
super().__init__(config_path=config_path, **merged_data)
|
||||
|
||||
def clear_user_settings(self) -> None:
|
||||
"""Clear all user settings"""
|
||||
"""Clear all user settings."""
|
||||
self._reset_user_settings()
|
||||
self.dump()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset all settings to default values"""
|
||||
"""Reset all settings to default values."""
|
||||
self._reset_user_settings()
|
||||
self._reset_cli_settings()
|
||||
self._clear_auth_tokens()
|
||||
self.dump()
|
||||
|
||||
def dump(self) -> None:
|
||||
"""Save current settings to settings.json"""
|
||||
"""Save current settings to settings.json."""
|
||||
if str(self.config_path) == "/dev/null":
|
||||
return
|
||||
|
||||
@@ -207,15 +202,15 @@ class Settings(BaseModel):
|
||||
pass
|
||||
|
||||
def _reset_user_settings(self) -> None:
|
||||
"""Reset all user settings to default values"""
|
||||
"""Reset all user settings to default values."""
|
||||
for key in USER_SETTINGS_KEYS:
|
||||
setattr(self, key, None)
|
||||
|
||||
def _reset_cli_settings(self) -> None:
|
||||
"""Reset all CLI settings to default values"""
|
||||
"""Reset all CLI settings to default values."""
|
||||
for key in CLI_SETTINGS_KEYS:
|
||||
setattr(self, key, DEFAULT_CLI_SETTINGS.get(key))
|
||||
|
||||
def _clear_auth_tokens(self) -> None:
|
||||
"""Clear all authentication tokens"""
|
||||
"""Clear all authentication tokens."""
|
||||
TokenManager().clear_tokens()
|
||||
@@ -19,8 +19,8 @@ from crewai.agents.parser import (
|
||||
OutputParserError,
|
||||
parse,
|
||||
)
|
||||
from crewai.cli.config import Settings
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.settings import Settings
|
||||
from crewai.tools import BaseTool as CrewAITool
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
@@ -1047,8 +1047,8 @@ def load_agent_from_repository(from_repository: str) -> dict[str, Any]:
|
||||
if callable(_create_plus_client_hook):
|
||||
client = _create_plus_client_hook()
|
||||
else:
|
||||
from crewai.cli.authentication.token import get_auth_token
|
||||
from crewai.cli.plus_api import PlusAPI
|
||||
from crewai.auth.token import get_auth_token
|
||||
from crewai.plus_api import PlusAPI
|
||||
|
||||
client = PlusAPI(api_key=get_auth_token())
|
||||
_print_current_organization()
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Interactive chat interface for CrewAI crews."""
|
||||
|
||||
import contextvars
|
||||
import json
|
||||
from pathlib import Path
|
||||
@@ -12,15 +14,15 @@ import click
|
||||
from packaging import version
|
||||
import tomli
|
||||
|
||||
from crewai.cli.utils import read_toml
|
||||
from crewai.cli.version import get_crewai_version
|
||||
from crewai.crew import Crew
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.types.crew_chat import ChatInputField, ChatInputs
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
from crewai.utilities.printer import Printer
|
||||
from crewai.utilities.project_utils import read_toml
|
||||
from crewai.utilities.types import LLMMessage
|
||||
from crewai.version import get_crewai_version
|
||||
|
||||
|
||||
_printer = Printer()
|
||||
@@ -31,15 +33,14 @@ MIN_REQUIRED_VERSION: Final[Literal["0.98.0"]] = "0.98.0"
|
||||
def check_conversational_crews_version(
|
||||
crewai_version: str, pyproject_data: dict[str, Any]
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the installed crewAI version supports conversational crews.
|
||||
"""Check if the installed crewAI version supports conversational crews.
|
||||
|
||||
Args:
|
||||
crewai_version: The current version of crewAI.
|
||||
pyproject_data: Dictionary containing pyproject.toml data.
|
||||
|
||||
Returns:
|
||||
bool: True if version check passes, False otherwise.
|
||||
True if version check passes, False otherwise.
|
||||
"""
|
||||
try:
|
||||
if version.parse(crewai_version) < version.parse(MIN_REQUIRED_VERSION):
|
||||
@@ -56,8 +57,8 @@ def check_conversational_crews_version(
|
||||
|
||||
|
||||
def run_chat() -> None:
|
||||
"""
|
||||
Runs an interactive chat loop using the Crew's chat LLM with function calling.
|
||||
"""Run an interactive chat loop using the Crew's chat LLM with function calling.
|
||||
|
||||
Incorporates crew_name, crew_description, and input fields to build a tool schema.
|
||||
Exits if crew_name or crew_description are missing.
|
||||
"""
|
||||
@@ -72,14 +73,12 @@ def run_chat() -> None:
|
||||
if not chat_llm:
|
||||
return
|
||||
|
||||
# Indicate that the crew is being analyzed
|
||||
click.secho(
|
||||
"\nAnalyzing crew and required inputs - this may take 3 to 30 seconds "
|
||||
"depending on the complexity of your crew.",
|
||||
fg="white",
|
||||
)
|
||||
|
||||
# Start loading indicator
|
||||
loading_complete = threading.Event()
|
||||
ctx = contextvars.copy_context()
|
||||
loading_thread = threading.Thread(
|
||||
@@ -92,16 +91,13 @@ def run_chat() -> None:
|
||||
crew_tool_schema = generate_crew_tool_schema(crew_chat_inputs)
|
||||
system_message = build_system_message(crew_chat_inputs)
|
||||
|
||||
# Call the LLM to generate the introductory message
|
||||
introductory_message = chat_llm.call(
|
||||
messages=[{"role": "system", "content": system_message}]
|
||||
)
|
||||
finally:
|
||||
# Stop loading indicator
|
||||
loading_complete.set()
|
||||
loading_thread.join()
|
||||
|
||||
# Indicate that the analysis is complete
|
||||
click.secho("\nFinished analyzing crew.\n", fg="white")
|
||||
|
||||
click.secho(f"Assistant: {introductory_message}\n", fg="green")
|
||||
@@ -127,7 +123,7 @@ def show_loading(event: threading.Event) -> None:
|
||||
|
||||
|
||||
def initialize_chat_llm(crew: Crew) -> LLM | BaseLLM | None:
|
||||
"""Initializes the chat LLM and handles exceptions."""
|
||||
"""Initialize the chat LLM and handle exceptions."""
|
||||
try:
|
||||
return create_llm(crew.chat_llm)
|
||||
except Exception as e:
|
||||
@@ -139,7 +135,7 @@ def initialize_chat_llm(crew: Crew) -> LLM | BaseLLM | None:
|
||||
|
||||
|
||||
def build_system_message(crew_chat_inputs: ChatInputs) -> str:
|
||||
"""Builds the initial system message for the chat."""
|
||||
"""Build the initial system message for the chat."""
|
||||
required_fields_str = (
|
||||
", ".join(
|
||||
f"{field.name} (desc: {field.description or 'n/a'})"
|
||||
@@ -168,7 +164,7 @@ def build_system_message(crew_chat_inputs: ChatInputs) -> str:
|
||||
|
||||
|
||||
def create_tool_function(crew: Crew, messages: list[LLMMessage]) -> Any:
|
||||
"""Creates a wrapper function for running the crew tool with messages."""
|
||||
"""Create a wrapper function for running the crew tool with messages."""
|
||||
|
||||
def run_crew_tool_with_messages(**kwargs: Any) -> str:
|
||||
return run_crew_tool(crew, messages, **kwargs)
|
||||
@@ -179,13 +175,11 @@ def create_tool_function(crew: Crew, messages: list[LLMMessage]) -> Any:
|
||||
def flush_input() -> None:
|
||||
"""Flush any pending input from the user."""
|
||||
if platform.system() == "Windows":
|
||||
# Windows platform
|
||||
import msvcrt
|
||||
|
||||
while msvcrt.kbhit(): # type: ignore[attr-defined]
|
||||
msvcrt.getch() # type: ignore[attr-defined]
|
||||
else:
|
||||
# Unix-like platforms (Linux, macOS)
|
||||
import termios
|
||||
|
||||
termios.tcflush(sys.stdin, termios.TCIFLUSH)
|
||||
@@ -200,7 +194,6 @@ def chat_loop(
|
||||
"""Main chat loop for interacting with the user."""
|
||||
while True:
|
||||
try:
|
||||
# Flush any pending input before accepting new input
|
||||
flush_input()
|
||||
|
||||
user_input = get_user_input()
|
||||
@@ -250,11 +243,9 @@ def handle_user_input(
|
||||
|
||||
messages.append({"role": "user", "content": user_input})
|
||||
|
||||
# Indicate that assistant is processing
|
||||
click.echo()
|
||||
click.secho("Assistant is processing your input. Please wait...", fg="green")
|
||||
|
||||
# Process assistant's response
|
||||
final_response = chat_llm.call(
|
||||
messages=messages,
|
||||
tools=[crew_tool_schema],
|
||||
@@ -266,12 +257,11 @@ def handle_user_input(
|
||||
|
||||
|
||||
def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict[str, Any]:
|
||||
"""
|
||||
Dynamically build a Littellm 'function' schema for the given crew.
|
||||
"""Dynamically build a Littellm 'function' schema for the given crew.
|
||||
|
||||
crew_name: The name of the crew (used for the function 'name').
|
||||
crew_inputs: A ChatInputs object containing crew_description
|
||||
and a list of input fields (each with a name & description).
|
||||
Args:
|
||||
crew_inputs: A ChatInputs object containing crew_description
|
||||
and a list of input fields (each with a name & description).
|
||||
"""
|
||||
properties = {}
|
||||
for field in crew_inputs.inputs:
|
||||
@@ -297,70 +287,51 @@ def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict[str, Any]:
|
||||
|
||||
|
||||
def run_crew_tool(crew: Crew, messages: list[LLMMessage], **kwargs: Any) -> str:
|
||||
"""
|
||||
Runs the crew using crew.kickoff(inputs=kwargs) and returns the output.
|
||||
"""Run the crew using crew.kickoff(inputs=kwargs) and return the output.
|
||||
|
||||
Args:
|
||||
crew (Crew): The crew instance to run.
|
||||
messages (List[Dict[str, str]]): The chat messages up to this point.
|
||||
crew: The crew instance to run.
|
||||
messages: The chat messages up to this point.
|
||||
**kwargs: The inputs collected from the user.
|
||||
|
||||
Returns:
|
||||
str: The output from the crew's execution.
|
||||
|
||||
Raises:
|
||||
SystemExit: Exits the chat if an error occurs during crew execution.
|
||||
The output from the crew's execution.
|
||||
"""
|
||||
try:
|
||||
# Serialize 'messages' to JSON string before adding to kwargs
|
||||
kwargs["crew_chat_messages"] = json.dumps(messages)
|
||||
|
||||
# Run the crew with the provided inputs
|
||||
crew_output = crew.kickoff(inputs=kwargs)
|
||||
|
||||
# Convert CrewOutput to a string to send back to the user
|
||||
return str(crew_output)
|
||||
|
||||
except Exception as e:
|
||||
# Exit the chat and show the error message
|
||||
click.secho("An error occurred while running the crew:", fg="red")
|
||||
click.secho(str(e), fg="red")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def load_crew_and_name() -> tuple[Crew, str]:
|
||||
"""
|
||||
Loads the crew by importing the crew class from the user's project.
|
||||
"""Load the crew by importing the crew class from the user's project.
|
||||
|
||||
Returns:
|
||||
Tuple[Crew, str]: A tuple containing the Crew instance and the name of the crew.
|
||||
A tuple containing the Crew instance and the name of the crew.
|
||||
"""
|
||||
# Get the current working directory
|
||||
cwd = Path.cwd()
|
||||
|
||||
# Path to the pyproject.toml file
|
||||
pyproject_path = cwd / "pyproject.toml"
|
||||
if not pyproject_path.exists():
|
||||
raise FileNotFoundError("pyproject.toml not found in the current directory.")
|
||||
|
||||
# Load the pyproject.toml file using 'tomli'
|
||||
with pyproject_path.open("rb") as f:
|
||||
pyproject_data = tomli.load(f)
|
||||
|
||||
# Get the project name from the 'project' section
|
||||
project_name = pyproject_data["project"]["name"]
|
||||
folder_name = project_name
|
||||
|
||||
# Derive the crew class name from the project name
|
||||
# E.g., if project_name is 'my_project', crew_class_name is 'MyProject'
|
||||
crew_class_name = project_name.replace("_", " ").title().replace(" ", "")
|
||||
|
||||
# Add the 'src' directory to sys.path
|
||||
src_path = cwd / "src"
|
||||
if str(src_path) not in sys.path:
|
||||
sys.path.insert(0, str(src_path))
|
||||
|
||||
# Import the crew module
|
||||
crew_module_name = f"{folder_name}.crew"
|
||||
try:
|
||||
crew_module = __import__(crew_module_name, fromlist=[crew_class_name])
|
||||
@@ -369,7 +340,6 @@ def load_crew_and_name() -> tuple[Crew, str]:
|
||||
f"Failed to import crew module {crew_module_name}: {e}"
|
||||
) from e
|
||||
|
||||
# Get the crew class from the module
|
||||
try:
|
||||
crew_class = getattr(crew_module, crew_class_name)
|
||||
except AttributeError as e:
|
||||
@@ -377,7 +347,6 @@ def load_crew_and_name() -> tuple[Crew, str]:
|
||||
f"Crew class {crew_class_name} not found in module {crew_module_name}"
|
||||
) from e
|
||||
|
||||
# Instantiate the crew
|
||||
crew_instance = crew_class().crew()
|
||||
return crew_instance, crew_class_name
|
||||
|
||||
@@ -385,27 +354,23 @@ def load_crew_and_name() -> tuple[Crew, str]:
|
||||
def generate_crew_chat_inputs(
|
||||
crew: Crew, crew_name: str, chat_llm: LLM | BaseLLM
|
||||
) -> ChatInputs:
|
||||
"""
|
||||
Generates the ChatInputs required for the crew by analyzing the tasks and agents.
|
||||
"""Generate the ChatInputs required for the crew by analyzing the tasks and agents.
|
||||
|
||||
Args:
|
||||
crew (Crew): The crew object containing tasks and agents.
|
||||
crew_name (str): The name of the crew.
|
||||
crew: The crew object containing tasks and agents.
|
||||
crew_name: The name of the crew.
|
||||
chat_llm: The chat language model to use for AI calls.
|
||||
|
||||
Returns:
|
||||
ChatInputs: An object containing the crew's name, description, and input fields.
|
||||
An object containing the crew's name, description, and input fields.
|
||||
"""
|
||||
# Extract placeholders from tasks and agents
|
||||
required_inputs = fetch_required_inputs(crew)
|
||||
|
||||
# Generate descriptions for each input using AI
|
||||
input_fields = []
|
||||
for input_name in required_inputs:
|
||||
description = generate_input_description_with_ai(input_name, crew, chat_llm)
|
||||
input_fields.append(ChatInputField(name=input_name, description=description))
|
||||
|
||||
# Generate crew description using AI
|
||||
crew_description = generate_crew_description_with_ai(crew, chat_llm)
|
||||
|
||||
return ChatInputs(
|
||||
@@ -414,13 +379,13 @@ def generate_crew_chat_inputs(
|
||||
|
||||
|
||||
def fetch_required_inputs(crew: Crew) -> set[str]:
|
||||
"""Extracts placeholders from the crew's tasks and agents.
|
||||
"""Extract placeholders from the crew's tasks and agents.
|
||||
|
||||
Args:
|
||||
crew (Crew): The crew object.
|
||||
crew: The crew object.
|
||||
|
||||
Returns:
|
||||
Set[str]: A set of placeholder names.
|
||||
A set of placeholder names.
|
||||
"""
|
||||
return crew.fetch_inputs()
|
||||
|
||||
@@ -428,18 +393,16 @@ def fetch_required_inputs(crew: Crew) -> set[str]:
|
||||
def generate_input_description_with_ai(
|
||||
input_name: str, crew: Crew, chat_llm: LLM | BaseLLM
|
||||
) -> str:
|
||||
"""
|
||||
Generates an input description using AI based on the context of the crew.
|
||||
"""Generate an input description using AI based on the context of the crew.
|
||||
|
||||
Args:
|
||||
input_name (str): The name of the input placeholder.
|
||||
crew (Crew): The crew object.
|
||||
input_name: The name of the input placeholder.
|
||||
crew: The crew object.
|
||||
chat_llm: The chat language model to use for AI calls.
|
||||
|
||||
Returns:
|
||||
str: A concise description of the input.
|
||||
A concise description of the input.
|
||||
"""
|
||||
# Gather context from tasks and agents where the input is used
|
||||
context_texts = []
|
||||
placeholder_pattern = re.compile(r"\{(.+?)}")
|
||||
|
||||
@@ -448,7 +411,6 @@ def generate_input_description_with_ai(
|
||||
f"{{{input_name}}}" in task.description
|
||||
or f"{{{input_name}}}" in task.expected_output
|
||||
):
|
||||
# Replace placeholders with input names
|
||||
task_description = placeholder_pattern.sub(
|
||||
lambda m: m.group(1), task.description or ""
|
||||
)
|
||||
@@ -463,7 +425,6 @@ def generate_input_description_with_ai(
|
||||
or f"{{{input_name}}}" in agent.goal
|
||||
or f"{{{input_name}}}" in agent.backstory
|
||||
):
|
||||
# Replace placeholders with input names
|
||||
agent_role = placeholder_pattern.sub(lambda m: m.group(1), agent.role or "")
|
||||
agent_goal = placeholder_pattern.sub(lambda m: m.group(1), agent.goal or "")
|
||||
agent_backstory = placeholder_pattern.sub(
|
||||
@@ -475,7 +436,6 @@ def generate_input_description_with_ai(
|
||||
|
||||
context = "\n".join(context_texts)
|
||||
if not context:
|
||||
# If no context is found for the input, raise an exception as per instruction
|
||||
raise ValueError(f"No context found for input '{input_name}'.")
|
||||
|
||||
prompt = (
|
||||
@@ -489,22 +449,19 @@ def generate_input_description_with_ai(
|
||||
|
||||
|
||||
def generate_crew_description_with_ai(crew: Crew, chat_llm: LLM | BaseLLM) -> str:
|
||||
"""
|
||||
Generates a brief description of the crew using AI.
|
||||
"""Generate a brief description of the crew using AI.
|
||||
|
||||
Args:
|
||||
crew (Crew): The crew object.
|
||||
crew: The crew object.
|
||||
chat_llm: The chat language model to use for AI calls.
|
||||
|
||||
Returns:
|
||||
str: A concise description of the crew's purpose (15 words or less).
|
||||
A concise description of the crew's purpose (15 words or less).
|
||||
"""
|
||||
# Gather context from tasks and agents
|
||||
context_texts = []
|
||||
placeholder_pattern = re.compile(r"\{(.+?)}")
|
||||
|
||||
for task in crew.tasks:
|
||||
# Replace placeholders with input names
|
||||
task_description = placeholder_pattern.sub(
|
||||
lambda m: m.group(1), task.description or ""
|
||||
)
|
||||
@@ -514,7 +471,6 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm: LLM | BaseLLM) -> st
|
||||
context_texts.append(f"Task Description: {task_description}")
|
||||
context_texts.append(f"Expected Output: {expected_output}")
|
||||
for agent in crew.agents:
|
||||
# Replace placeholders with input names
|
||||
agent_role = placeholder_pattern.sub(lambda m: m.group(1), agent.role or "")
|
||||
agent_goal = placeholder_pattern.sub(lambda m: m.group(1), agent.goal or "")
|
||||
agent_backstory = placeholder_pattern.sub(
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
import os
|
||||
from typing import Any, Final
|
||||
|
||||
from crewai.cli.constants import DEFAULT_LLM_MODEL, ENV_VARS, LITELLM_PARAMS
|
||||
from crewai.constants import DEFAULT_LLM_MODEL, ENV_VARS, LITELLM_PARAMS
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
|
||||
|
||||
@@ -1,20 +1,19 @@
|
||||
"""Project utility functions for discovering crews, flows, and tools."""
|
||||
|
||||
from functools import reduce
|
||||
import importlib.util
|
||||
from inspect import getmro, isclass, isfunction, ismethod
|
||||
import os
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import sys
|
||||
from typing import Any, cast, get_type_hints
|
||||
|
||||
import click
|
||||
from rich.console import Console
|
||||
import tomli
|
||||
|
||||
from crewai.cli.config import Settings
|
||||
from crewai.cli.constants import ENV_VARS
|
||||
from crewai.crew import Crew
|
||||
from crewai.flow import Flow
|
||||
from crewai.settings import Settings
|
||||
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
@@ -23,25 +22,6 @@ if sys.version_info >= (3, 11):
|
||||
console = Console()
|
||||
|
||||
|
||||
def copy_template(
|
||||
src: Path, dst: Path, name: str, class_name: str, folder_name: str
|
||||
) -> None:
|
||||
"""Copy a file from src to dst."""
|
||||
with open(src, "r") as file:
|
||||
content = file.read()
|
||||
|
||||
# Interpolate the content
|
||||
content = content.replace("{{name}}", name)
|
||||
content = content.replace("{{crew_name}}", class_name)
|
||||
content = content.replace("{{folder_name}}", folder_name)
|
||||
|
||||
# Write the interpolated content to the new file
|
||||
with open(dst, "w") as file:
|
||||
file.write(content)
|
||||
|
||||
click.secho(f" - Created {dst}", fg="green")
|
||||
|
||||
|
||||
def read_toml(file_path: str = "pyproject.toml") -> dict[str, Any]:
|
||||
"""Read the content of a TOML file and return it as a dictionary."""
|
||||
with open(file_path, "rb") as f:
|
||||
@@ -49,6 +29,7 @@ def read_toml(file_path: str = "pyproject.toml") -> dict[str, Any]:
|
||||
|
||||
|
||||
def parse_toml(content: str) -> dict[str, Any]:
|
||||
"""Parse a TOML string and return it as a dictionary."""
|
||||
if sys.version_info >= (3, 11):
|
||||
return tomllib.loads(content)
|
||||
return tomli.loads(content)
|
||||
@@ -104,7 +85,6 @@ def _get_project_attribute(
|
||||
style="bold red",
|
||||
)
|
||||
except Exception as e:
|
||||
# Handle TOML decode errors for Python 3.11+
|
||||
if sys.version_info >= (3, 11) and isinstance(e, tomllib.TOMLDecodeError):
|
||||
console.print(
|
||||
f"Error: {pyproject_path} is not a valid TOML file.", style="bold red"
|
||||
@@ -128,164 +108,18 @@ def _get_nested_value(data: dict[str, Any], keys: list[str]) -> Any:
|
||||
return reduce(dict.__getitem__, keys, data)
|
||||
|
||||
|
||||
def fetch_and_json_env_file(env_file_path: str = ".env") -> dict[str, Any]:
|
||||
"""Fetch the environment variables from a .env file and return them as a dictionary."""
|
||||
try:
|
||||
# Read the .env file
|
||||
with open(env_file_path, "r") as f:
|
||||
env_content = f.read()
|
||||
|
||||
# Parse the .env file content to a dictionary
|
||||
env_dict = {}
|
||||
for line in env_content.splitlines():
|
||||
if line.strip() and not line.strip().startswith("#"):
|
||||
key, value = line.split("=", 1)
|
||||
env_dict[key.strip()] = value.strip()
|
||||
|
||||
return env_dict
|
||||
|
||||
except FileNotFoundError:
|
||||
console.print(f"Error: {env_file_path} not found.", style="bold red")
|
||||
except Exception as e:
|
||||
console.print(f"Error reading the .env file: {e}", style="bold red")
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def tree_copy(source: Path, destination: Path) -> None:
|
||||
"""Copies the entire directory structure from the source to the destination."""
|
||||
for item in os.listdir(source):
|
||||
source_item = os.path.join(source, item)
|
||||
destination_item = os.path.join(destination, item)
|
||||
if os.path.isdir(source_item):
|
||||
shutil.copytree(source_item, destination_item)
|
||||
else:
|
||||
shutil.copy2(source_item, destination_item)
|
||||
|
||||
|
||||
def tree_find_and_replace(directory: Path, find: str, replace: str) -> None:
|
||||
"""Recursively searches through a directory, replacing a target string in
|
||||
both file contents and filenames with a specified replacement string.
|
||||
"""
|
||||
for path, dirs, files in os.walk(os.path.abspath(directory), topdown=False):
|
||||
for filename in files:
|
||||
filepath = os.path.join(path, filename)
|
||||
|
||||
with open(filepath, "r", encoding="utf-8", errors="ignore") as file:
|
||||
contents = file.read()
|
||||
with open(filepath, "w") as file:
|
||||
file.write(contents.replace(find, replace))
|
||||
|
||||
if find in filename:
|
||||
new_filename = filename.replace(find, replace)
|
||||
new_filepath = os.path.join(path, new_filename)
|
||||
os.rename(filepath, new_filepath)
|
||||
|
||||
for dirname in dirs:
|
||||
if find in dirname:
|
||||
new_dirname = dirname.replace(find, replace)
|
||||
new_dirpath = os.path.join(path, new_dirname)
|
||||
old_dirpath = os.path.join(path, dirname)
|
||||
os.rename(old_dirpath, new_dirpath)
|
||||
|
||||
|
||||
def load_env_vars(folder_path: Path) -> dict[str, Any]:
|
||||
"""
|
||||
Loads environment variables from a .env file in the specified folder path.
|
||||
|
||||
Args:
|
||||
- folder_path (Path): The path to the folder containing the .env file.
|
||||
|
||||
Returns:
|
||||
- dict: A dictionary of environment variables.
|
||||
"""
|
||||
env_file_path = folder_path / ".env"
|
||||
env_vars = {}
|
||||
if env_file_path.exists():
|
||||
with open(env_file_path, "r") as file:
|
||||
for line in file:
|
||||
key, _, value = line.strip().partition("=")
|
||||
if key and value:
|
||||
env_vars[key] = value
|
||||
return env_vars
|
||||
|
||||
|
||||
def update_env_vars(
|
||||
env_vars: dict[str, Any], provider: str, model: str
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Updates environment variables with the API key for the selected provider and model.
|
||||
|
||||
Args:
|
||||
- env_vars (dict): Environment variables dictionary.
|
||||
- provider (str): Selected provider.
|
||||
- model (str): Selected model.
|
||||
|
||||
Returns:
|
||||
- None
|
||||
"""
|
||||
provider_config = cast(
|
||||
list[str],
|
||||
ENV_VARS.get(
|
||||
provider,
|
||||
[
|
||||
click.prompt(
|
||||
f"Enter the environment variable name for your {provider.capitalize()} API key",
|
||||
type=str,
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
api_key_var = provider_config[0]
|
||||
|
||||
if api_key_var not in env_vars:
|
||||
try:
|
||||
env_vars[api_key_var] = click.prompt(
|
||||
f"Enter your {provider.capitalize()} API key", type=str, hide_input=True
|
||||
)
|
||||
except click.exceptions.Abort:
|
||||
click.secho("Operation aborted by the user.", fg="red")
|
||||
return None
|
||||
else:
|
||||
click.secho(f"API key already exists for {provider.capitalize()}.", fg="yellow")
|
||||
|
||||
env_vars["MODEL"] = model
|
||||
click.secho(f"Selected model: {model}", fg="green")
|
||||
return env_vars
|
||||
|
||||
|
||||
def write_env_file(folder_path: Path, env_vars: dict[str, Any]) -> None:
|
||||
"""
|
||||
Writes environment variables to a .env file in the specified folder.
|
||||
|
||||
Args:
|
||||
- folder_path (Path): The path to the folder where the .env file will be written.
|
||||
- env_vars (dict): A dictionary of environment variables to write.
|
||||
"""
|
||||
env_file_path = folder_path / ".env"
|
||||
with open(env_file_path, "w") as file:
|
||||
for key, value in env_vars.items():
|
||||
file.write(f"{key.upper()}={value}\n")
|
||||
|
||||
|
||||
def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
|
||||
"""Get the crew instances from a file."""
|
||||
crew_instances = []
|
||||
try:
|
||||
import importlib.util
|
||||
|
||||
# Add the current directory to sys.path to ensure imports resolve correctly
|
||||
current_dir = os.getcwd()
|
||||
if current_dir not in sys.path:
|
||||
sys.path.insert(0, current_dir)
|
||||
|
||||
# If we're not in src directory but there's a src directory, add it to path
|
||||
src_dir = os.path.join(current_dir, "src")
|
||||
if os.path.isdir(src_dir) and src_dir not in sys.path:
|
||||
sys.path.insert(0, src_dir)
|
||||
|
||||
# Search in both current directory and src directory if it exists
|
||||
search_paths = [".", "src"] if os.path.isdir("src") else ["."]
|
||||
|
||||
for search_path in search_paths:
|
||||
@@ -316,7 +150,6 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
|
||||
)
|
||||
continue
|
||||
|
||||
# If we found crew instances, break out of the loop
|
||||
if crew_instances:
|
||||
break
|
||||
|
||||
@@ -334,7 +167,6 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
|
||||
)
|
||||
continue
|
||||
|
||||
# If we found crew instances in this search path, break out of the search paths loop
|
||||
if crew_instances:
|
||||
break
|
||||
|
||||
@@ -352,6 +184,7 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
|
||||
|
||||
|
||||
def get_crew_instance(module_attr: Any) -> Crew | None:
|
||||
"""Get a Crew instance from a module attribute."""
|
||||
if (
|
||||
callable(module_attr)
|
||||
and hasattr(module_attr, "is_crew_class")
|
||||
@@ -372,6 +205,7 @@ def get_crew_instance(module_attr: Any) -> Crew | None:
|
||||
|
||||
|
||||
def fetch_crews(module_attr: Any) -> list[Crew]:
|
||||
"""Fetch crew instances from a module attribute."""
|
||||
crew_instances: list[Crew] = []
|
||||
|
||||
if crew_instance := get_crew_instance(module_attr):
|
||||
@@ -386,7 +220,7 @@ def fetch_crews(module_attr: Any) -> list[Crew]:
|
||||
return crew_instances
|
||||
|
||||
|
||||
def get_flow_instance(module_attr: Any) -> Flow | None:
|
||||
def get_flow_instance(module_attr: Any) -> Flow[Any] | None:
|
||||
"""Check if a module attribute is a user-defined Flow subclass and return an instance.
|
||||
|
||||
Args:
|
||||
@@ -413,13 +247,12 @@ _SKIP_DIRS = frozenset(
|
||||
)
|
||||
|
||||
|
||||
def get_flows(flow_path: str = "main.py") -> list[Flow]:
|
||||
def get_flows(flow_path: str = "main.py") -> list[Flow[Any]]:
|
||||
"""Get the flow instances from project files.
|
||||
|
||||
Walks the project directory looking for files matching ``flow_path``
|
||||
(default ``main.py``), loads each module, and extracts Flow subclass
|
||||
instances. Directories that are clearly not user source code (virtual
|
||||
environments, ``.git``, etc.) are pruned to avoid noisy import errors.
|
||||
instances.
|
||||
|
||||
Args:
|
||||
flow_path: Filename to search for (default ``main.py``).
|
||||
@@ -427,7 +260,7 @@ def get_flows(flow_path: str = "main.py") -> list[Flow]:
|
||||
Returns:
|
||||
A list of discovered Flow instances.
|
||||
"""
|
||||
flow_instances: list[Flow] = []
|
||||
flow_instances: list[Flow[Any]] = []
|
||||
try:
|
||||
current_dir = os.getcwd()
|
||||
if current_dir not in sys.path:
|
||||
@@ -486,6 +319,7 @@ def get_flows(flow_path: str = "main.py") -> list[Flow]:
|
||||
|
||||
|
||||
def is_valid_tool(obj: Any) -> bool:
|
||||
"""Check if an object is a valid CrewAI tool."""
|
||||
from crewai.tools.base_tool import Tool
|
||||
|
||||
if isclass(obj):
|
||||
@@ -498,12 +332,12 @@ def is_valid_tool(obj: Any) -> bool:
|
||||
|
||||
|
||||
def extract_available_exports(dir_path: str = "src") -> list[dict[str, Any]]:
|
||||
"""
|
||||
Extract available tool classes from the project's __init__.py files.
|
||||
"""Extract available tool classes from the project's __init__.py files.
|
||||
|
||||
Only includes classes that inherit from BaseTool or functions decorated with @tool.
|
||||
|
||||
Returns:
|
||||
list: A list of valid tool class names or ["BaseTool"] if none found
|
||||
A list of valid tool class names or ["BaseTool"] if none found.
|
||||
"""
|
||||
try:
|
||||
init_files = Path(dir_path).glob("**/__init__.py")
|
||||
@@ -530,6 +364,7 @@ def extract_available_exports(dir_path: str = "src") -> list[dict[str, Any]]:
|
||||
def build_env_with_tool_repository_credentials(
|
||||
repository_handle: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Build environment variables with tool repository credentials."""
|
||||
repository_handle = repository_handle.upper().replace("-", "_")
|
||||
settings = Settings()
|
||||
|
||||
@@ -545,9 +380,7 @@ def build_env_with_tool_repository_credentials(
|
||||
|
||||
|
||||
def _load_tools_from_init(init_file: Path) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Load and validate tools from a given __init__.py file.
|
||||
"""
|
||||
"""Load and validate tools from a given __init__.py file."""
|
||||
spec = importlib.util.spec_from_file_location("temp_module", init_file)
|
||||
|
||||
if not spec or not spec.loader:
|
||||
@@ -583,9 +416,7 @@ def _load_tools_from_init(init_file: Path) -> list[dict[str, Any]]:
|
||||
|
||||
|
||||
def _print_no_tools_warning() -> None:
|
||||
"""
|
||||
Display warning and usage instructions if no tools were found.
|
||||
"""
|
||||
"""Display warning and usage instructions if no tools were found."""
|
||||
console.print(
|
||||
"\n[bold yellow]Warning: No valid tools were exposed in your __init__.py file![/bold yellow]"
|
||||
)
|
||||
@@ -1,12 +1,15 @@
|
||||
"""Memory reset utilities for CrewAI crews and flows."""
|
||||
|
||||
import subprocess
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
|
||||
from crewai.cli.utils import get_crews, get_flows
|
||||
from crewai.flow import Flow
|
||||
from crewai.utilities.project_utils import get_crews, get_flows
|
||||
|
||||
|
||||
def _reset_flow_memory(flow: Flow) -> None:
|
||||
def _reset_flow_memory(flow: Flow[Any]) -> None:
|
||||
"""Reset memory for a single flow instance.
|
||||
|
||||
Handles Memory, MemoryScope (both have .reset()), and MemorySlice
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Version utilities for CrewAI CLI."""
|
||||
"""Version utilities for CrewAI."""
|
||||
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime, timedelta
|
||||
@@ -26,7 +26,7 @@ def _get_cache_file() -> Path:
|
||||
|
||||
|
||||
def get_crewai_version() -> str:
|
||||
"""Get the version number of CrewAI running the CLI."""
|
||||
"""Get the version number of the installed CrewAI package."""
|
||||
return importlib.metadata.version("crewai")
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from unittest import mock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from crewai.agents.crew_agent_executor import AgentFinish, CrewAgentExecutor
|
||||
from crewai.cli.constants import DEFAULT_LLM_MODEL
|
||||
from crewai.constants import DEFAULT_LLM_MODEL
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.tool_usage_events import ToolUsageFinishedEvent
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
@@ -2046,12 +2046,12 @@ def test_get_knowledge_search_query():
|
||||
@pytest.fixture
|
||||
def mock_get_auth_token():
|
||||
with patch(
|
||||
"crewai.cli.authentication.token.get_auth_token", return_value="test_token"
|
||||
"crewai.auth.token.get_auth_token", return_value="test_token"
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI.get_agent")
|
||||
@patch("crewai.plus_api.PlusAPI.get_agent")
|
||||
def test_agent_from_repository(mock_get_agent, mock_get_auth_token):
|
||||
from crewai_tools import (
|
||||
FileReadTool,
|
||||
@@ -2092,7 +2092,7 @@ def test_agent_from_repository(mock_get_agent, mock_get_auth_token):
|
||||
assert agent.tools[1].file_path == "test.txt"
|
||||
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI.get_agent")
|
||||
@patch("crewai.plus_api.PlusAPI.get_agent")
|
||||
def test_agent_from_repository_override_attributes(mock_get_agent, mock_get_auth_token):
|
||||
from crewai_tools import SerperDevTool
|
||||
|
||||
@@ -2116,7 +2116,7 @@ def test_agent_from_repository_override_attributes(mock_get_agent, mock_get_auth
|
||||
assert isinstance(agent.tools[0], SerperDevTool)
|
||||
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI.get_agent")
|
||||
@patch("crewai.plus_api.PlusAPI.get_agent")
|
||||
def test_agent_from_repository_with_invalid_tools(mock_get_agent, mock_get_auth_token):
|
||||
mock_get_response = MagicMock()
|
||||
mock_get_response.status_code = 200
|
||||
@@ -2139,7 +2139,7 @@ def test_agent_from_repository_with_invalid_tools(mock_get_agent, mock_get_auth_
|
||||
Agent(from_repository="test_agent")
|
||||
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI.get_agent")
|
||||
@patch("crewai.plus_api.PlusAPI.get_agent")
|
||||
def test_agent_from_repository_internal_error(mock_get_agent, mock_get_auth_token):
|
||||
mock_get_response = MagicMock()
|
||||
mock_get_response.status_code = 500
|
||||
@@ -2152,7 +2152,7 @@ def test_agent_from_repository_internal_error(mock_get_agent, mock_get_auth_toke
|
||||
Agent(from_repository="test_agent")
|
||||
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI.get_agent")
|
||||
@patch("crewai.plus_api.PlusAPI.get_agent")
|
||||
def test_agent_from_repository_agent_not_found(mock_get_agent, mock_get_auth_token):
|
||||
mock_get_response = MagicMock()
|
||||
mock_get_response.status_code = 404
|
||||
@@ -2165,7 +2165,7 @@ def test_agent_from_repository_agent_not_found(mock_get_agent, mock_get_auth_tok
|
||||
Agent(from_repository="test_agent")
|
||||
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI.get_agent")
|
||||
@patch("crewai.plus_api.PlusAPI.get_agent")
|
||||
@patch("crewai.utilities.agent_utils.Settings")
|
||||
@patch("crewai.utilities.agent_utils.console")
|
||||
def test_agent_from_repository_displays_org_info(
|
||||
@@ -2198,7 +2198,7 @@ def test_agent_from_repository_displays_org_info(
|
||||
assert agent.backstory == "test backstory"
|
||||
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI.get_agent")
|
||||
@patch("crewai.plus_api.PlusAPI.get_agent")
|
||||
@patch("crewai.utilities.agent_utils.Settings")
|
||||
@patch("crewai.utilities.agent_utils.console")
|
||||
def test_agent_from_repository_without_org_set(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
from crewai.cli.authentication.main import Oauth2Settings
|
||||
from crewai.cli.authentication.providers.auth0 import Auth0Provider
|
||||
from crewai.auth.oauth2 import Oauth2Settings
|
||||
from crewai.auth.providers.auth0 import Auth0Provider
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from crewai.cli.authentication.main import Oauth2Settings
|
||||
from crewai.cli.authentication.providers.entra_id import EntraIdProvider
|
||||
from crewai.auth.oauth2 import Oauth2Settings
|
||||
from crewai.auth.providers.entra_id import EntraIdProvider
|
||||
|
||||
|
||||
class TestEntraIdProvider:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from crewai.cli.authentication.main import Oauth2Settings
|
||||
from crewai.cli.authentication.providers.keycloak import KeycloakProvider
|
||||
from crewai.auth.oauth2 import Oauth2Settings
|
||||
from crewai.auth.providers.keycloak import KeycloakProvider
|
||||
|
||||
|
||||
class TestKeycloakProvider:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from crewai.cli.authentication.main import Oauth2Settings
|
||||
from crewai.cli.authentication.providers.okta import OktaProvider
|
||||
from crewai.auth.oauth2 import Oauth2Settings
|
||||
from crewai.auth.providers.okta import OktaProvider
|
||||
|
||||
|
||||
class TestOktaProvider:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
from crewai.cli.authentication.main import Oauth2Settings
|
||||
from crewai.cli.authentication.providers.workos import WorkosProvider
|
||||
from crewai.auth.oauth2 import Oauth2Settings
|
||||
from crewai.auth.providers.workos import WorkosProvider
|
||||
|
||||
|
||||
class TestWorkosProvider:
|
||||
|
||||
@@ -3,8 +3,8 @@ from unittest.mock import MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
import httpx
|
||||
from crewai.cli.authentication.main import AuthenticationCommand
|
||||
from crewai.cli.constants import (
|
||||
from crewai.auth.oauth2 import AuthenticationCommand
|
||||
from crewai.constants import (
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
|
||||
@@ -14,7 +14,7 @@ from crewai.cli.constants import (
|
||||
class TestAuthenticationCommand:
|
||||
def setup_method(self):
|
||||
# Mock Settings so we always use default constants regardless of local config.
|
||||
with patch("crewai.cli.authentication.main.Settings") as mock_settings:
|
||||
with patch("crewai.auth.oauth2.Settings") as mock_settings:
|
||||
instance = mock_settings.return_value
|
||||
instance.oauth2_provider = "workos"
|
||||
instance.oauth2_domain = CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN
|
||||
@@ -38,12 +38,12 @@ class TestAuthenticationCommand:
|
||||
),
|
||||
],
|
||||
)
|
||||
@patch("crewai.cli.authentication.main.AuthenticationCommand._get_device_code")
|
||||
@patch("crewai.auth.oauth2.AuthenticationCommand._get_device_code")
|
||||
@patch(
|
||||
"crewai.cli.authentication.main.AuthenticationCommand._display_auth_instructions"
|
||||
"crewai.auth.oauth2.AuthenticationCommand._display_auth_instructions"
|
||||
)
|
||||
@patch("crewai.cli.authentication.main.AuthenticationCommand._poll_for_token")
|
||||
@patch("crewai.cli.authentication.main.console.print")
|
||||
@patch("crewai.auth.oauth2.AuthenticationCommand._poll_for_token")
|
||||
@patch("crewai.auth.oauth2.console.print")
|
||||
def test_login(
|
||||
self,
|
||||
mock_console_print,
|
||||
@@ -82,8 +82,8 @@ class TestAuthenticationCommand:
|
||||
self.auth_command.oauth2_provider._get_domain() == expected_urls["domain"]
|
||||
)
|
||||
|
||||
@patch("crewai.cli.authentication.main.webbrowser")
|
||||
@patch("crewai.cli.authentication.main.console.print")
|
||||
@patch("crewai.auth.oauth2.webbrowser")
|
||||
@patch("crewai.auth.oauth2.console.print")
|
||||
def test_display_auth_instructions(self, mock_console_print, mock_webbrowser):
|
||||
device_code_data = {
|
||||
"verification_uri_complete": "https://example.com/auth",
|
||||
@@ -113,8 +113,8 @@ class TestAuthenticationCommand:
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("has_expiration", [True, False])
|
||||
@patch("crewai.cli.authentication.main.validate_jwt_token")
|
||||
@patch("crewai.cli.authentication.main.TokenManager.save_tokens")
|
||||
@patch("crewai.auth.oauth2.validate_jwt_token")
|
||||
@patch("crewai.auth.oauth2.TokenManager.save_tokens")
|
||||
def test_validate_and_save_token(
|
||||
self,
|
||||
mock_save_tokens,
|
||||
@@ -123,8 +123,8 @@ class TestAuthenticationCommand:
|
||||
jwt_config,
|
||||
has_expiration,
|
||||
):
|
||||
from crewai.cli.authentication.main import Oauth2Settings
|
||||
from crewai.cli.authentication.providers.workos import WorkosProvider
|
||||
from crewai.auth.oauth2 import Oauth2Settings
|
||||
from crewai.auth.providers.workos import WorkosProvider
|
||||
|
||||
if user_provider == "workos":
|
||||
self.auth_command.oauth2_provider = WorkosProvider(
|
||||
@@ -163,8 +163,8 @@ class TestAuthenticationCommand:
|
||||
mock_save_tokens.assert_called_once_with("test_access_token", 0)
|
||||
|
||||
@patch("crewai_cli.tools.main.ToolCommand")
|
||||
@patch("crewai.cli.authentication.main.Settings")
|
||||
@patch("crewai.cli.authentication.main.console.print")
|
||||
@patch("crewai.auth.oauth2.Settings")
|
||||
@patch("crewai.auth.oauth2.console.print")
|
||||
def test_login_to_tool_repository_success(
|
||||
self, mock_console_print, mock_settings, mock_tool_command
|
||||
):
|
||||
@@ -196,7 +196,7 @@ class TestAuthenticationCommand:
|
||||
mock_console_print.assert_has_calls(expected_calls)
|
||||
|
||||
@patch("crewai_cli.tools.main.ToolCommand")
|
||||
@patch("crewai.cli.authentication.main.console.print")
|
||||
@patch("crewai.auth.oauth2.console.print")
|
||||
def test_login_to_tool_repository_error(
|
||||
self, mock_console_print, mock_tool_command
|
||||
):
|
||||
@@ -226,7 +226,7 @@ class TestAuthenticationCommand:
|
||||
]
|
||||
mock_console_print.assert_has_calls(expected_calls)
|
||||
|
||||
@patch("crewai.cli.authentication.main.httpx.post")
|
||||
@patch("crewai.auth.oauth2.httpx.post")
|
||||
def test_get_device_code(self, mock_post):
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
@@ -262,8 +262,8 @@ class TestAuthenticationCommand:
|
||||
"verification_uri_complete": "https://example.com/auth",
|
||||
}
|
||||
|
||||
@patch("crewai.cli.authentication.main.httpx.post")
|
||||
@patch("crewai.cli.authentication.main.console.print")
|
||||
@patch("crewai.auth.oauth2.httpx.post")
|
||||
@patch("crewai.auth.oauth2.console.print")
|
||||
def test_poll_for_token_success(self, mock_console_print, mock_post):
|
||||
mock_response_success = MagicMock()
|
||||
mock_response_success.status_code = 200
|
||||
@@ -311,8 +311,8 @@ class TestAuthenticationCommand:
|
||||
]
|
||||
mock_console_print.assert_has_calls(expected_calls)
|
||||
|
||||
@patch("crewai.cli.authentication.main.httpx.post")
|
||||
@patch("crewai.cli.authentication.main.console.print")
|
||||
@patch("crewai.auth.oauth2.httpx.post")
|
||||
@patch("crewai.auth.oauth2.console.print")
|
||||
def test_poll_for_token_timeout(self, mock_console_print, mock_post):
|
||||
mock_response_pending = MagicMock()
|
||||
mock_response_pending.status_code = 400
|
||||
@@ -330,7 +330,7 @@ class TestAuthenticationCommand:
|
||||
"Timeout: Failed to get the token. Please try again.", style="bold red"
|
||||
)
|
||||
|
||||
@patch("crewai.cli.authentication.main.httpx.post")
|
||||
@patch("crewai.auth.oauth2.httpx.post")
|
||||
def test_poll_for_token_error(self, mock_post):
|
||||
"""Test the method to poll for token (error path)."""
|
||||
# Setup mock to return error
|
||||
|
||||
@@ -3,11 +3,11 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import jwt
|
||||
|
||||
from crewai.cli.authentication.utils import validate_jwt_token
|
||||
from crewai.auth.utils import validate_jwt_token
|
||||
|
||||
|
||||
@patch("crewai.cli.authentication.utils.PyJWKClient", return_value=MagicMock())
|
||||
@patch("crewai.cli.authentication.utils.jwt")
|
||||
@patch("crewai.auth.utils.PyJWKClient", return_value=MagicMock())
|
||||
@patch("crewai.auth.utils.jwt")
|
||||
class TestUtils(unittest.TestCase):
|
||||
def test_validate_jwt_token(self, mock_jwt, mock_pyjwkclient):
|
||||
mock_jwt.decode.return_value = {"exp": 1719859200}
|
||||
|
||||
@@ -27,9 +27,9 @@ def mock_crew():
|
||||
@pytest.fixture
|
||||
def mock_get_crews(mock_crew):
|
||||
with mock.patch(
|
||||
"crewai.cli.reset_memories_command.get_crews", return_value=[mock_crew]
|
||||
"crewai.utilities.reset_memories.get_crews", return_value=[mock_crew]
|
||||
) as mock_get_crew, mock.patch(
|
||||
"crewai.cli.reset_memories_command.get_flows", return_value=[]
|
||||
"crewai.utilities.reset_memories.get_flows", return_value=[]
|
||||
):
|
||||
yield mock_get_crew
|
||||
|
||||
@@ -169,9 +169,9 @@ def mock_flow():
|
||||
@pytest.fixture
|
||||
def mock_get_flows(mock_flow):
|
||||
with mock.patch(
|
||||
"crewai.cli.reset_memories_command.get_flows", return_value=[mock_flow]
|
||||
"crewai.utilities.reset_memories.get_flows", return_value=[mock_flow]
|
||||
) as mock_get_flow, mock.patch(
|
||||
"crewai.cli.reset_memories_command.get_crews", return_value=[]
|
||||
"crewai.utilities.reset_memories.get_crews", return_value=[]
|
||||
):
|
||||
yield mock_get_flow
|
||||
|
||||
@@ -196,9 +196,9 @@ def test_reset_flow_knowledge_no_effect(mock_get_flows, mock_flow, runner):
|
||||
|
||||
def test_reset_no_crew_or_flow_found(runner):
|
||||
with mock.patch(
|
||||
"crewai.cli.reset_memories_command.get_crews", return_value=[]
|
||||
"crewai.utilities.reset_memories.get_crews", return_value=[]
|
||||
), mock.patch(
|
||||
"crewai.cli.reset_memories_command.get_flows", return_value=[]
|
||||
"crewai.utilities.reset_memories.get_flows", return_value=[]
|
||||
):
|
||||
result = runner.invoke(reset_memories, ["-m"])
|
||||
assert "No crew or flow found." in result.output
|
||||
@@ -206,9 +206,9 @@ def test_reset_no_crew_or_flow_found(runner):
|
||||
|
||||
def test_reset_crew_and_flow_memory(mock_crew, mock_flow, runner):
|
||||
with mock.patch(
|
||||
"crewai.cli.reset_memories_command.get_crews", return_value=[mock_crew]
|
||||
"crewai.utilities.reset_memories.get_crews", return_value=[mock_crew]
|
||||
), mock.patch(
|
||||
"crewai.cli.reset_memories_command.get_flows", return_value=[mock_flow]
|
||||
"crewai.utilities.reset_memories.get_flows", return_value=[mock_flow]
|
||||
):
|
||||
result = runner.invoke(reset_memories, ["-m"])
|
||||
mock_crew.reset_memories.assert_called_once_with(command_type="memory")
|
||||
@@ -222,9 +222,9 @@ def test_reset_flow_memory_none(runner):
|
||||
mock_flow.name = "NoMemFlow"
|
||||
mock_flow.memory = None
|
||||
with mock.patch(
|
||||
"crewai.cli.reset_memories_command.get_crews", return_value=[]
|
||||
"crewai.utilities.reset_memories.get_crews", return_value=[]
|
||||
), mock.patch(
|
||||
"crewai.cli.reset_memories_command.get_flows", return_value=[mock_flow]
|
||||
"crewai.utilities.reset_memories.get_flows", return_value=[mock_flow]
|
||||
):
|
||||
result = runner.invoke(reset_memories, ["-m"])
|
||||
assert "[Flow (NoMemFlow)] Memory has been reset." in result.output
|
||||
|
||||
@@ -6,13 +6,13 @@ from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from crewai.cli.config import (
|
||||
from crewai.settings import (
|
||||
CLI_SETTINGS_KEYS,
|
||||
DEFAULT_CLI_SETTINGS,
|
||||
USER_SETTINGS_KEYS,
|
||||
Settings,
|
||||
)
|
||||
from crewai.cli.shared.token_manager import TokenManager
|
||||
from crewai.auth.token_manager import TokenManager
|
||||
|
||||
|
||||
class TestSettings(unittest.TestCase):
|
||||
@@ -69,7 +69,7 @@ class TestSettings(unittest.TestCase):
|
||||
for key in user_settings.keys():
|
||||
self.assertEqual(getattr(settings, key), None)
|
||||
|
||||
@patch("crewai.cli.config.TokenManager")
|
||||
@patch("crewai.settings.TokenManager")
|
||||
def test_reset_settings(self, mock_token_manager):
|
||||
user_settings = {key: f"value_for_{key}" for key in USER_SETTINGS_KEYS}
|
||||
cli_settings = {key: f"value_for_{key}" for key in CLI_SETTINGS_KEYS if key != "oauth2_extra"}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from crewai.cli.constants import ENV_VARS, MODELS, PROVIDERS
|
||||
from crewai.constants import ENV_VARS, MODELS, PROVIDERS
|
||||
|
||||
|
||||
def test_huggingface_in_providers():
|
||||
|
||||
@@ -1,101 +0,0 @@
|
||||
import pytest
|
||||
from crewai.cli.git import Repository
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def repository(fp):
|
||||
fp.register(["git", "--version"], stdout="git version 2.30.0\n")
|
||||
fp.register(["git", "rev-parse", "--is-inside-work-tree"], stdout="true\n")
|
||||
fp.register(["git", "fetch"], stdout="")
|
||||
return Repository(path=".")
|
||||
|
||||
|
||||
def test_init_with_invalid_git_repo(fp):
|
||||
fp.register(["git", "--version"], stdout="git version 2.30.0\n")
|
||||
fp.register(
|
||||
["git", "rev-parse", "--is-inside-work-tree"],
|
||||
returncode=1,
|
||||
stderr="fatal: not a git repository\n",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
Repository(path="invalid/path")
|
||||
|
||||
|
||||
def test_is_git_not_installed(fp):
|
||||
fp.register(["git", "--version"], returncode=1)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Git is not installed or not found in your PATH."
|
||||
):
|
||||
Repository(path=".")
|
||||
|
||||
|
||||
def test_status(fp, repository):
|
||||
fp.register(
|
||||
["git", "status", "--branch", "--porcelain"],
|
||||
stdout="## main...origin/main [ahead 1]\n",
|
||||
)
|
||||
assert repository.status() == "## main...origin/main [ahead 1]"
|
||||
|
||||
|
||||
def test_has_uncommitted_changes(fp, repository):
|
||||
fp.register(
|
||||
["git", "status", "--branch", "--porcelain"],
|
||||
stdout="## main...origin/main\n M somefile.txt\n",
|
||||
)
|
||||
assert repository.has_uncommitted_changes() is True
|
||||
|
||||
|
||||
def test_is_ahead_or_behind(fp, repository):
|
||||
fp.register(
|
||||
["git", "status", "--branch", "--porcelain"],
|
||||
stdout="## main...origin/main [ahead 1]\n",
|
||||
)
|
||||
assert repository.is_ahead_or_behind() is True
|
||||
|
||||
|
||||
def test_is_synced_when_synced(fp, repository):
|
||||
fp.register(
|
||||
["git", "status", "--branch", "--porcelain"], stdout="## main...origin/main\n"
|
||||
)
|
||||
fp.register(
|
||||
["git", "status", "--branch", "--porcelain"], stdout="## main...origin/main\n"
|
||||
)
|
||||
assert repository.is_synced() is True
|
||||
|
||||
|
||||
def test_is_synced_with_uncommitted_changes(fp, repository):
|
||||
fp.register(
|
||||
["git", "status", "--branch", "--porcelain"],
|
||||
stdout="## main...origin/main\n M somefile.txt\n",
|
||||
)
|
||||
assert repository.is_synced() is False
|
||||
|
||||
|
||||
def test_is_synced_when_ahead_or_behind(fp, repository):
|
||||
fp.register(
|
||||
["git", "status", "--branch", "--porcelain"],
|
||||
stdout="## main...origin/main [ahead 1]\n",
|
||||
)
|
||||
fp.register(
|
||||
["git", "status", "--branch", "--porcelain"],
|
||||
stdout="## main...origin/main [ahead 1]\n",
|
||||
)
|
||||
assert repository.is_synced() is False
|
||||
|
||||
|
||||
def test_is_synced_with_uncommitted_changes_and_ahead(fp, repository):
|
||||
fp.register(
|
||||
["git", "status", "--branch", "--porcelain"],
|
||||
stdout="## main...origin/main [ahead 1]\n M somefile.txt\n",
|
||||
)
|
||||
assert repository.is_synced() is False
|
||||
|
||||
|
||||
def test_origin_url(fp, repository):
|
||||
fp.register(
|
||||
["git", "remote", "get-url", "origin"],
|
||||
stdout="https://github.com/user/repo.git\n",
|
||||
)
|
||||
assert repository.origin_url() == "https://github.com/user/repo.git"
|
||||
@@ -4,7 +4,7 @@ from unittest.mock import ANY, AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.cli.plus_api import PlusAPI
|
||||
from crewai.plus_api import PlusAPI
|
||||
|
||||
|
||||
class TestPlusAPI(unittest.TestCase):
|
||||
@@ -20,7 +20,7 @@ class TestPlusAPI(unittest.TestCase):
|
||||
self.assertTrue("CrewAI-CLI/" in self.api.headers["User-Agent"])
|
||||
self.assertTrue(self.api.headers["X-Crewai-Version"])
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
@patch("crewai.plus_api.PlusAPI._make_request")
|
||||
def test_login_to_tool_repository(self, mock_make_request):
|
||||
mock_response = MagicMock()
|
||||
mock_make_request.return_value = mock_response
|
||||
@@ -32,7 +32,7 @@ class TestPlusAPI(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
@patch("crewai.plus_api.PlusAPI._make_request")
|
||||
def test_login_to_tool_repository_with_user_identifier(self, mock_make_request):
|
||||
mock_response = MagicMock()
|
||||
mock_make_request.return_value = mock_response
|
||||
@@ -60,8 +60,8 @@ class TestPlusAPI(unittest.TestCase):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@patch("crewai.cli.plus_api.Settings")
|
||||
@patch("crewai.cli.plus_api.httpx.Client")
|
||||
@patch("crewai.plus_api.Settings")
|
||||
@patch("crewai.plus_api.httpx.Client")
|
||||
def test_login_to_tool_repository_with_org_uuid(
|
||||
self, mock_client_class, mock_settings_class
|
||||
):
|
||||
@@ -83,7 +83,7 @@ class TestPlusAPI(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
@patch("crewai.plus_api.PlusAPI._make_request")
|
||||
def test_get_tool(self, mock_make_request):
|
||||
mock_response = MagicMock()
|
||||
mock_make_request.return_value = mock_response
|
||||
@@ -94,8 +94,8 @@ class TestPlusAPI(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@patch("crewai.cli.plus_api.Settings")
|
||||
@patch("crewai.cli.plus_api.httpx.Client")
|
||||
@patch("crewai.plus_api.Settings")
|
||||
@patch("crewai.plus_api.httpx.Client")
|
||||
def test_get_tool_with_org_uuid(self, mock_client_class, mock_settings_class):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.org_uuid = self.org_uuid
|
||||
@@ -115,7 +115,7 @@ class TestPlusAPI(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
@patch("crewai.plus_api.PlusAPI._make_request")
|
||||
def test_publish_tool(self, mock_make_request):
|
||||
mock_response = MagicMock()
|
||||
mock_make_request.return_value = mock_response
|
||||
@@ -142,8 +142,8 @@ class TestPlusAPI(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@patch("crewai.cli.plus_api.Settings")
|
||||
@patch("crewai.cli.plus_api.httpx.Client")
|
||||
@patch("crewai.plus_api.Settings")
|
||||
@patch("crewai.plus_api.httpx.Client")
|
||||
def test_publish_tool_with_org_uuid(self, mock_client_class, mock_settings_class):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.org_uuid = self.org_uuid
|
||||
@@ -180,7 +180,7 @@ class TestPlusAPI(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
@patch("crewai.plus_api.PlusAPI._make_request")
|
||||
def test_publish_tool_without_description(self, mock_make_request):
|
||||
mock_response = MagicMock()
|
||||
mock_make_request.return_value = mock_response
|
||||
@@ -207,7 +207,7 @@ class TestPlusAPI(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@patch("crewai.cli.plus_api.httpx.Client")
|
||||
@patch("crewai.plus_api.httpx.Client")
|
||||
def test_make_request(self, mock_client_class):
|
||||
mock_client_instance = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
@@ -222,35 +222,35 @@ class TestPlusAPI(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
@patch("crewai.plus_api.PlusAPI._make_request")
|
||||
def test_deploy_by_name(self, mock_make_request):
|
||||
self.api.deploy_by_name("test_project")
|
||||
mock_make_request.assert_called_once_with(
|
||||
"POST", "/crewai_plus/api/v1/crews/by-name/test_project/deploy"
|
||||
)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
@patch("crewai.plus_api.PlusAPI._make_request")
|
||||
def test_deploy_by_uuid(self, mock_make_request):
|
||||
self.api.deploy_by_uuid("test_uuid")
|
||||
mock_make_request.assert_called_once_with(
|
||||
"POST", "/crewai_plus/api/v1/crews/test_uuid/deploy"
|
||||
)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
@patch("crewai.plus_api.PlusAPI._make_request")
|
||||
def test_crew_status_by_name(self, mock_make_request):
|
||||
self.api.crew_status_by_name("test_project")
|
||||
mock_make_request.assert_called_once_with(
|
||||
"GET", "/crewai_plus/api/v1/crews/by-name/test_project/status"
|
||||
)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
@patch("crewai.plus_api.PlusAPI._make_request")
|
||||
def test_crew_status_by_uuid(self, mock_make_request):
|
||||
self.api.crew_status_by_uuid("test_uuid")
|
||||
mock_make_request.assert_called_once_with(
|
||||
"GET", "/crewai_plus/api/v1/crews/test_uuid/status"
|
||||
)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
@patch("crewai.plus_api.PlusAPI._make_request")
|
||||
def test_crew_by_name(self, mock_make_request):
|
||||
self.api.crew_by_name("test_project")
|
||||
mock_make_request.assert_called_once_with(
|
||||
@@ -262,7 +262,7 @@ class TestPlusAPI(unittest.TestCase):
|
||||
"GET", "/crewai_plus/api/v1/crews/by-name/test_project/logs/custom_log"
|
||||
)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
@patch("crewai.plus_api.PlusAPI._make_request")
|
||||
def test_crew_by_uuid(self, mock_make_request):
|
||||
self.api.crew_by_uuid("test_uuid")
|
||||
mock_make_request.assert_called_once_with(
|
||||
@@ -274,26 +274,26 @@ class TestPlusAPI(unittest.TestCase):
|
||||
"GET", "/crewai_plus/api/v1/crews/test_uuid/logs/custom_log"
|
||||
)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
@patch("crewai.plus_api.PlusAPI._make_request")
|
||||
def test_delete_crew_by_name(self, mock_make_request):
|
||||
self.api.delete_crew_by_name("test_project")
|
||||
mock_make_request.assert_called_once_with(
|
||||
"DELETE", "/crewai_plus/api/v1/crews/by-name/test_project"
|
||||
)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
@patch("crewai.plus_api.PlusAPI._make_request")
|
||||
def test_delete_crew_by_uuid(self, mock_make_request):
|
||||
self.api.delete_crew_by_uuid("test_uuid")
|
||||
mock_make_request.assert_called_once_with(
|
||||
"DELETE", "/crewai_plus/api/v1/crews/test_uuid"
|
||||
)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
@patch("crewai.plus_api.PlusAPI._make_request")
|
||||
def test_list_crews(self, mock_make_request):
|
||||
self.api.list_crews()
|
||||
mock_make_request.assert_called_once_with("GET", "/crewai_plus/api/v1/crews")
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
@patch("crewai.plus_api.PlusAPI._make_request")
|
||||
def test_create_crew(self, mock_make_request):
|
||||
payload = {"name": "test_crew"}
|
||||
self.api.create_crew(payload)
|
||||
@@ -301,7 +301,7 @@ class TestPlusAPI(unittest.TestCase):
|
||||
"POST", "/crewai_plus/api/v1/crews", json=payload
|
||||
)
|
||||
|
||||
@patch("crewai.cli.plus_api.Settings")
|
||||
@patch("crewai.plus_api.Settings")
|
||||
@patch.dict(os.environ, {"CREWAI_PLUS_URL": ""})
|
||||
def test_custom_base_url(self, mock_settings_class):
|
||||
mock_settings = MagicMock()
|
||||
@@ -342,7 +342,7 @@ async def test_get_agent(mock_async_client_class):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("httpx.AsyncClient")
|
||||
@patch("crewai.cli.plus_api.Settings")
|
||||
@patch("crewai.plus_api.Settings")
|
||||
async def test_get_agent_with_org_uuid(mock_settings_class, mock_async_client_class):
|
||||
org_uuid = "test-org-uuid"
|
||||
mock_settings = MagicMock()
|
||||
|
||||
@@ -10,20 +10,20 @@ from unittest.mock import patch
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
from crewai.cli.shared.token_manager import TokenManager
|
||||
from crewai.auth.token_manager import TokenManager
|
||||
|
||||
|
||||
class TestTokenManager(unittest.TestCase):
|
||||
"""Test cases for TokenManager."""
|
||||
|
||||
@patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key")
|
||||
@patch("crewai.auth.token_manager.TokenManager._get_or_create_key")
|
||||
def setUp(self, mock_get_key: unittest.mock.MagicMock) -> None:
|
||||
"""Set up test fixtures."""
|
||||
mock_get_key.return_value = Fernet.generate_key()
|
||||
self.token_manager = TokenManager()
|
||||
|
||||
@patch("crewai.cli.shared.token_manager.TokenManager._read_secure_file")
|
||||
@patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key")
|
||||
@patch("crewai.auth.token_manager.TokenManager._read_secure_file")
|
||||
@patch("crewai.auth.token_manager.TokenManager._get_or_create_key")
|
||||
def test_get_or_create_key_existing(
|
||||
self,
|
||||
mock_get_or_create: unittest.mock.MagicMock,
|
||||
@@ -45,7 +45,7 @@ class TestTokenManager(unittest.TestCase):
|
||||
with (
|
||||
patch.object(self.token_manager, "_read_secure_file", return_value=None) as mock_read,
|
||||
patch.object(self.token_manager, "_atomic_create_secure_file", return_value=True) as mock_atomic_create,
|
||||
patch("crewai.cli.shared.token_manager.Fernet.generate_key", return_value=mock_key) as mock_generate,
|
||||
patch("crewai.auth.token_manager.Fernet.generate_key", return_value=mock_key) as mock_generate,
|
||||
):
|
||||
result = self.token_manager._get_or_create_key()
|
||||
|
||||
@@ -62,14 +62,14 @@ class TestTokenManager(unittest.TestCase):
|
||||
with (
|
||||
patch.object(self.token_manager, "_read_secure_file", side_effect=[None, their_key]) as mock_read,
|
||||
patch.object(self.token_manager, "_atomic_create_secure_file", return_value=False) as mock_atomic_create,
|
||||
patch("crewai.cli.shared.token_manager.Fernet.generate_key", return_value=our_key),
|
||||
patch("crewai.auth.token_manager.Fernet.generate_key", return_value=our_key),
|
||||
):
|
||||
result = self.token_manager._get_or_create_key()
|
||||
|
||||
self.assertEqual(result, their_key)
|
||||
self.assertEqual(mock_read.call_count, 2)
|
||||
|
||||
@patch("crewai.cli.shared.token_manager.TokenManager._atomic_write_secure_file")
|
||||
@patch("crewai.auth.token_manager.TokenManager._atomic_write_secure_file")
|
||||
def test_save_tokens(
|
||||
self, mock_write: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
@@ -88,7 +88,7 @@ class TestTokenManager(unittest.TestCase):
|
||||
expiration = datetime.fromisoformat(data["expiration"])
|
||||
self.assertEqual(expiration, datetime.fromtimestamp(expires_at))
|
||||
|
||||
@patch("crewai.cli.shared.token_manager.TokenManager._read_secure_file")
|
||||
@patch("crewai.auth.token_manager.TokenManager._read_secure_file")
|
||||
def test_get_token_valid(
|
||||
self, mock_read: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
@@ -103,7 +103,7 @@ class TestTokenManager(unittest.TestCase):
|
||||
|
||||
self.assertEqual(result, access_token)
|
||||
|
||||
@patch("crewai.cli.shared.token_manager.TokenManager._read_secure_file")
|
||||
@patch("crewai.auth.token_manager.TokenManager._read_secure_file")
|
||||
def test_get_token_expired(
|
||||
self, mock_read: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
@@ -118,7 +118,7 @@ class TestTokenManager(unittest.TestCase):
|
||||
|
||||
self.assertIsNone(result)
|
||||
|
||||
@patch("crewai.cli.shared.token_manager.TokenManager._read_secure_file")
|
||||
@patch("crewai.auth.token_manager.TokenManager._read_secure_file")
|
||||
def test_get_token_not_found(
|
||||
self, mock_read: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
@@ -129,7 +129,7 @@ class TestTokenManager(unittest.TestCase):
|
||||
|
||||
self.assertIsNone(result)
|
||||
|
||||
@patch("crewai.cli.shared.token_manager.TokenManager._delete_secure_file")
|
||||
@patch("crewai.auth.token_manager.TokenManager._delete_secure_file")
|
||||
def test_clear_tokens(
|
||||
self, mock_delete: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
@@ -159,7 +159,7 @@ class TestAtomicFileOperations(unittest.TestCase):
|
||||
import shutil
|
||||
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||
|
||||
@patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key")
|
||||
@patch("crewai.auth.token_manager.TokenManager._get_or_create_key")
|
||||
def test_atomic_create_new_file(
|
||||
self, mock_get_key: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
@@ -175,7 +175,7 @@ class TestAtomicFileOperations(unittest.TestCase):
|
||||
self.assertEqual(file_path.read_bytes(), b"content")
|
||||
self.assertEqual(file_path.stat().st_mode & 0o777, 0o600)
|
||||
|
||||
@patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key")
|
||||
@patch("crewai.auth.token_manager.TokenManager._get_or_create_key")
|
||||
def test_atomic_create_existing_file(
|
||||
self, mock_get_key: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
@@ -192,7 +192,7 @@ class TestAtomicFileOperations(unittest.TestCase):
|
||||
self.assertFalse(result)
|
||||
self.assertEqual(file_path.read_bytes(), b"original")
|
||||
|
||||
@patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key")
|
||||
@patch("crewai.auth.token_manager.TokenManager._get_or_create_key")
|
||||
def test_atomic_write_new_file(
|
||||
self, mock_get_key: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
@@ -207,7 +207,7 @@ class TestAtomicFileOperations(unittest.TestCase):
|
||||
self.assertEqual(file_path.read_bytes(), b"content")
|
||||
self.assertEqual(file_path.stat().st_mode & 0o777, 0o600)
|
||||
|
||||
@patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key")
|
||||
@patch("crewai.auth.token_manager.TokenManager._get_or_create_key")
|
||||
def test_atomic_write_overwrites(
|
||||
self, mock_get_key: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
@@ -222,7 +222,7 @@ class TestAtomicFileOperations(unittest.TestCase):
|
||||
|
||||
self.assertEqual(file_path.read_bytes(), b"new content")
|
||||
|
||||
@patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key")
|
||||
@patch("crewai.auth.token_manager.TokenManager._get_or_create_key")
|
||||
def test_atomic_write_no_temp_file_on_success(
|
||||
self, mock_get_key: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
@@ -236,7 +236,7 @@ class TestAtomicFileOperations(unittest.TestCase):
|
||||
temp_files = list(Path(self.temp_dir).glob(".test.txt.*"))
|
||||
self.assertEqual(len(temp_files), 0)
|
||||
|
||||
@patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key")
|
||||
@patch("crewai.auth.token_manager.TokenManager._get_or_create_key")
|
||||
def test_read_secure_file_exists(
|
||||
self, mock_get_key: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
@@ -251,7 +251,7 @@ class TestAtomicFileOperations(unittest.TestCase):
|
||||
|
||||
self.assertEqual(result, b"content")
|
||||
|
||||
@patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key")
|
||||
@patch("crewai.auth.token_manager.TokenManager._get_or_create_key")
|
||||
def test_read_secure_file_not_exists(
|
||||
self, mock_get_key: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
@@ -263,7 +263,7 @@ class TestAtomicFileOperations(unittest.TestCase):
|
||||
|
||||
self.assertIsNone(result)
|
||||
|
||||
@patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key")
|
||||
@patch("crewai.auth.token_manager.TokenManager._get_or_create_key")
|
||||
def test_delete_secure_file_exists(
|
||||
self, mock_get_key: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
@@ -278,7 +278,7 @@ class TestAtomicFileOperations(unittest.TestCase):
|
||||
|
||||
self.assertFalse(file_path.exists())
|
||||
|
||||
@patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key")
|
||||
@patch("crewai.auth.token_manager.TokenManager._get_or_create_key")
|
||||
def test_delete_secure_file_not_exists(
|
||||
self, mock_get_key: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
|
||||
@@ -4,7 +4,7 @@ import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from crewai.cli import utils
|
||||
from crewai.utilities import project_utils as utils
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -6,7 +6,7 @@ from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from crewai import __version__
|
||||
from crewai.cli.version import (
|
||||
from crewai.version import (
|
||||
_find_latest_non_yanked_version,
|
||||
_get_cache_file,
|
||||
_is_cache_valid,
|
||||
@@ -60,8 +60,8 @@ class TestVersionChecking:
|
||||
cache_data = {"version": "1.0.0"}
|
||||
assert _is_cache_valid(cache_data) is False
|
||||
|
||||
@patch("crewai.cli.version.Path.exists")
|
||||
@patch("crewai.cli.version.request.urlopen")
|
||||
@patch("crewai.version.Path.exists")
|
||||
@patch("crewai.version.request.urlopen")
|
||||
def test_get_latest_version_from_pypi_success(
|
||||
self, mock_urlopen: MagicMock, mock_exists: MagicMock
|
||||
) -> None:
|
||||
@@ -82,8 +82,8 @@ class TestVersionChecking:
|
||||
version = get_latest_version_from_pypi()
|
||||
assert version == "2.0.0"
|
||||
|
||||
@patch("crewai.cli.version.Path.exists")
|
||||
@patch("crewai.cli.version.request.urlopen")
|
||||
@patch("crewai.version.Path.exists")
|
||||
@patch("crewai.version.request.urlopen")
|
||||
def test_get_latest_version_from_pypi_failure(
|
||||
self, mock_urlopen: MagicMock, mock_exists: MagicMock
|
||||
) -> None:
|
||||
@@ -97,8 +97,8 @@ class TestVersionChecking:
|
||||
version = get_latest_version_from_pypi()
|
||||
assert version is None
|
||||
|
||||
@patch("crewai.cli.version.get_crewai_version")
|
||||
@patch("crewai.cli.version.get_latest_version_from_pypi")
|
||||
@patch("crewai.version.get_crewai_version")
|
||||
@patch("crewai.version.get_latest_version_from_pypi")
|
||||
def test_is_newer_version_available_true(
|
||||
self, mock_latest: MagicMock, mock_current: MagicMock
|
||||
) -> None:
|
||||
@@ -111,8 +111,8 @@ class TestVersionChecking:
|
||||
assert current == "1.0.0"
|
||||
assert latest == "2.0.0"
|
||||
|
||||
@patch("crewai.cli.version.get_crewai_version")
|
||||
@patch("crewai.cli.version.get_latest_version_from_pypi")
|
||||
@patch("crewai.version.get_crewai_version")
|
||||
@patch("crewai.version.get_latest_version_from_pypi")
|
||||
def test_is_newer_version_available_false(
|
||||
self, mock_latest: MagicMock, mock_current: MagicMock
|
||||
) -> None:
|
||||
@@ -125,8 +125,8 @@ class TestVersionChecking:
|
||||
assert current == "2.0.0"
|
||||
assert latest == "2.0.0"
|
||||
|
||||
@patch("crewai.cli.version.get_crewai_version")
|
||||
@patch("crewai.cli.version.get_latest_version_from_pypi")
|
||||
@patch("crewai.version.get_crewai_version")
|
||||
@patch("crewai.version.get_latest_version_from_pypi")
|
||||
def test_is_newer_version_available_with_none_latest(
|
||||
self, mock_latest: MagicMock, mock_current: MagicMock
|
||||
) -> None:
|
||||
@@ -260,8 +260,8 @@ class TestIsVersionYanked:
|
||||
class TestIsCurrentVersionYanked:
|
||||
"""Test is_current_version_yanked public function."""
|
||||
|
||||
@patch("crewai.cli.version.get_crewai_version")
|
||||
@patch("crewai.cli.version._get_cache_file")
|
||||
@patch("crewai.version.get_crewai_version")
|
||||
@patch("crewai.version._get_cache_file")
|
||||
def test_reads_from_valid_cache(
|
||||
self, mock_cache_file: MagicMock, mock_version: MagicMock, tmp_path: Path
|
||||
) -> None:
|
||||
@@ -282,8 +282,8 @@ class TestIsCurrentVersionYanked:
|
||||
assert is_yanked is True
|
||||
assert reason == "bad release"
|
||||
|
||||
@patch("crewai.cli.version.get_crewai_version")
|
||||
@patch("crewai.cli.version._get_cache_file")
|
||||
@patch("crewai.version.get_crewai_version")
|
||||
@patch("crewai.version._get_cache_file")
|
||||
def test_not_yanked_from_cache(
|
||||
self, mock_cache_file: MagicMock, mock_version: MagicMock, tmp_path: Path
|
||||
) -> None:
|
||||
@@ -304,9 +304,9 @@ class TestIsCurrentVersionYanked:
|
||||
assert is_yanked is False
|
||||
assert reason == ""
|
||||
|
||||
@patch("crewai.cli.version.get_latest_version_from_pypi")
|
||||
@patch("crewai.cli.version.get_crewai_version")
|
||||
@patch("crewai.cli.version._get_cache_file")
|
||||
@patch("crewai.version.get_latest_version_from_pypi")
|
||||
@patch("crewai.version.get_crewai_version")
|
||||
@patch("crewai.version._get_cache_file")
|
||||
def test_triggers_fetch_on_stale_cache(
|
||||
self,
|
||||
mock_cache_file: MagicMock,
|
||||
@@ -346,9 +346,9 @@ class TestIsCurrentVersionYanked:
|
||||
assert is_yanked is False
|
||||
mock_fetch.assert_called_once()
|
||||
|
||||
@patch("crewai.cli.version.get_latest_version_from_pypi")
|
||||
@patch("crewai.cli.version.get_crewai_version")
|
||||
@patch("crewai.cli.version._get_cache_file")
|
||||
@patch("crewai.version.get_latest_version_from_pypi")
|
||||
@patch("crewai.version.get_crewai_version")
|
||||
@patch("crewai.version._get_cache_file")
|
||||
def test_returns_false_on_fetch_failure(
|
||||
self,
|
||||
mock_cache_file: MagicMock,
|
||||
|
||||
@@ -11,7 +11,7 @@ from crewai.llms.providers.openai.completion import OpenAICompletion, ResponsesA
|
||||
from crewai.crew import Crew
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
from crewai.cli.constants import DEFAULT_LLM_MODEL
|
||||
from crewai.constants import DEFAULT_LLM_MODEL
|
||||
|
||||
def test_openai_completion_is_used_when_openai_provider():
|
||||
"""
|
||||
|
||||
@@ -102,7 +102,7 @@ class TestBuildMCPConfigFromDict:
|
||||
|
||||
|
||||
class TestFetchAmpMCPConfigs:
|
||||
@patch("crewai.cli.plus_api.PlusAPI")
|
||||
@patch("crewai.plus_api.PlusAPI")
|
||||
@patch("crewai_tools.tools.crewai_platform_tools.misc.get_platform_integration_token", return_value="test-api-key")
|
||||
def test_fetches_configs_successfully(self, mock_get_token, mock_plus_api_class, resolver):
|
||||
mock_response = MagicMock()
|
||||
@@ -133,7 +133,7 @@ class TestFetchAmpMCPConfigs:
|
||||
mock_plus_api_class.assert_called_once_with(api_key="test-api-key")
|
||||
mock_plus_api.get_mcp_configs.assert_called_once_with(["notion", "github"])
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI")
|
||||
@patch("crewai.plus_api.PlusAPI")
|
||||
@patch("crewai_tools.tools.crewai_platform_tools.misc.get_platform_integration_token", return_value="test-api-key")
|
||||
def test_omits_missing_slugs(self, mock_get_token, mock_plus_api_class, resolver):
|
||||
mock_response = MagicMock()
|
||||
@@ -150,7 +150,7 @@ class TestFetchAmpMCPConfigs:
|
||||
assert "notion" in result
|
||||
assert "missing-server" not in result
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI")
|
||||
@patch("crewai.plus_api.PlusAPI")
|
||||
@patch("crewai_tools.tools.crewai_platform_tools.misc.get_platform_integration_token", return_value="test-api-key")
|
||||
def test_returns_empty_on_http_error(self, mock_get_token, mock_plus_api_class, resolver):
|
||||
mock_response = MagicMock()
|
||||
@@ -163,7 +163,7 @@ class TestFetchAmpMCPConfigs:
|
||||
|
||||
assert result == {}
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI")
|
||||
@patch("crewai.plus_api.PlusAPI")
|
||||
@patch("crewai_tools.tools.crewai_platform_tools.misc.get_platform_integration_token", return_value="test-api-key")
|
||||
def test_returns_empty_on_network_error(self, mock_get_token, mock_plus_api_class, resolver):
|
||||
import httpx
|
||||
|
||||
@@ -35,7 +35,7 @@ class TestTraceListenerSetup:
|
||||
# Need to patch all the places where get_auth_token is imported/used
|
||||
with (
|
||||
patch(
|
||||
"crewai.cli.authentication.token.get_auth_token",
|
||||
"crewai.auth.token.get_auth_token",
|
||||
return_value="mock_token_12345",
|
||||
),
|
||||
patch(
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
from crewai.cli.constants import DEFAULT_LLM_MODEL
|
||||
from crewai.constants import DEFAULT_LLM_MODEL
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
|
||||
Reference in New Issue
Block a user