From 7afca5daab4120b50c73733116bf834e13a42294 Mon Sep 17 00:00:00 2001 From: Greyson Lalonde Date: Sun, 15 Mar 2026 19:39:55 -0400 Subject: [PATCH] refactor: remove cli/ from crewai package and relocate to proper modules MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move framework infrastructure out of crewai/cli/ to dedicated modules: - cli/authentication/ → crewai/auth/ - cli/config.py → crewai/settings.py - cli/constants.py → crewai/constants.py - cli/plus_api.py → crewai/plus_api.py - cli/version.py → crewai/version.py - cli/crew_chat.py → crewai/utilities/crew_chat.py - cli/reset_memories_command.py → crewai/utilities/reset_memories.py - cli/utils.py (framework parts) → crewai/utilities/project_utils.py Delete CLI-only duplicates (command.py, git.py, provider.py) already present in crewai_cli. Replace _login_to_tool_repository with a _post_login() hook in AuthenticationCommand. Update all imports and mock.patch paths across both packages and tests. --- lib/cli/src/crewai_cli/crew_chat.py | 6 +- .../src/crewai_cli/reset_memories_command.py | 4 +- lib/cli/src/crewai_cli/utils.py | 2 +- lib/crewai/src/crewai/auth/__init__.py | 7 + lib/crewai/src/crewai/auth/constants.py | 3 + .../authentication/main.py => auth/oauth2.py} | 65 ++--- .../src/crewai/auth/providers/__init__.py | 1 + .../providers/auth0.py | 6 +- .../providers/base_provider.py | 9 +- .../providers/entra_id.py | 6 +- .../providers/keycloak.py | 6 +- .../authentication => auth}/providers/okta.py | 6 +- .../providers/workos.py | 6 +- .../{cli/authentication => auth}/token.py | 6 +- .../{cli/shared => auth}/token_manager.py | 2 + .../{cli/authentication => auth}/utils.py | 26 +- lib/crewai/src/crewai/cli/__init__.py | 0 .../src/crewai/cli/authentication/__init__.py | 4 - .../crewai/cli/authentication/constants.py | 1 - .../cli/authentication/providers/__init__.py | 0 lib/crewai/src/crewai/cli/command.py | 76 ------ lib/crewai/src/crewai/cli/git.py | 89 ------- lib/crewai/src/crewai/cli/provider.py | 231 ------------------ lib/crewai/src/crewai/cli/shared/__init__.py | 0 lib/crewai/src/crewai/{cli => }/constants.py | 2 + .../listeners/tracing/trace_batch_manager.py | 10 +- .../listeners/tracing/trace_listener.py | 4 +- .../crewai/events/utils/console_formatter.py | 2 +- lib/crewai/src/crewai/mcp/tool_resolver.py | 3 +- lib/crewai/src/crewai/{cli => }/plus_api.py | 12 +- .../src/crewai/{cli/config.py => settings.py} | 53 ++-- .../src/crewai/utilities/agent_utils.py | 6 +- .../crewai/{cli => utilities}/crew_chat.py | 114 +++------ lib/crewai/src/crewai/utilities/llm_utils.py | 2 +- .../utils.py => utilities/project_utils.py} | 203 ++------------- .../reset_memories.py} | 7 +- lib/crewai/src/crewai/{cli => }/version.py | 4 +- lib/crewai/tests/agents/test_agent.py | 18 +- .../authentication/providers/test_auth0.py | 4 +- .../authentication/providers/test_entra_id.py | 4 +- .../authentication/providers/test_keycloak.py | 4 +- .../cli/authentication/providers/test_okta.py | 4 +- .../authentication/providers/test_workos.py | 4 +- .../cli/authentication/test_auth_main.py | 44 ++-- .../tests/cli/authentication/test_utils.py | 6 +- lib/crewai/tests/cli/test_cli.py | 20 +- lib/crewai/tests/cli/test_config.py | 6 +- lib/crewai/tests/cli/test_constants.py | 2 +- lib/crewai/tests/cli/test_git.py | 101 -------- lib/crewai/tests/cli/test_plus_api.py | 50 ++-- lib/crewai/tests/cli/test_token_manager.py | 40 +-- lib/crewai/tests/cli/test_utils.py | 2 +- lib/crewai/tests/cli/test_version.py | 42 ++-- lib/crewai/tests/llms/openai/test_openai.py | 2 +- lib/crewai/tests/mcp/test_amp_mcp.py | 8 +- lib/crewai/tests/tracing/test_tracing.py | 2 +- lib/crewai/tests/utilities/test_llm_utils.py | 2 +- 57 files changed, 324 insertions(+), 1025 deletions(-) create mode 100644 lib/crewai/src/crewai/auth/__init__.py create mode 100644 lib/crewai/src/crewai/auth/constants.py rename lib/crewai/src/crewai/{cli/authentication/main.py => auth/oauth2.py} (76%) create mode 100644 lib/crewai/src/crewai/auth/providers/__init__.py rename lib/crewai/src/crewai/{cli/authentication => auth}/providers/auth0.py (88%) rename lib/crewai/src/crewai/{cli/authentication => auth}/providers/base_provider.py (76%) rename lib/crewai/src/crewai/{cli/authentication => auth}/providers/entra_id.py (88%) rename lib/crewai/src/crewai/{cli/authentication => auth}/providers/keycloak.py (89%) rename lib/crewai/src/crewai/{cli/authentication => auth}/providers/okta.py (91%) rename lib/crewai/src/crewai/{cli/authentication => auth}/providers/workos.py (86%) rename lib/crewai/src/crewai/{cli/authentication => auth}/token.py (66%) rename lib/crewai/src/crewai/{cli/shared => auth}/token_manager.py (99%) rename lib/crewai/src/crewai/{cli/authentication => auth}/utils.py (77%) delete mode 100644 lib/crewai/src/crewai/cli/__init__.py delete mode 100644 lib/crewai/src/crewai/cli/authentication/__init__.py delete mode 100644 lib/crewai/src/crewai/cli/authentication/constants.py delete mode 100644 lib/crewai/src/crewai/cli/authentication/providers/__init__.py delete mode 100644 lib/crewai/src/crewai/cli/command.py delete mode 100644 lib/crewai/src/crewai/cli/git.py delete mode 100644 lib/crewai/src/crewai/cli/provider.py delete mode 100644 lib/crewai/src/crewai/cli/shared/__init__.py rename lib/crewai/src/crewai/{cli => }/constants.py (99%) rename lib/crewai/src/crewai/{cli => }/plus_api.py (96%) rename lib/crewai/src/crewai/{cli/config.py => settings.py} (78%) rename lib/crewai/src/crewai/{cli => utilities}/crew_chat.py (80%) rename lib/crewai/src/crewai/{cli/utils.py => utilities/project_utils.py} (67%) rename lib/crewai/src/crewai/{cli/reset_memories_command.py => utilities/reset_memories.py} (94%) rename lib/crewai/src/crewai/{cli => }/version.py (98%) delete mode 100644 lib/crewai/tests/cli/test_git.py diff --git a/lib/cli/src/crewai_cli/crew_chat.py b/lib/cli/src/crewai_cli/crew_chat.py index e2234e43f..c3f7cf06d 100644 --- a/lib/cli/src/crewai_cli/crew_chat.py +++ b/lib/cli/src/crewai_cli/crew_chat.py @@ -1,7 +1,7 @@ """Wrapper for the crew chat command. -Delegates to ``crewai.cli.crew_chat.run_chat`` when the full crewai package is -installed, otherwise prints a helpful error message. +Delegates to ``crewai.utilities.crew_chat.run_chat`` when the full crewai +package is installed, otherwise prints a helpful error message. """ from __future__ import annotations @@ -11,7 +11,7 @@ import click def run_chat() -> None: try: - from crewai.cli.crew_chat import run_chat as _run_chat + from crewai.utilities.crew_chat import run_chat as _run_chat except ImportError: click.secho( "The 'chat' command requires the full crewai package.\n" diff --git a/lib/cli/src/crewai_cli/reset_memories_command.py b/lib/cli/src/crewai_cli/reset_memories_command.py index bc871bf0e..9778bf628 100644 --- a/lib/cli/src/crewai_cli/reset_memories_command.py +++ b/lib/cli/src/crewai_cli/reset_memories_command.py @@ -1,6 +1,6 @@ """Wrapper for the reset-memories command. -Delegates to ``crewai.cli.reset_memories_command`` when the full crewai +Delegates to ``crewai.utilities.reset_memories`` when the full crewai package is installed, otherwise prints a helpful error message. """ @@ -17,7 +17,7 @@ def reset_memories_command( all: bool, ) -> None: try: - from crewai.cli.reset_memories_command import ( + from crewai.utilities.reset_memories import ( reset_memories_command as _reset, ) except ImportError: diff --git a/lib/cli/src/crewai_cli/utils.py b/lib/cli/src/crewai_cli/utils.py index 304ef1e8a..9d2a2f806 100644 --- a/lib/cli/src/crewai_cli/utils.py +++ b/lib/cli/src/crewai_cli/utils.py @@ -246,7 +246,7 @@ def is_valid_tool(obj: Any) -> bool: Falls back to crewai's ``is_valid_tool`` when available. """ try: - from crewai.cli.utils import is_valid_tool as _core_is_valid_tool + from crewai.utilities.project_utils import is_valid_tool as _core_is_valid_tool return _core_is_valid_tool(obj) except ImportError: diff --git a/lib/crewai/src/crewai/auth/__init__.py b/lib/crewai/src/crewai/auth/__init__.py new file mode 100644 index 000000000..f33f09a58 --- /dev/null +++ b/lib/crewai/src/crewai/auth/__init__.py @@ -0,0 +1,7 @@ +"""Authentication utilities for the CrewAI platform.""" + +from crewai.auth.oauth2 import AuthenticationCommand +from crewai.auth.token import AuthError, get_auth_token + + +__all__ = ["AuthError", "AuthenticationCommand", "get_auth_token"] diff --git a/lib/crewai/src/crewai/auth/constants.py b/lib/crewai/src/crewai/auth/constants.py new file mode 100644 index 000000000..3fa25aa4e --- /dev/null +++ b/lib/crewai/src/crewai/auth/constants.py @@ -0,0 +1,3 @@ +"""Authentication constants.""" + +ALGORITHMS = ["RS256"] diff --git a/lib/crewai/src/crewai/cli/authentication/main.py b/lib/crewai/src/crewai/auth/oauth2.py similarity index 76% rename from lib/crewai/src/crewai/cli/authentication/main.py rename to lib/crewai/src/crewai/auth/oauth2.py index 61c85bb44..170f804f8 100644 --- a/lib/crewai/src/crewai/cli/authentication/main.py +++ b/lib/crewai/src/crewai/auth/oauth2.py @@ -1,3 +1,5 @@ +"""OAuth2 authentication for the CrewAI platform.""" + import time from typing import TYPE_CHECKING, Any, TypeVar, cast import webbrowser @@ -6,9 +8,9 @@ import httpx from pydantic import BaseModel, Field from rich.console import Console -from crewai.cli.authentication.utils import validate_jwt_token -from crewai.cli.config import Settings -from crewai.cli.shared.token_manager import TokenManager +from crewai.auth.token_manager import TokenManager +from crewai.auth.utils import validate_jwt_token +from crewai.settings import Settings console = Console() @@ -17,6 +19,8 @@ TOauth2Settings = TypeVar("TOauth2Settings", bound="Oauth2Settings") class Oauth2Settings(BaseModel): + """OAuth2 provider configuration.""" + provider: str = Field( description="OAuth2 provider used for authentication (e.g., workos, okta, auth0)." ) @@ -38,7 +42,6 @@ class Oauth2Settings(BaseModel): @classmethod def from_settings(cls: type[TOauth2Settings]) -> TOauth2Settings: """Create an Oauth2Settings instance from the CLI settings.""" - settings = Settings() return cls( @@ -51,23 +54,25 @@ class Oauth2Settings(BaseModel): if TYPE_CHECKING: - from crewai.cli.authentication.providers.base_provider import BaseProvider + from crewai.auth.providers.base_provider import BaseProvider class ProviderFactory: + """Factory for creating OAuth2 providers from settings.""" + @classmethod def from_settings( cls: type["ProviderFactory"], # noqa: UP037 settings: Oauth2Settings | None = None, ) -> "BaseProvider": # noqa: UP037 + """Create a provider instance from settings.""" settings = settings or Oauth2Settings.from_settings() import importlib module = importlib.import_module( - f"crewai.cli.authentication.providers.{settings.provider.lower()}" + f"crewai.auth.providers.{settings.provider.lower()}" ) - # Converts from snake_case to CamelCase to obtain the provider class name. provider = getattr( module, f"{''.join(word.capitalize() for word in settings.provider.split('_'))}Provider", @@ -77,6 +82,8 @@ class ProviderFactory: class AuthenticationCommand: + """Handles authentication with the CrewAI platform.""" + def __init__(self) -> None: self.token_manager = TokenManager() self.oauth2_provider = ProviderFactory.from_settings() @@ -92,7 +99,6 @@ class AuthenticationCommand: def _get_device_code(self) -> dict[str, Any]: """Get the device code to authenticate the user.""" - device_code_payload = { "client_id": self.oauth2_provider.get_client_id(), "scope": " ".join(self.oauth2_provider.get_oauth_scopes()), @@ -108,7 +114,6 @@ class AuthenticationCommand: def _display_auth_instructions(self, device_code_data: dict[str, str]) -> None: """Display the authentication instructions to the user.""" - verification_uri = device_code_data.get( "verification_uri_complete", device_code_data.get("verification_uri", "") ) @@ -119,7 +124,6 @@ class AuthenticationCommand: def _poll_for_token(self, device_code_data: dict[str, Any]) -> None: """Polls the server for the token until it is received, or max attempts are reached.""" - token_payload = { "grant_type": "urn:ietf:params:oauth:grant-type:device_code", "device_code": device_code_data["device_code"], @@ -143,7 +147,7 @@ class AuthenticationCommand: style="bold green", ) - self._login_to_tool_repository() + self._post_login() console.print("\n[bold green]Welcome to CrewAI AMP![/bold green]\n") return @@ -162,7 +166,6 @@ class AuthenticationCommand: def _validate_and_save_token(self, token_data: dict[str, Any]) -> None: """Validates the JWT token and saves the token to the token manager.""" - jwt_token = token_data["access_token"] issuer = self.oauth2_provider.get_issuer() jwt_token_data = { @@ -177,39 +180,5 @@ class AuthenticationCommand: expires_at = decoded_token.get("exp", 0) self.token_manager.save_tokens(jwt_token, expires_at) - def _login_to_tool_repository(self) -> None: - """Login to the tool repository.""" - - from crewai_cli.tools.main import ToolCommand - - try: - console.print( - "Now logging you in to the Tool Repository... ", - style="bold blue", - end="", - ) - - ToolCommand().login() - - console.print( - "Success!\n", - style="bold green", - ) - - settings = Settings() - - console.print( - f"You are now authenticated to the tool repository for organization [bold cyan]'{settings.org_name if settings.org_name else settings.org_uuid}'[/bold cyan]", - style="green", - ) - except Exception: - console.print( - "\n[bold yellow]Warning:[/bold yellow] Authentication with the Tool Repository failed.", - style="yellow", - ) - console.print( - "Other features will work normally, but you may experience limitations " - "with downloading and publishing tools." - "\nRun [bold]crewai login[/bold] to try logging in again.\n", - style="yellow", - ) + def _post_login(self) -> None: + """Hook called after successful login. Override in subclasses for additional behavior.""" diff --git a/lib/crewai/src/crewai/auth/providers/__init__.py b/lib/crewai/src/crewai/auth/providers/__init__.py new file mode 100644 index 000000000..c495fe55b --- /dev/null +++ b/lib/crewai/src/crewai/auth/providers/__init__.py @@ -0,0 +1 @@ +"""OAuth2 authentication providers.""" diff --git a/lib/crewai/src/crewai/cli/authentication/providers/auth0.py b/lib/crewai/src/crewai/auth/providers/auth0.py similarity index 88% rename from lib/crewai/src/crewai/cli/authentication/providers/auth0.py rename to lib/crewai/src/crewai/auth/providers/auth0.py index b27e3d168..a86c33cc5 100644 --- a/lib/crewai/src/crewai/cli/authentication/providers/auth0.py +++ b/lib/crewai/src/crewai/auth/providers/auth0.py @@ -1,7 +1,11 @@ -from crewai.cli.authentication.providers.base_provider import BaseProvider +"""Auth0 OAuth2 provider.""" + +from crewai.auth.providers.base_provider import BaseProvider class Auth0Provider(BaseProvider): + """Auth0 OAuth2 provider implementation.""" + def get_authorize_url(self) -> str: return f"https://{self._get_domain()}/oauth/device/code" diff --git a/lib/crewai/src/crewai/cli/authentication/providers/base_provider.py b/lib/crewai/src/crewai/auth/providers/base_provider.py similarity index 76% rename from lib/crewai/src/crewai/cli/authentication/providers/base_provider.py rename to lib/crewai/src/crewai/auth/providers/base_provider.py index 9412ca283..d69d8d673 100644 --- a/lib/crewai/src/crewai/cli/authentication/providers/base_provider.py +++ b/lib/crewai/src/crewai/auth/providers/base_provider.py @@ -1,9 +1,13 @@ +"""Base OAuth2 provider interface.""" + from abc import ABC, abstractmethod -from crewai.cli.authentication.main import Oauth2Settings +from crewai.auth.oauth2 import Oauth2Settings class BaseProvider(ABC): + """Abstract base class for OAuth2 providers.""" + def __init__(self, settings: Oauth2Settings): self.settings = settings @@ -26,8 +30,9 @@ class BaseProvider(ABC): def get_client_id(self) -> str: ... def get_required_fields(self) -> list[str]: - """Returns which provider-specific fields inside the "extra" dict will be required""" + """Returns which provider-specific fields inside the "extra" dict will be required.""" return [] def get_oauth_scopes(self) -> list[str]: + """Returns the OAuth scopes to request.""" return ["openid", "profile", "email"] diff --git a/lib/crewai/src/crewai/cli/authentication/providers/entra_id.py b/lib/crewai/src/crewai/auth/providers/entra_id.py similarity index 88% rename from lib/crewai/src/crewai/cli/authentication/providers/entra_id.py rename to lib/crewai/src/crewai/auth/providers/entra_id.py index c08ea4ec7..1bd1dc9a8 100644 --- a/lib/crewai/src/crewai/cli/authentication/providers/entra_id.py +++ b/lib/crewai/src/crewai/auth/providers/entra_id.py @@ -1,9 +1,13 @@ +"""Entra ID (Azure AD) OAuth2 provider.""" + from typing import cast -from crewai.cli.authentication.providers.base_provider import BaseProvider +from crewai.auth.providers.base_provider import BaseProvider class EntraIdProvider(BaseProvider): + """Entra ID (Azure AD) OAuth2 provider implementation.""" + def get_authorize_url(self) -> str: return f"{self._base_url()}/oauth2/v2.0/devicecode" diff --git a/lib/crewai/src/crewai/cli/authentication/providers/keycloak.py b/lib/crewai/src/crewai/auth/providers/keycloak.py similarity index 89% rename from lib/crewai/src/crewai/cli/authentication/providers/keycloak.py rename to lib/crewai/src/crewai/auth/providers/keycloak.py index e7b076121..b2b82947e 100644 --- a/lib/crewai/src/crewai/cli/authentication/providers/keycloak.py +++ b/lib/crewai/src/crewai/auth/providers/keycloak.py @@ -1,7 +1,11 @@ -from crewai.cli.authentication.providers.base_provider import BaseProvider +"""Keycloak OAuth2 provider.""" + +from crewai.auth.providers.base_provider import BaseProvider class KeycloakProvider(BaseProvider): + """Keycloak OAuth2 provider implementation.""" + def get_authorize_url(self) -> str: return f"{self._oauth2_base_url()}/realms/{self.settings.extra.get('realm')}/protocol/openid-connect/auth/device" diff --git a/lib/crewai/src/crewai/cli/authentication/providers/okta.py b/lib/crewai/src/crewai/auth/providers/okta.py similarity index 91% rename from lib/crewai/src/crewai/cli/authentication/providers/okta.py rename to lib/crewai/src/crewai/auth/providers/okta.py index 90f5e2908..13f051360 100644 --- a/lib/crewai/src/crewai/cli/authentication/providers/okta.py +++ b/lib/crewai/src/crewai/auth/providers/okta.py @@ -1,7 +1,11 @@ -from crewai.cli.authentication.providers.base_provider import BaseProvider +"""Okta OAuth2 provider.""" + +from crewai.auth.providers.base_provider import BaseProvider class OktaProvider(BaseProvider): + """Okta OAuth2 provider implementation.""" + def get_authorize_url(self) -> str: return f"{self._oauth2_base_url()}/v1/device/authorize" diff --git a/lib/crewai/src/crewai/cli/authentication/providers/workos.py b/lib/crewai/src/crewai/auth/providers/workos.py similarity index 86% rename from lib/crewai/src/crewai/cli/authentication/providers/workos.py rename to lib/crewai/src/crewai/auth/providers/workos.py index 7cffdf890..dda5fe62f 100644 --- a/lib/crewai/src/crewai/cli/authentication/providers/workos.py +++ b/lib/crewai/src/crewai/auth/providers/workos.py @@ -1,7 +1,11 @@ -from crewai.cli.authentication.providers.base_provider import BaseProvider +"""WorkOS OAuth2 provider.""" + +from crewai.auth.providers.base_provider import BaseProvider class WorkosProvider(BaseProvider): + """WorkOS OAuth2 provider implementation.""" + def get_authorize_url(self) -> str: return f"https://{self._get_domain()}/oauth2/device_authorization" diff --git a/lib/crewai/src/crewai/cli/authentication/token.py b/lib/crewai/src/crewai/auth/token.py similarity index 66% rename from lib/crewai/src/crewai/cli/authentication/token.py rename to lib/crewai/src/crewai/auth/token.py index 7a1d05c98..bc9501502 100644 --- a/lib/crewai/src/crewai/cli/authentication/token.py +++ b/lib/crewai/src/crewai/auth/token.py @@ -1,8 +1,10 @@ -from crewai.cli.shared.token_manager import TokenManager +"""Authentication token retrieval.""" + +from crewai.auth.token_manager import TokenManager class AuthError(Exception): - pass + """Raised when authentication fails.""" def get_auth_token() -> str: diff --git a/lib/crewai/src/crewai/cli/shared/token_manager.py b/lib/crewai/src/crewai/auth/token_manager.py similarity index 99% rename from lib/crewai/src/crewai/cli/shared/token_manager.py rename to lib/crewai/src/crewai/auth/token_manager.py index 02c176924..9f1807d33 100644 --- a/lib/crewai/src/crewai/cli/shared/token_manager.py +++ b/lib/crewai/src/crewai/auth/token_manager.py @@ -1,3 +1,5 @@ +"""Manages encrypted token storage.""" + from datetime import datetime import json import os diff --git a/lib/crewai/src/crewai/cli/authentication/utils.py b/lib/crewai/src/crewai/auth/utils.py similarity index 77% rename from lib/crewai/src/crewai/cli/authentication/utils.py rename to lib/crewai/src/crewai/auth/utils.py index 7311b9d42..c8e406793 100644 --- a/lib/crewai/src/crewai/cli/authentication/utils.py +++ b/lib/crewai/src/crewai/auth/utils.py @@ -1,3 +1,5 @@ +"""JWT token validation utilities.""" + from typing import Any import jwt @@ -7,18 +9,20 @@ from jwt import PyJWKClient def validate_jwt_token( jwt_token: str, jwks_url: str, issuer: str, audience: str ) -> Any: - """ - Verify the token's signature and claims using PyJWT. - :param jwt_token: The JWT (JWS) string to validate. - :param jwks_url: The URL of the JWKS endpoint. - :param issuer: The expected issuer of the token. - :param audience: The expected audience of the token. - :return: The decoded token. - :raises Exception: If the token is invalid for any reason (e.g., signature mismatch, - expired, incorrect issuer/audience, JWKS fetching error, - missing required claims). - """ + """Verify the token's signature and claims using PyJWT. + Args: + jwt_token: The JWT (JWS) string to validate. + jwks_url: The URL of the JWKS endpoint. + issuer: The expected issuer of the token. + audience: The expected audience of the token. + + Returns: + The decoded token. + + Raises: + Exception: If the token is invalid for any reason. + """ try: jwk_client = PyJWKClient(jwks_url) signing_key = jwk_client.get_signing_key_from_jwt(jwt_token) diff --git a/lib/crewai/src/crewai/cli/__init__.py b/lib/crewai/src/crewai/cli/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/lib/crewai/src/crewai/cli/authentication/__init__.py b/lib/crewai/src/crewai/cli/authentication/__init__.py deleted file mode 100644 index 98070be42..000000000 --- a/lib/crewai/src/crewai/cli/authentication/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from crewai.cli.authentication.main import AuthenticationCommand - - -__all__ = ["AuthenticationCommand"] diff --git a/lib/crewai/src/crewai/cli/authentication/constants.py b/lib/crewai/src/crewai/cli/authentication/constants.py deleted file mode 100644 index a9457b36a..000000000 --- a/lib/crewai/src/crewai/cli/authentication/constants.py +++ /dev/null @@ -1 +0,0 @@ -ALGORITHMS = ["RS256"] diff --git a/lib/crewai/src/crewai/cli/authentication/providers/__init__.py b/lib/crewai/src/crewai/cli/authentication/providers/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/lib/crewai/src/crewai/cli/command.py b/lib/crewai/src/crewai/cli/command.py deleted file mode 100644 index 139f69373..000000000 --- a/lib/crewai/src/crewai/cli/command.py +++ /dev/null @@ -1,76 +0,0 @@ -import json - -import httpx -from rich.console import Console - -from crewai.cli.authentication.token import get_auth_token -from crewai.cli.plus_api import PlusAPI -from crewai.telemetry.telemetry import Telemetry - - -console = Console() - - -class BaseCommand: - def __init__(self) -> None: - self._telemetry = Telemetry() - self._telemetry.set_tracer() - - -class PlusAPIMixin: - def __init__(self, telemetry: Telemetry) -> None: - try: - telemetry.set_tracer() - self.plus_api_client = PlusAPI(api_key=get_auth_token()) - except Exception: - telemetry.deploy_signup_error_span() - console.print( - "Please sign up/login to CrewAI+ before using the CLI.", - style="bold red", - ) - console.print("Run 'crewai login' to sign up/login.", style="bold green") - raise SystemExit from None - - def _validate_response(self, response: httpx.Response) -> None: - """ - Handle and display error messages from API responses. - - Args: - response (httpx.Response): The response from the Plus API - """ - try: - json_response = response.json() - except (json.JSONDecodeError, ValueError): - console.print( - "Failed to parse response from Enterprise API failed. Details:", - style="bold red", - ) - console.print(f"Status Code: {response.status_code}") - console.print( - f"Response:\n{response.content.decode('utf-8', errors='replace')}" - ) - raise SystemExit from None - - if response.status_code == 422: - console.print( - "Failed to complete operation. Please fix the following errors:", - style="bold red", - ) - for field, messages in json_response.items(): - for message in messages: - console.print( - f"* [bold red]{field.capitalize()}[/bold red] {message}" - ) - raise SystemExit - - if not response.is_success: - console.print( - "Request to Enterprise API failed. Details:", style="bold red" - ) - details = ( - json_response.get("error") - or json_response.get("message") - or response.content.decode("utf-8", errors="replace") - ) - console.print(f"{details}") - raise SystemExit diff --git a/lib/crewai/src/crewai/cli/git.py b/lib/crewai/src/crewai/cli/git.py deleted file mode 100644 index fb08c391a..000000000 --- a/lib/crewai/src/crewai/cli/git.py +++ /dev/null @@ -1,89 +0,0 @@ -from functools import lru_cache -import subprocess - - -class Repository: - def __init__(self, path: str = ".") -> None: - self.path = path - - if not self.is_git_installed(): - raise ValueError("Git is not installed or not found in your PATH.") - - if not self.is_git_repo(): - raise ValueError(f"{self.path} is not a Git repository.") - - self.fetch() - - @staticmethod - def is_git_installed() -> bool: - """Check if Git is installed and available in the system.""" - try: - subprocess.run( - ["git", "--version"], # noqa: S607 - capture_output=True, - check=True, - text=True, - ) - return True - except (subprocess.CalledProcessError, FileNotFoundError): - return False - - def fetch(self) -> None: - """Fetch latest updates from the remote.""" - subprocess.run(["git", "fetch"], cwd=self.path, check=True) # noqa: S607 - - def status(self) -> str: - """Get the git status in porcelain format.""" - return subprocess.check_output( - ["git", "status", "--branch", "--porcelain"], # noqa: S607 - cwd=self.path, - encoding="utf-8", - ).strip() - - @lru_cache(maxsize=None) # noqa: B019 - def is_git_repo(self) -> bool: - """Check if the current directory is a git repository. - - Notes: - - TODO: This method is cached to avoid redundant checks, but using lru_cache on methods can lead to memory leaks - """ - try: - subprocess.check_output( - ["git", "rev-parse", "--is-inside-work-tree"], # noqa: S607 - cwd=self.path, - encoding="utf-8", - ) - return True - except subprocess.CalledProcessError: - return False - - def has_uncommitted_changes(self) -> bool: - """Check if the repository has uncommitted changes.""" - return len(self.status().splitlines()) > 1 - - def is_ahead_or_behind(self) -> bool: - """Check if the repository is ahead or behind the remote.""" - for line in self.status().splitlines(): - if line.startswith("##") and ("ahead" in line or "behind" in line): - return True - return False - - def is_synced(self) -> bool: - """Return True if the Git repository is fully synced with the remote, False otherwise.""" - if self.has_uncommitted_changes() or self.is_ahead_or_behind(): - return False - return True - - def origin_url(self) -> str | None: - """Get the Git repository's remote URL.""" - try: - result = subprocess.run( - ["git", "remote", "get-url", "origin"], # noqa: S607 - cwd=self.path, - capture_output=True, - text=True, - check=True, - ) - return result.stdout.strip() - except subprocess.CalledProcessError: - return None diff --git a/lib/crewai/src/crewai/cli/provider.py b/lib/crewai/src/crewai/cli/provider.py deleted file mode 100644 index 1f1e4ec40..000000000 --- a/lib/crewai/src/crewai/cli/provider.py +++ /dev/null @@ -1,231 +0,0 @@ -from collections import defaultdict -from collections.abc import Sequence -import json -import os -from pathlib import Path -import time -from typing import Any - -import certifi -import click -import httpx - -from crewai.cli.constants import JSON_URL, MODELS, PROVIDERS - - -def select_choice(prompt_message: str, choices: Sequence[str]) -> str | None: - """Presents a list of choices to the user and prompts them to select one. - - Args: - prompt_message: The message to display to the user before presenting the choices. - choices: A list of options to present to the user. - - Returns: - The selected choice from the list, or None if the user chooses to quit. - """ - - provider_models = get_provider_data() - if not provider_models: - return None - click.secho(prompt_message, fg="cyan") - for idx, choice in enumerate(choices, start=1): - click.secho(f"{idx}. {choice}", fg="cyan") - click.secho("q. Quit", fg="cyan") - - while True: - choice = click.prompt( - "Enter the number of your choice or 'q' to quit", type=str - ) - - if choice.lower() == "q": - return None - - try: - selected_index = int(choice) - 1 - if 0 <= selected_index < len(choices): - return choices[selected_index] - except ValueError: - pass - - click.secho( - "Invalid selection. Please select a number between 1 and 6 or 'q' to quit.", - fg="red", - ) - - -def select_provider(provider_models: dict[str, list[str]]) -> str | None | bool: - """Presents a list of providers to the user and prompts them to select one. - - Args: - provider_models: A dictionary of provider models. - - Returns: - The selected provider, None if user explicitly quits, or False if no selection. - """ - predefined_providers = [p.lower() for p in PROVIDERS] - all_providers = sorted(set(predefined_providers + list(provider_models.keys()))) - - provider = select_choice( - "Select a provider to set up:", [*predefined_providers, "other"] - ) - if provider is None: # User typed 'q' - return None - - if provider == "other": - provider = select_choice("Select a provider from the full list:", all_providers) - if provider is None: # User typed 'q' - return None - - return provider.lower() if provider else False - - -def select_model(provider: str, provider_models: dict[str, list[str]]) -> str | None: - """Presents a list of models for a given provider to the user and prompts them to select one. - - Args: - provider: The provider for which to select a model. - provider_models: A dictionary of provider models. - - Returns: - The selected model, or None if the operation is aborted or an invalid selection is made. - """ - predefined_providers = [p.lower() for p in PROVIDERS] - - if provider in predefined_providers: - available_models = MODELS.get(provider, []) - else: - available_models = provider_models.get(provider, []) - - if not available_models: - click.secho(f"No models available for provider '{provider}'.", fg="red") - return None - - return select_choice( - f"Select a model to use for {provider.capitalize()}:", available_models - ) - - -def load_provider_data(cache_file: Path, cache_expiry: int) -> dict[str, Any] | None: - """Loads provider data from a cache file if it exists and is not expired. - - If the cache is expired or corrupted, it fetches the data from the web. - - Args: - cache_file: The path to the cache file. - cache_expiry: The cache expiry time in seconds. - - Returns: - The loaded provider data or None if the operation fails. - """ - current_time = time.time() - if ( - cache_file.exists() - and (current_time - cache_file.stat().st_mtime) < cache_expiry - ): - data = read_cache_file(cache_file) - if data: - return data - click.secho( - "Cache is corrupted. Fetching provider data from the web...", fg="yellow" - ) - else: - click.secho( - "Cache expired or not found. Fetching provider data from the web...", - fg="cyan", - ) - return fetch_provider_data(cache_file) - - -def read_cache_file(cache_file: Path) -> dict[str, Any] | None: - """Reads and returns the JSON content from a cache file. - - Args: - cache_file: The path to the cache file. - - Returns: - The JSON content of the cache file or None if the JSON is invalid. - """ - try: - with open(cache_file, "r") as f: - data: dict[str, Any] = json.load(f) - return data - except json.JSONDecodeError: - return None - - -def fetch_provider_data(cache_file: Path) -> dict[str, Any] | None: - """Fetches provider data from a specified URL and caches it to a file. - - Args: - cache_file: The path to the cache file. - - Returns: - The fetched provider data or None if the operation fails. - """ - ssl_config = os.environ["SSL_CERT_FILE"] = certifi.where() - - try: - with httpx.stream("GET", JSON_URL, timeout=60, verify=ssl_config) as response: - response.raise_for_status() - data = download_data(response) - with open(cache_file, "w") as f: - json.dump(data, f) - return data - except httpx.HTTPError as e: - click.secho(f"Error fetching provider data: {e}", fg="red") - except json.JSONDecodeError: - click.secho("Error parsing provider data. Invalid JSON format.", fg="red") - return None - - -def download_data(response: httpx.Response) -> dict[str, Any]: - """Downloads data from a given HTTP response and returns the JSON content. - - Args: - response: The HTTP response object. - - Returns: - The JSON content of the response. - """ - total_size = int(response.headers.get("content-length", 0)) - block_size = 8192 - data_chunks: list[bytes] = [] - bar: Any - with click.progressbar( - length=total_size, label="Downloading", show_pos=True - ) as bar: - for chunk in response.iter_bytes(block_size): - if chunk: - data_chunks.append(chunk) - bar.update(len(chunk)) - data_content = b"".join(data_chunks) - result: dict[str, Any] = json.loads(data_content.decode("utf-8")) - return result - - -def get_provider_data() -> dict[str, list[str]] | None: - """Retrieves provider data from a cache file. - - Filters out models based on provider criteria, and returns a dictionary of providers - mapped to their models. - - Returns: - A dictionary of providers mapped to their models or None if the operation fails. - """ - cache_dir = Path.home() / ".crewai" - cache_dir.mkdir(exist_ok=True) - cache_file = cache_dir / "provider_cache.json" - cache_expiry = 24 * 3600 - - data = load_provider_data(cache_file, cache_expiry) - if not data: - return None - - provider_models = defaultdict(list) - for model_name, properties in data.items(): - provider = properties.get("litellm_provider", "").strip().lower() - if "http" in provider or provider == "other": - continue - if provider: - provider_models[provider].append(model_name) - return provider_models diff --git a/lib/crewai/src/crewai/cli/shared/__init__.py b/lib/crewai/src/crewai/cli/shared/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/lib/crewai/src/crewai/cli/constants.py b/lib/crewai/src/crewai/constants.py similarity index 99% rename from lib/crewai/src/crewai/cli/constants.py rename to lib/crewai/src/crewai/constants.py index 2ef8dcc7f..33b012666 100644 --- a/lib/crewai/src/crewai/cli/constants.py +++ b/lib/crewai/src/crewai/constants.py @@ -1,3 +1,5 @@ +"""CrewAI constants.""" + from typing import Any diff --git a/lib/crewai/src/crewai/events/listeners/tracing/trace_batch_manager.py b/lib/crewai/src/crewai/events/listeners/tracing/trace_batch_manager.py index da25792fb..231ade0a8 100644 --- a/lib/crewai/src/crewai/events/listeners/tracing/trace_batch_manager.py +++ b/lib/crewai/src/crewai/events/listeners/tracing/trace_batch_manager.py @@ -8,17 +8,17 @@ import uuid from rich.console import Console from rich.panel import Panel -from crewai.cli.authentication.token import AuthError, get_auth_token -from crewai.cli.config import Settings -from crewai.cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL -from crewai.cli.plus_api import PlusAPI -from crewai.cli.version import get_crewai_version +from crewai.auth.token import AuthError, get_auth_token +from crewai.constants import DEFAULT_CREWAI_ENTERPRISE_URL from crewai.events.listeners.tracing.types import TraceEvent from crewai.events.listeners.tracing.utils import ( get_user_id, is_tracing_enabled_in_context, should_auto_collect_first_time_traces, ) +from crewai.plus_api import PlusAPI +from crewai.settings import Settings +from crewai.version import get_crewai_version logger = getLogger(__name__) diff --git a/lib/crewai/src/crewai/events/listeners/tracing/trace_listener.py b/lib/crewai/src/crewai/events/listeners/tracing/trace_listener.py index 0e4d7d8a2..b2ca172d3 100644 --- a/lib/crewai/src/crewai/events/listeners/tracing/trace_listener.py +++ b/lib/crewai/src/crewai/events/listeners/tracing/trace_listener.py @@ -6,8 +6,7 @@ import uuid from typing_extensions import Self -from crewai.cli.authentication.token import AuthError, get_auth_token -from crewai.cli.version import get_crewai_version +from crewai.auth.token import AuthError, get_auth_token from crewai.events.base_event_listener import BaseEventListener from crewai.events.base_events import BaseEvent from crewai.events.event_bus import CrewAIEventsBus @@ -110,6 +109,7 @@ from crewai.events.types.tool_usage_events import ( ToolUsageStartedEvent, ) from crewai.events.utils.console_formatter import ConsoleFormatter +from crewai.version import get_crewai_version class TraceCollectionListener(BaseEventListener): diff --git a/lib/crewai/src/crewai/events/utils/console_formatter.py b/lib/crewai/src/crewai/events/utils/console_formatter.py index a3019ffcf..b86455b67 100644 --- a/lib/crewai/src/crewai/events/utils/console_formatter.py +++ b/lib/crewai/src/crewai/events/utils/console_formatter.py @@ -8,7 +8,7 @@ from rich.live import Live from rich.panel import Panel from rich.text import Text -from crewai.cli.version import is_current_version_yanked, is_newer_version_available +from crewai.version import is_current_version_yanked, is_newer_version_available _disable_version_check: ContextVar[bool] = ContextVar( diff --git a/lib/crewai/src/crewai/mcp/tool_resolver.py b/lib/crewai/src/crewai/mcp/tool_resolver.py index 2ef7364ac..cc1c8f387 100644 --- a/lib/crewai/src/crewai/mcp/tool_resolver.py +++ b/lib/crewai/src/crewai/mcp/tool_resolver.py @@ -195,7 +195,7 @@ class MCPToolResolver: get_platform_integration_token, ) - from crewai.cli.plus_api import PlusAPI + from crewai.plus_api import PlusAPI plus_api = PlusAPI(api_key=get_platform_integration_token()) response = plus_api.get_mcp_configs(slugs) @@ -285,6 +285,7 @@ class MCPToolResolver: independent transport so that parallel tool executions never share state. """ + transport: StdioTransport | HTTPTransport | SSETransport if isinstance(mcp_config, MCPServerStdio): transport = StdioTransport( command=mcp_config.command, diff --git a/lib/crewai/src/crewai/cli/plus_api.py b/lib/crewai/src/crewai/plus_api.py similarity index 96% rename from lib/crewai/src/crewai/cli/plus_api.py rename to lib/crewai/src/crewai/plus_api.py index e32e5220d..64ecfe4e5 100644 --- a/lib/crewai/src/crewai/cli/plus_api.py +++ b/lib/crewai/src/crewai/plus_api.py @@ -1,18 +1,18 @@ +"""CrewAI+ API client.""" + import os from typing import Any from urllib.parse import urljoin import httpx -from crewai.cli.config import Settings -from crewai.cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL -from crewai.cli.version import get_crewai_version +from crewai.constants import DEFAULT_CREWAI_ENTERPRISE_URL +from crewai.settings import Settings +from crewai.version import get_crewai_version class PlusAPI: - """ - This class exposes methods for working with the CrewAI+ API. - """ + """Client for working with the CrewAI+ API.""" TOOLS_RESOURCE = "/crewai_plus/api/v1/tools" ORGANIZATIONS_RESOURCE = "/crewai_plus/api/v1/me/organizations" diff --git a/lib/crewai/src/crewai/cli/config.py b/lib/crewai/src/crewai/settings.py similarity index 78% rename from lib/crewai/src/crewai/cli/config.py rename to lib/crewai/src/crewai/settings.py index d156d8488..66ee7080d 100644 --- a/lib/crewai/src/crewai/cli/config.py +++ b/lib/crewai/src/crewai/settings.py @@ -1,3 +1,5 @@ +"""CrewAI platform settings management.""" + import json from logging import getLogger from pathlib import Path @@ -6,14 +8,14 @@ from typing import Any from pydantic import BaseModel, Field -from crewai.cli.constants import ( +from crewai.auth.token_manager import TokenManager +from crewai.constants import ( CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE, CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID, CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN, CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER, DEFAULT_CREWAI_ENTERPRISE_URL, ) -from crewai.cli.shared.token_manager import TokenManager logger = getLogger(__name__) @@ -22,8 +24,7 @@ DEFAULT_CONFIG_PATH = Path.home() / ".config" / "crewai" / "settings.json" def get_writable_config_path() -> Path | None: - """ - Find a writable location for the config file with fallback options. + """Find a writable location for the config file with fallback options. Tries in order: 1. Default: ~/.config/crewai/settings.json @@ -32,12 +33,12 @@ def get_writable_config_path() -> Path | None: 4. In-memory only (returns None) Returns: - Path object for writable config location, or None if no writable location found + Path object for writable config location, or None if no writable location found. """ fallback_paths = [ - DEFAULT_CONFIG_PATH, # Default location - Path(tempfile.gettempdir()) / "crewai_settings.json", # Temporary directory - Path.cwd() / "crewai_settings.json", # Current working directory + DEFAULT_CONFIG_PATH, + Path(tempfile.gettempdir()) / "crewai_settings.json", + Path.cwd() / "crewai_settings.json", ] for config_path in fallback_paths: @@ -46,7 +47,7 @@ def get_writable_config_path() -> Path | None: test_file = config_path.parent / ".crewai_write_test" try: test_file.write_text("test") - test_file.unlink() # Clean up test file + test_file.unlink() logger.info(f"Using config path: {config_path}") return config_path except Exception: # noqa: S112 @@ -58,7 +59,6 @@ def get_writable_config_path() -> Path | None: return None -# Settings that are related to the user's account USER_SETTINGS_KEYS = [ "tool_repository_username", "tool_repository_password", @@ -66,7 +66,6 @@ USER_SETTINGS_KEYS = [ "org_uuid", ] -# Settings that are related to the CLI CLI_SETTINGS_KEYS = [ "enterprise_base_url", "oauth2_provider", @@ -76,7 +75,6 @@ CLI_SETTINGS_KEYS = [ "oauth2_extra", ] -# Default values for CLI settings DEFAULT_CLI_SETTINGS = { "enterprise_base_url": DEFAULT_CREWAI_ENTERPRISE_URL, "oauth2_provider": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER, @@ -86,13 +84,11 @@ DEFAULT_CLI_SETTINGS = { "oauth2_extra": {}, } -# Readonly settings - cannot be set by the user READONLY_SETTINGS_KEYS = [ "org_name", "org_uuid", ] -# Hidden settings - not displayed by the 'list' command and cannot be set by the user HIDDEN_SETTINGS_KEYS = [ "config_path", "tool_repository_username", @@ -101,8 +97,10 @@ HIDDEN_SETTINGS_KEYS = [ class Settings(BaseModel): + """CrewAI platform settings.""" + enterprise_base_url: str | None = Field( - default=DEFAULT_CLI_SETTINGS["enterprise_base_url"], + default=DEFAULT_CREWAI_ENTERPRISE_URL, description="Base URL of the CrewAI AMP instance", ) tool_repository_username: str | None = Field( @@ -121,22 +119,22 @@ class Settings(BaseModel): oauth2_provider: str = Field( description="OAuth2 provider used for authentication (e.g., workos, okta, auth0).", - default=DEFAULT_CLI_SETTINGS["oauth2_provider"], + default=CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER, ) oauth2_audience: str | None = Field( description="OAuth2 audience value, typically used to identify the target API or resource.", - default=DEFAULT_CLI_SETTINGS["oauth2_audience"], + default=CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE, ) oauth2_client_id: str = Field( - default=DEFAULT_CLI_SETTINGS["oauth2_client_id"], + default=CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID, description="OAuth2 client ID issued by the provider, used during authentication requests.", ) oauth2_domain: str = Field( description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens.", - default=DEFAULT_CLI_SETTINGS["oauth2_domain"], + default=CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN, ) oauth2_extra: dict[str, Any] = Field( @@ -145,14 +143,12 @@ class Settings(BaseModel): ) def __init__(self, config_path: Path | None = None, **data: dict[str, Any]) -> None: - """Load Settings from config path with fallback support""" + """Load Settings from config path with fallback support.""" if config_path is None: config_path = get_writable_config_path() - # If config_path is None, we're in memory-only mode if config_path is None: merged_data = {**data} - # Dummy path for memory-only mode super().__init__(config_path=Path("/dev/null"), **merged_data) return @@ -160,7 +156,6 @@ class Settings(BaseModel): config_path.parent.mkdir(parents=True, exist_ok=True) except Exception: merged_data = {**data} - # Dummy path for memory-only mode super().__init__(config_path=Path("/dev/null"), **merged_data) return @@ -176,19 +171,19 @@ class Settings(BaseModel): super().__init__(config_path=config_path, **merged_data) def clear_user_settings(self) -> None: - """Clear all user settings""" + """Clear all user settings.""" self._reset_user_settings() self.dump() def reset(self) -> None: - """Reset all settings to default values""" + """Reset all settings to default values.""" self._reset_user_settings() self._reset_cli_settings() self._clear_auth_tokens() self.dump() def dump(self) -> None: - """Save current settings to settings.json""" + """Save current settings to settings.json.""" if str(self.config_path) == "/dev/null": return @@ -207,15 +202,15 @@ class Settings(BaseModel): pass def _reset_user_settings(self) -> None: - """Reset all user settings to default values""" + """Reset all user settings to default values.""" for key in USER_SETTINGS_KEYS: setattr(self, key, None) def _reset_cli_settings(self) -> None: - """Reset all CLI settings to default values""" + """Reset all CLI settings to default values.""" for key in CLI_SETTINGS_KEYS: setattr(self, key, DEFAULT_CLI_SETTINGS.get(key)) def _clear_auth_tokens(self) -> None: - """Clear all authentication tokens""" + """Clear all authentication tokens.""" TokenManager().clear_tokens() diff --git a/lib/crewai/src/crewai/utilities/agent_utils.py b/lib/crewai/src/crewai/utilities/agent_utils.py index e0aee388b..700cb1bd1 100644 --- a/lib/crewai/src/crewai/utilities/agent_utils.py +++ b/lib/crewai/src/crewai/utilities/agent_utils.py @@ -19,8 +19,8 @@ from crewai.agents.parser import ( OutputParserError, parse, ) -from crewai.cli.config import Settings from crewai.llms.base_llm import BaseLLM +from crewai.settings import Settings from crewai.tools import BaseTool as CrewAITool from crewai.tools.base_tool import BaseTool from crewai.tools.structured_tool import CrewStructuredTool @@ -1047,8 +1047,8 @@ def load_agent_from_repository(from_repository: str) -> dict[str, Any]: if callable(_create_plus_client_hook): client = _create_plus_client_hook() else: - from crewai.cli.authentication.token import get_auth_token - from crewai.cli.plus_api import PlusAPI + from crewai.auth.token import get_auth_token + from crewai.plus_api import PlusAPI client = PlusAPI(api_key=get_auth_token()) _print_current_organization() diff --git a/lib/crewai/src/crewai/cli/crew_chat.py b/lib/crewai/src/crewai/utilities/crew_chat.py similarity index 80% rename from lib/crewai/src/crewai/cli/crew_chat.py rename to lib/crewai/src/crewai/utilities/crew_chat.py index bbbd51c0c..7d3b35397 100644 --- a/lib/crewai/src/crewai/cli/crew_chat.py +++ b/lib/crewai/src/crewai/utilities/crew_chat.py @@ -1,3 +1,5 @@ +"""Interactive chat interface for CrewAI crews.""" + import contextvars import json from pathlib import Path @@ -12,15 +14,15 @@ import click from packaging import version import tomli -from crewai.cli.utils import read_toml -from crewai.cli.version import get_crewai_version from crewai.crew import Crew from crewai.llm import LLM from crewai.llms.base_llm import BaseLLM from crewai.types.crew_chat import ChatInputField, ChatInputs from crewai.utilities.llm_utils import create_llm from crewai.utilities.printer import Printer +from crewai.utilities.project_utils import read_toml from crewai.utilities.types import LLMMessage +from crewai.version import get_crewai_version _printer = Printer() @@ -31,15 +33,14 @@ MIN_REQUIRED_VERSION: Final[Literal["0.98.0"]] = "0.98.0" def check_conversational_crews_version( crewai_version: str, pyproject_data: dict[str, Any] ) -> bool: - """ - Check if the installed crewAI version supports conversational crews. + """Check if the installed crewAI version supports conversational crews. Args: crewai_version: The current version of crewAI. pyproject_data: Dictionary containing pyproject.toml data. Returns: - bool: True if version check passes, False otherwise. + True if version check passes, False otherwise. """ try: if version.parse(crewai_version) < version.parse(MIN_REQUIRED_VERSION): @@ -56,8 +57,8 @@ def check_conversational_crews_version( def run_chat() -> None: - """ - Runs an interactive chat loop using the Crew's chat LLM with function calling. + """Run an interactive chat loop using the Crew's chat LLM with function calling. + Incorporates crew_name, crew_description, and input fields to build a tool schema. Exits if crew_name or crew_description are missing. """ @@ -72,14 +73,12 @@ def run_chat() -> None: if not chat_llm: return - # Indicate that the crew is being analyzed click.secho( "\nAnalyzing crew and required inputs - this may take 3 to 30 seconds " "depending on the complexity of your crew.", fg="white", ) - # Start loading indicator loading_complete = threading.Event() ctx = contextvars.copy_context() loading_thread = threading.Thread( @@ -92,16 +91,13 @@ def run_chat() -> None: crew_tool_schema = generate_crew_tool_schema(crew_chat_inputs) system_message = build_system_message(crew_chat_inputs) - # Call the LLM to generate the introductory message introductory_message = chat_llm.call( messages=[{"role": "system", "content": system_message}] ) finally: - # Stop loading indicator loading_complete.set() loading_thread.join() - # Indicate that the analysis is complete click.secho("\nFinished analyzing crew.\n", fg="white") click.secho(f"Assistant: {introductory_message}\n", fg="green") @@ -127,7 +123,7 @@ def show_loading(event: threading.Event) -> None: def initialize_chat_llm(crew: Crew) -> LLM | BaseLLM | None: - """Initializes the chat LLM and handles exceptions.""" + """Initialize the chat LLM and handle exceptions.""" try: return create_llm(crew.chat_llm) except Exception as e: @@ -139,7 +135,7 @@ def initialize_chat_llm(crew: Crew) -> LLM | BaseLLM | None: def build_system_message(crew_chat_inputs: ChatInputs) -> str: - """Builds the initial system message for the chat.""" + """Build the initial system message for the chat.""" required_fields_str = ( ", ".join( f"{field.name} (desc: {field.description or 'n/a'})" @@ -168,7 +164,7 @@ def build_system_message(crew_chat_inputs: ChatInputs) -> str: def create_tool_function(crew: Crew, messages: list[LLMMessage]) -> Any: - """Creates a wrapper function for running the crew tool with messages.""" + """Create a wrapper function for running the crew tool with messages.""" def run_crew_tool_with_messages(**kwargs: Any) -> str: return run_crew_tool(crew, messages, **kwargs) @@ -179,13 +175,11 @@ def create_tool_function(crew: Crew, messages: list[LLMMessage]) -> Any: def flush_input() -> None: """Flush any pending input from the user.""" if platform.system() == "Windows": - # Windows platform import msvcrt while msvcrt.kbhit(): # type: ignore[attr-defined] msvcrt.getch() # type: ignore[attr-defined] else: - # Unix-like platforms (Linux, macOS) import termios termios.tcflush(sys.stdin, termios.TCIFLUSH) @@ -200,7 +194,6 @@ def chat_loop( """Main chat loop for interacting with the user.""" while True: try: - # Flush any pending input before accepting new input flush_input() user_input = get_user_input() @@ -250,11 +243,9 @@ def handle_user_input( messages.append({"role": "user", "content": user_input}) - # Indicate that assistant is processing click.echo() click.secho("Assistant is processing your input. Please wait...", fg="green") - # Process assistant's response final_response = chat_llm.call( messages=messages, tools=[crew_tool_schema], @@ -266,12 +257,11 @@ def handle_user_input( def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict[str, Any]: - """ - Dynamically build a Littellm 'function' schema for the given crew. + """Dynamically build a Littellm 'function' schema for the given crew. - crew_name: The name of the crew (used for the function 'name'). - crew_inputs: A ChatInputs object containing crew_description - and a list of input fields (each with a name & description). + Args: + crew_inputs: A ChatInputs object containing crew_description + and a list of input fields (each with a name & description). """ properties = {} for field in crew_inputs.inputs: @@ -297,70 +287,51 @@ def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict[str, Any]: def run_crew_tool(crew: Crew, messages: list[LLMMessage], **kwargs: Any) -> str: - """ - Runs the crew using crew.kickoff(inputs=kwargs) and returns the output. + """Run the crew using crew.kickoff(inputs=kwargs) and return the output. Args: - crew (Crew): The crew instance to run. - messages (List[Dict[str, str]]): The chat messages up to this point. + crew: The crew instance to run. + messages: The chat messages up to this point. **kwargs: The inputs collected from the user. Returns: - str: The output from the crew's execution. - - Raises: - SystemExit: Exits the chat if an error occurs during crew execution. + The output from the crew's execution. """ try: - # Serialize 'messages' to JSON string before adding to kwargs kwargs["crew_chat_messages"] = json.dumps(messages) - - # Run the crew with the provided inputs crew_output = crew.kickoff(inputs=kwargs) - - # Convert CrewOutput to a string to send back to the user return str(crew_output) except Exception as e: - # Exit the chat and show the error message click.secho("An error occurred while running the crew:", fg="red") click.secho(str(e), fg="red") sys.exit(1) def load_crew_and_name() -> tuple[Crew, str]: - """ - Loads the crew by importing the crew class from the user's project. + """Load the crew by importing the crew class from the user's project. Returns: - Tuple[Crew, str]: A tuple containing the Crew instance and the name of the crew. + A tuple containing the Crew instance and the name of the crew. """ - # Get the current working directory cwd = Path.cwd() - # Path to the pyproject.toml file pyproject_path = cwd / "pyproject.toml" if not pyproject_path.exists(): raise FileNotFoundError("pyproject.toml not found in the current directory.") - # Load the pyproject.toml file using 'tomli' with pyproject_path.open("rb") as f: pyproject_data = tomli.load(f) - # Get the project name from the 'project' section project_name = pyproject_data["project"]["name"] folder_name = project_name - # Derive the crew class name from the project name - # E.g., if project_name is 'my_project', crew_class_name is 'MyProject' crew_class_name = project_name.replace("_", " ").title().replace(" ", "") - # Add the 'src' directory to sys.path src_path = cwd / "src" if str(src_path) not in sys.path: sys.path.insert(0, str(src_path)) - # Import the crew module crew_module_name = f"{folder_name}.crew" try: crew_module = __import__(crew_module_name, fromlist=[crew_class_name]) @@ -369,7 +340,6 @@ def load_crew_and_name() -> tuple[Crew, str]: f"Failed to import crew module {crew_module_name}: {e}" ) from e - # Get the crew class from the module try: crew_class = getattr(crew_module, crew_class_name) except AttributeError as e: @@ -377,7 +347,6 @@ def load_crew_and_name() -> tuple[Crew, str]: f"Crew class {crew_class_name} not found in module {crew_module_name}" ) from e - # Instantiate the crew crew_instance = crew_class().crew() return crew_instance, crew_class_name @@ -385,27 +354,23 @@ def load_crew_and_name() -> tuple[Crew, str]: def generate_crew_chat_inputs( crew: Crew, crew_name: str, chat_llm: LLM | BaseLLM ) -> ChatInputs: - """ - Generates the ChatInputs required for the crew by analyzing the tasks and agents. + """Generate the ChatInputs required for the crew by analyzing the tasks and agents. Args: - crew (Crew): The crew object containing tasks and agents. - crew_name (str): The name of the crew. + crew: The crew object containing tasks and agents. + crew_name: The name of the crew. chat_llm: The chat language model to use for AI calls. Returns: - ChatInputs: An object containing the crew's name, description, and input fields. + An object containing the crew's name, description, and input fields. """ - # Extract placeholders from tasks and agents required_inputs = fetch_required_inputs(crew) - # Generate descriptions for each input using AI input_fields = [] for input_name in required_inputs: description = generate_input_description_with_ai(input_name, crew, chat_llm) input_fields.append(ChatInputField(name=input_name, description=description)) - # Generate crew description using AI crew_description = generate_crew_description_with_ai(crew, chat_llm) return ChatInputs( @@ -414,13 +379,13 @@ def generate_crew_chat_inputs( def fetch_required_inputs(crew: Crew) -> set[str]: - """Extracts placeholders from the crew's tasks and agents. + """Extract placeholders from the crew's tasks and agents. Args: - crew (Crew): The crew object. + crew: The crew object. Returns: - Set[str]: A set of placeholder names. + A set of placeholder names. """ return crew.fetch_inputs() @@ -428,18 +393,16 @@ def fetch_required_inputs(crew: Crew) -> set[str]: def generate_input_description_with_ai( input_name: str, crew: Crew, chat_llm: LLM | BaseLLM ) -> str: - """ - Generates an input description using AI based on the context of the crew. + """Generate an input description using AI based on the context of the crew. Args: - input_name (str): The name of the input placeholder. - crew (Crew): The crew object. + input_name: The name of the input placeholder. + crew: The crew object. chat_llm: The chat language model to use for AI calls. Returns: - str: A concise description of the input. + A concise description of the input. """ - # Gather context from tasks and agents where the input is used context_texts = [] placeholder_pattern = re.compile(r"\{(.+?)}") @@ -448,7 +411,6 @@ def generate_input_description_with_ai( f"{{{input_name}}}" in task.description or f"{{{input_name}}}" in task.expected_output ): - # Replace placeholders with input names task_description = placeholder_pattern.sub( lambda m: m.group(1), task.description or "" ) @@ -463,7 +425,6 @@ def generate_input_description_with_ai( or f"{{{input_name}}}" in agent.goal or f"{{{input_name}}}" in agent.backstory ): - # Replace placeholders with input names agent_role = placeholder_pattern.sub(lambda m: m.group(1), agent.role or "") agent_goal = placeholder_pattern.sub(lambda m: m.group(1), agent.goal or "") agent_backstory = placeholder_pattern.sub( @@ -475,7 +436,6 @@ def generate_input_description_with_ai( context = "\n".join(context_texts) if not context: - # If no context is found for the input, raise an exception as per instruction raise ValueError(f"No context found for input '{input_name}'.") prompt = ( @@ -489,22 +449,19 @@ def generate_input_description_with_ai( def generate_crew_description_with_ai(crew: Crew, chat_llm: LLM | BaseLLM) -> str: - """ - Generates a brief description of the crew using AI. + """Generate a brief description of the crew using AI. Args: - crew (Crew): The crew object. + crew: The crew object. chat_llm: The chat language model to use for AI calls. Returns: - str: A concise description of the crew's purpose (15 words or less). + A concise description of the crew's purpose (15 words or less). """ - # Gather context from tasks and agents context_texts = [] placeholder_pattern = re.compile(r"\{(.+?)}") for task in crew.tasks: - # Replace placeholders with input names task_description = placeholder_pattern.sub( lambda m: m.group(1), task.description or "" ) @@ -514,7 +471,6 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm: LLM | BaseLLM) -> st context_texts.append(f"Task Description: {task_description}") context_texts.append(f"Expected Output: {expected_output}") for agent in crew.agents: - # Replace placeholders with input names agent_role = placeholder_pattern.sub(lambda m: m.group(1), agent.role or "") agent_goal = placeholder_pattern.sub(lambda m: m.group(1), agent.goal or "") agent_backstory = placeholder_pattern.sub( diff --git a/lib/crewai/src/crewai/utilities/llm_utils.py b/lib/crewai/src/crewai/utilities/llm_utils.py index 55a42968a..91c582b2f 100644 --- a/lib/crewai/src/crewai/utilities/llm_utils.py +++ b/lib/crewai/src/crewai/utilities/llm_utils.py @@ -2,7 +2,7 @@ import logging import os from typing import Any, Final -from crewai.cli.constants import DEFAULT_LLM_MODEL, ENV_VARS, LITELLM_PARAMS +from crewai.constants import DEFAULT_LLM_MODEL, ENV_VARS, LITELLM_PARAMS from crewai.llm import LLM from crewai.llms.base_llm import BaseLLM diff --git a/lib/crewai/src/crewai/cli/utils.py b/lib/crewai/src/crewai/utilities/project_utils.py similarity index 67% rename from lib/crewai/src/crewai/cli/utils.py rename to lib/crewai/src/crewai/utilities/project_utils.py index 714130632..72d3e2aef 100644 --- a/lib/crewai/src/crewai/cli/utils.py +++ b/lib/crewai/src/crewai/utilities/project_utils.py @@ -1,20 +1,19 @@ +"""Project utility functions for discovering crews, flows, and tools.""" + from functools import reduce import importlib.util from inspect import getmro, isclass, isfunction, ismethod import os from pathlib import Path -import shutil import sys from typing import Any, cast, get_type_hints -import click from rich.console import Console import tomli -from crewai.cli.config import Settings -from crewai.cli.constants import ENV_VARS from crewai.crew import Crew from crewai.flow import Flow +from crewai.settings import Settings if sys.version_info >= (3, 11): @@ -23,25 +22,6 @@ if sys.version_info >= (3, 11): console = Console() -def copy_template( - src: Path, dst: Path, name: str, class_name: str, folder_name: str -) -> None: - """Copy a file from src to dst.""" - with open(src, "r") as file: - content = file.read() - - # Interpolate the content - content = content.replace("{{name}}", name) - content = content.replace("{{crew_name}}", class_name) - content = content.replace("{{folder_name}}", folder_name) - - # Write the interpolated content to the new file - with open(dst, "w") as file: - file.write(content) - - click.secho(f" - Created {dst}", fg="green") - - def read_toml(file_path: str = "pyproject.toml") -> dict[str, Any]: """Read the content of a TOML file and return it as a dictionary.""" with open(file_path, "rb") as f: @@ -49,6 +29,7 @@ def read_toml(file_path: str = "pyproject.toml") -> dict[str, Any]: def parse_toml(content: str) -> dict[str, Any]: + """Parse a TOML string and return it as a dictionary.""" if sys.version_info >= (3, 11): return tomllib.loads(content) return tomli.loads(content) @@ -104,7 +85,6 @@ def _get_project_attribute( style="bold red", ) except Exception as e: - # Handle TOML decode errors for Python 3.11+ if sys.version_info >= (3, 11) and isinstance(e, tomllib.TOMLDecodeError): console.print( f"Error: {pyproject_path} is not a valid TOML file.", style="bold red" @@ -128,164 +108,18 @@ def _get_nested_value(data: dict[str, Any], keys: list[str]) -> Any: return reduce(dict.__getitem__, keys, data) -def fetch_and_json_env_file(env_file_path: str = ".env") -> dict[str, Any]: - """Fetch the environment variables from a .env file and return them as a dictionary.""" - try: - # Read the .env file - with open(env_file_path, "r") as f: - env_content = f.read() - - # Parse the .env file content to a dictionary - env_dict = {} - for line in env_content.splitlines(): - if line.strip() and not line.strip().startswith("#"): - key, value = line.split("=", 1) - env_dict[key.strip()] = value.strip() - - return env_dict - - except FileNotFoundError: - console.print(f"Error: {env_file_path} not found.", style="bold red") - except Exception as e: - console.print(f"Error reading the .env file: {e}", style="bold red") - - return {} - - -def tree_copy(source: Path, destination: Path) -> None: - """Copies the entire directory structure from the source to the destination.""" - for item in os.listdir(source): - source_item = os.path.join(source, item) - destination_item = os.path.join(destination, item) - if os.path.isdir(source_item): - shutil.copytree(source_item, destination_item) - else: - shutil.copy2(source_item, destination_item) - - -def tree_find_and_replace(directory: Path, find: str, replace: str) -> None: - """Recursively searches through a directory, replacing a target string in - both file contents and filenames with a specified replacement string. - """ - for path, dirs, files in os.walk(os.path.abspath(directory), topdown=False): - for filename in files: - filepath = os.path.join(path, filename) - - with open(filepath, "r", encoding="utf-8", errors="ignore") as file: - contents = file.read() - with open(filepath, "w") as file: - file.write(contents.replace(find, replace)) - - if find in filename: - new_filename = filename.replace(find, replace) - new_filepath = os.path.join(path, new_filename) - os.rename(filepath, new_filepath) - - for dirname in dirs: - if find in dirname: - new_dirname = dirname.replace(find, replace) - new_dirpath = os.path.join(path, new_dirname) - old_dirpath = os.path.join(path, dirname) - os.rename(old_dirpath, new_dirpath) - - -def load_env_vars(folder_path: Path) -> dict[str, Any]: - """ - Loads environment variables from a .env file in the specified folder path. - - Args: - - folder_path (Path): The path to the folder containing the .env file. - - Returns: - - dict: A dictionary of environment variables. - """ - env_file_path = folder_path / ".env" - env_vars = {} - if env_file_path.exists(): - with open(env_file_path, "r") as file: - for line in file: - key, _, value = line.strip().partition("=") - if key and value: - env_vars[key] = value - return env_vars - - -def update_env_vars( - env_vars: dict[str, Any], provider: str, model: str -) -> dict[str, Any] | None: - """ - Updates environment variables with the API key for the selected provider and model. - - Args: - - env_vars (dict): Environment variables dictionary. - - provider (str): Selected provider. - - model (str): Selected model. - - Returns: - - None - """ - provider_config = cast( - list[str], - ENV_VARS.get( - provider, - [ - click.prompt( - f"Enter the environment variable name for your {provider.capitalize()} API key", - type=str, - ) - ], - ), - ) - - api_key_var = provider_config[0] - - if api_key_var not in env_vars: - try: - env_vars[api_key_var] = click.prompt( - f"Enter your {provider.capitalize()} API key", type=str, hide_input=True - ) - except click.exceptions.Abort: - click.secho("Operation aborted by the user.", fg="red") - return None - else: - click.secho(f"API key already exists for {provider.capitalize()}.", fg="yellow") - - env_vars["MODEL"] = model - click.secho(f"Selected model: {model}", fg="green") - return env_vars - - -def write_env_file(folder_path: Path, env_vars: dict[str, Any]) -> None: - """ - Writes environment variables to a .env file in the specified folder. - - Args: - - folder_path (Path): The path to the folder where the .env file will be written. - - env_vars (dict): A dictionary of environment variables to write. - """ - env_file_path = folder_path / ".env" - with open(env_file_path, "w") as file: - for key, value in env_vars.items(): - file.write(f"{key.upper()}={value}\n") - - def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]: """Get the crew instances from a file.""" crew_instances = [] try: - import importlib.util - - # Add the current directory to sys.path to ensure imports resolve correctly current_dir = os.getcwd() if current_dir not in sys.path: sys.path.insert(0, current_dir) - # If we're not in src directory but there's a src directory, add it to path src_dir = os.path.join(current_dir, "src") if os.path.isdir(src_dir) and src_dir not in sys.path: sys.path.insert(0, src_dir) - # Search in both current directory and src directory if it exists search_paths = [".", "src"] if os.path.isdir("src") else ["."] for search_path in search_paths: @@ -316,7 +150,6 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]: ) continue - # If we found crew instances, break out of the loop if crew_instances: break @@ -334,7 +167,6 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]: ) continue - # If we found crew instances in this search path, break out of the search paths loop if crew_instances: break @@ -352,6 +184,7 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]: def get_crew_instance(module_attr: Any) -> Crew | None: + """Get a Crew instance from a module attribute.""" if ( callable(module_attr) and hasattr(module_attr, "is_crew_class") @@ -372,6 +205,7 @@ def get_crew_instance(module_attr: Any) -> Crew | None: def fetch_crews(module_attr: Any) -> list[Crew]: + """Fetch crew instances from a module attribute.""" crew_instances: list[Crew] = [] if crew_instance := get_crew_instance(module_attr): @@ -386,7 +220,7 @@ def fetch_crews(module_attr: Any) -> list[Crew]: return crew_instances -def get_flow_instance(module_attr: Any) -> Flow | None: +def get_flow_instance(module_attr: Any) -> Flow[Any] | None: """Check if a module attribute is a user-defined Flow subclass and return an instance. Args: @@ -413,13 +247,12 @@ _SKIP_DIRS = frozenset( ) -def get_flows(flow_path: str = "main.py") -> list[Flow]: +def get_flows(flow_path: str = "main.py") -> list[Flow[Any]]: """Get the flow instances from project files. Walks the project directory looking for files matching ``flow_path`` (default ``main.py``), loads each module, and extracts Flow subclass - instances. Directories that are clearly not user source code (virtual - environments, ``.git``, etc.) are pruned to avoid noisy import errors. + instances. Args: flow_path: Filename to search for (default ``main.py``). @@ -427,7 +260,7 @@ def get_flows(flow_path: str = "main.py") -> list[Flow]: Returns: A list of discovered Flow instances. """ - flow_instances: list[Flow] = [] + flow_instances: list[Flow[Any]] = [] try: current_dir = os.getcwd() if current_dir not in sys.path: @@ -486,6 +319,7 @@ def get_flows(flow_path: str = "main.py") -> list[Flow]: def is_valid_tool(obj: Any) -> bool: + """Check if an object is a valid CrewAI tool.""" from crewai.tools.base_tool import Tool if isclass(obj): @@ -498,12 +332,12 @@ def is_valid_tool(obj: Any) -> bool: def extract_available_exports(dir_path: str = "src") -> list[dict[str, Any]]: - """ - Extract available tool classes from the project's __init__.py files. + """Extract available tool classes from the project's __init__.py files. + Only includes classes that inherit from BaseTool or functions decorated with @tool. Returns: - list: A list of valid tool class names or ["BaseTool"] if none found + A list of valid tool class names or ["BaseTool"] if none found. """ try: init_files = Path(dir_path).glob("**/__init__.py") @@ -530,6 +364,7 @@ def extract_available_exports(dir_path: str = "src") -> list[dict[str, Any]]: def build_env_with_tool_repository_credentials( repository_handle: str, ) -> dict[str, Any]: + """Build environment variables with tool repository credentials.""" repository_handle = repository_handle.upper().replace("-", "_") settings = Settings() @@ -545,9 +380,7 @@ def build_env_with_tool_repository_credentials( def _load_tools_from_init(init_file: Path) -> list[dict[str, Any]]: - """ - Load and validate tools from a given __init__.py file. - """ + """Load and validate tools from a given __init__.py file.""" spec = importlib.util.spec_from_file_location("temp_module", init_file) if not spec or not spec.loader: @@ -583,9 +416,7 @@ def _load_tools_from_init(init_file: Path) -> list[dict[str, Any]]: def _print_no_tools_warning() -> None: - """ - Display warning and usage instructions if no tools were found. - """ + """Display warning and usage instructions if no tools were found.""" console.print( "\n[bold yellow]Warning: No valid tools were exposed in your __init__.py file![/bold yellow]" ) diff --git a/lib/crewai/src/crewai/cli/reset_memories_command.py b/lib/crewai/src/crewai/utilities/reset_memories.py similarity index 94% rename from lib/crewai/src/crewai/cli/reset_memories_command.py rename to lib/crewai/src/crewai/utilities/reset_memories.py index 4128d0651..50d4a633e 100644 --- a/lib/crewai/src/crewai/cli/reset_memories_command.py +++ b/lib/crewai/src/crewai/utilities/reset_memories.py @@ -1,12 +1,15 @@ +"""Memory reset utilities for CrewAI crews and flows.""" + import subprocess +from typing import Any import click -from crewai.cli.utils import get_crews, get_flows from crewai.flow import Flow +from crewai.utilities.project_utils import get_crews, get_flows -def _reset_flow_memory(flow: Flow) -> None: +def _reset_flow_memory(flow: Flow[Any]) -> None: """Reset memory for a single flow instance. Handles Memory, MemoryScope (both have .reset()), and MemorySlice diff --git a/lib/crewai/src/crewai/cli/version.py b/lib/crewai/src/crewai/version.py similarity index 98% rename from lib/crewai/src/crewai/cli/version.py rename to lib/crewai/src/crewai/version.py index 60eb3a95a..4aac4252a 100644 --- a/lib/crewai/src/crewai/cli/version.py +++ b/lib/crewai/src/crewai/version.py @@ -1,4 +1,4 @@ -"""Version utilities for CrewAI CLI.""" +"""Version utilities for CrewAI.""" from collections.abc import Mapping from datetime import datetime, timedelta @@ -26,7 +26,7 @@ def _get_cache_file() -> Path: def get_crewai_version() -> str: - """Get the version number of CrewAI running the CLI.""" + """Get the version number of the installed CrewAI package.""" return importlib.metadata.version("crewai") diff --git a/lib/crewai/tests/agents/test_agent.py b/lib/crewai/tests/agents/test_agent.py index 4f6a84602..efda9b476 100644 --- a/lib/crewai/tests/agents/test_agent.py +++ b/lib/crewai/tests/agents/test_agent.py @@ -6,7 +6,7 @@ from unittest import mock from unittest.mock import MagicMock, patch from crewai.agents.crew_agent_executor import AgentFinish, CrewAgentExecutor -from crewai.cli.constants import DEFAULT_LLM_MODEL +from crewai.constants import DEFAULT_LLM_MODEL from crewai.events.event_bus import crewai_event_bus from crewai.events.types.tool_usage_events import ToolUsageFinishedEvent from crewai.knowledge.knowledge import Knowledge @@ -2046,12 +2046,12 @@ def test_get_knowledge_search_query(): @pytest.fixture def mock_get_auth_token(): with patch( - "crewai.cli.authentication.token.get_auth_token", return_value="test_token" + "crewai.auth.token.get_auth_token", return_value="test_token" ): yield -@patch("crewai.cli.plus_api.PlusAPI.get_agent") +@patch("crewai.plus_api.PlusAPI.get_agent") def test_agent_from_repository(mock_get_agent, mock_get_auth_token): from crewai_tools import ( FileReadTool, @@ -2092,7 +2092,7 @@ def test_agent_from_repository(mock_get_agent, mock_get_auth_token): assert agent.tools[1].file_path == "test.txt" -@patch("crewai.cli.plus_api.PlusAPI.get_agent") +@patch("crewai.plus_api.PlusAPI.get_agent") def test_agent_from_repository_override_attributes(mock_get_agent, mock_get_auth_token): from crewai_tools import SerperDevTool @@ -2116,7 +2116,7 @@ def test_agent_from_repository_override_attributes(mock_get_agent, mock_get_auth assert isinstance(agent.tools[0], SerperDevTool) -@patch("crewai.cli.plus_api.PlusAPI.get_agent") +@patch("crewai.plus_api.PlusAPI.get_agent") def test_agent_from_repository_with_invalid_tools(mock_get_agent, mock_get_auth_token): mock_get_response = MagicMock() mock_get_response.status_code = 200 @@ -2139,7 +2139,7 @@ def test_agent_from_repository_with_invalid_tools(mock_get_agent, mock_get_auth_ Agent(from_repository="test_agent") -@patch("crewai.cli.plus_api.PlusAPI.get_agent") +@patch("crewai.plus_api.PlusAPI.get_agent") def test_agent_from_repository_internal_error(mock_get_agent, mock_get_auth_token): mock_get_response = MagicMock() mock_get_response.status_code = 500 @@ -2152,7 +2152,7 @@ def test_agent_from_repository_internal_error(mock_get_agent, mock_get_auth_toke Agent(from_repository="test_agent") -@patch("crewai.cli.plus_api.PlusAPI.get_agent") +@patch("crewai.plus_api.PlusAPI.get_agent") def test_agent_from_repository_agent_not_found(mock_get_agent, mock_get_auth_token): mock_get_response = MagicMock() mock_get_response.status_code = 404 @@ -2165,7 +2165,7 @@ def test_agent_from_repository_agent_not_found(mock_get_agent, mock_get_auth_tok Agent(from_repository="test_agent") -@patch("crewai.cli.plus_api.PlusAPI.get_agent") +@patch("crewai.plus_api.PlusAPI.get_agent") @patch("crewai.utilities.agent_utils.Settings") @patch("crewai.utilities.agent_utils.console") def test_agent_from_repository_displays_org_info( @@ -2198,7 +2198,7 @@ def test_agent_from_repository_displays_org_info( assert agent.backstory == "test backstory" -@patch("crewai.cli.plus_api.PlusAPI.get_agent") +@patch("crewai.plus_api.PlusAPI.get_agent") @patch("crewai.utilities.agent_utils.Settings") @patch("crewai.utilities.agent_utils.console") def test_agent_from_repository_without_org_set( diff --git a/lib/crewai/tests/cli/authentication/providers/test_auth0.py b/lib/crewai/tests/cli/authentication/providers/test_auth0.py index e513a1fb7..7b2c40edc 100644 --- a/lib/crewai/tests/cli/authentication/providers/test_auth0.py +++ b/lib/crewai/tests/cli/authentication/providers/test_auth0.py @@ -1,6 +1,6 @@ import pytest -from crewai.cli.authentication.main import Oauth2Settings -from crewai.cli.authentication.providers.auth0 import Auth0Provider +from crewai.auth.oauth2 import Oauth2Settings +from crewai.auth.providers.auth0 import Auth0Provider diff --git a/lib/crewai/tests/cli/authentication/providers/test_entra_id.py b/lib/crewai/tests/cli/authentication/providers/test_entra_id.py index 889023955..4237a6054 100644 --- a/lib/crewai/tests/cli/authentication/providers/test_entra_id.py +++ b/lib/crewai/tests/cli/authentication/providers/test_entra_id.py @@ -1,7 +1,7 @@ import pytest -from crewai.cli.authentication.main import Oauth2Settings -from crewai.cli.authentication.providers.entra_id import EntraIdProvider +from crewai.auth.oauth2 import Oauth2Settings +from crewai.auth.providers.entra_id import EntraIdProvider class TestEntraIdProvider: diff --git a/lib/crewai/tests/cli/authentication/providers/test_keycloak.py b/lib/crewai/tests/cli/authentication/providers/test_keycloak.py index 05d71b271..cf87e6625 100644 --- a/lib/crewai/tests/cli/authentication/providers/test_keycloak.py +++ b/lib/crewai/tests/cli/authentication/providers/test_keycloak.py @@ -1,7 +1,7 @@ import pytest -from crewai.cli.authentication.main import Oauth2Settings -from crewai.cli.authentication.providers.keycloak import KeycloakProvider +from crewai.auth.oauth2 import Oauth2Settings +from crewai.auth.providers.keycloak import KeycloakProvider class TestKeycloakProvider: diff --git a/lib/crewai/tests/cli/authentication/providers/test_okta.py b/lib/crewai/tests/cli/authentication/providers/test_okta.py index 5108b1bb6..ec76202ca 100644 --- a/lib/crewai/tests/cli/authentication/providers/test_okta.py +++ b/lib/crewai/tests/cli/authentication/providers/test_okta.py @@ -1,7 +1,7 @@ import pytest -from crewai.cli.authentication.main import Oauth2Settings -from crewai.cli.authentication.providers.okta import OktaProvider +from crewai.auth.oauth2 import Oauth2Settings +from crewai.auth.providers.okta import OktaProvider class TestOktaProvider: diff --git a/lib/crewai/tests/cli/authentication/providers/test_workos.py b/lib/crewai/tests/cli/authentication/providers/test_workos.py index 7eda774d6..791bc531b 100644 --- a/lib/crewai/tests/cli/authentication/providers/test_workos.py +++ b/lib/crewai/tests/cli/authentication/providers/test_workos.py @@ -1,6 +1,6 @@ import pytest -from crewai.cli.authentication.main import Oauth2Settings -from crewai.cli.authentication.providers.workos import WorkosProvider +from crewai.auth.oauth2 import Oauth2Settings +from crewai.auth.providers.workos import WorkosProvider class TestWorkosProvider: diff --git a/lib/crewai/tests/cli/authentication/test_auth_main.py b/lib/crewai/tests/cli/authentication/test_auth_main.py index 5cae84552..260ae237e 100644 --- a/lib/crewai/tests/cli/authentication/test_auth_main.py +++ b/lib/crewai/tests/cli/authentication/test_auth_main.py @@ -3,8 +3,8 @@ from unittest.mock import MagicMock, call, patch import pytest import httpx -from crewai.cli.authentication.main import AuthenticationCommand -from crewai.cli.constants import ( +from crewai.auth.oauth2 import AuthenticationCommand +from crewai.constants import ( CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE, CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID, CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN, @@ -14,7 +14,7 @@ from crewai.cli.constants import ( class TestAuthenticationCommand: def setup_method(self): # Mock Settings so we always use default constants regardless of local config. - with patch("crewai.cli.authentication.main.Settings") as mock_settings: + with patch("crewai.auth.oauth2.Settings") as mock_settings: instance = mock_settings.return_value instance.oauth2_provider = "workos" instance.oauth2_domain = CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN @@ -38,12 +38,12 @@ class TestAuthenticationCommand: ), ], ) - @patch("crewai.cli.authentication.main.AuthenticationCommand._get_device_code") + @patch("crewai.auth.oauth2.AuthenticationCommand._get_device_code") @patch( - "crewai.cli.authentication.main.AuthenticationCommand._display_auth_instructions" + "crewai.auth.oauth2.AuthenticationCommand._display_auth_instructions" ) - @patch("crewai.cli.authentication.main.AuthenticationCommand._poll_for_token") - @patch("crewai.cli.authentication.main.console.print") + @patch("crewai.auth.oauth2.AuthenticationCommand._poll_for_token") + @patch("crewai.auth.oauth2.console.print") def test_login( self, mock_console_print, @@ -82,8 +82,8 @@ class TestAuthenticationCommand: self.auth_command.oauth2_provider._get_domain() == expected_urls["domain"] ) - @patch("crewai.cli.authentication.main.webbrowser") - @patch("crewai.cli.authentication.main.console.print") + @patch("crewai.auth.oauth2.webbrowser") + @patch("crewai.auth.oauth2.console.print") def test_display_auth_instructions(self, mock_console_print, mock_webbrowser): device_code_data = { "verification_uri_complete": "https://example.com/auth", @@ -113,8 +113,8 @@ class TestAuthenticationCommand: ], ) @pytest.mark.parametrize("has_expiration", [True, False]) - @patch("crewai.cli.authentication.main.validate_jwt_token") - @patch("crewai.cli.authentication.main.TokenManager.save_tokens") + @patch("crewai.auth.oauth2.validate_jwt_token") + @patch("crewai.auth.oauth2.TokenManager.save_tokens") def test_validate_and_save_token( self, mock_save_tokens, @@ -123,8 +123,8 @@ class TestAuthenticationCommand: jwt_config, has_expiration, ): - from crewai.cli.authentication.main import Oauth2Settings - from crewai.cli.authentication.providers.workos import WorkosProvider + from crewai.auth.oauth2 import Oauth2Settings + from crewai.auth.providers.workos import WorkosProvider if user_provider == "workos": self.auth_command.oauth2_provider = WorkosProvider( @@ -163,8 +163,8 @@ class TestAuthenticationCommand: mock_save_tokens.assert_called_once_with("test_access_token", 0) @patch("crewai_cli.tools.main.ToolCommand") - @patch("crewai.cli.authentication.main.Settings") - @patch("crewai.cli.authentication.main.console.print") + @patch("crewai.auth.oauth2.Settings") + @patch("crewai.auth.oauth2.console.print") def test_login_to_tool_repository_success( self, mock_console_print, mock_settings, mock_tool_command ): @@ -196,7 +196,7 @@ class TestAuthenticationCommand: mock_console_print.assert_has_calls(expected_calls) @patch("crewai_cli.tools.main.ToolCommand") - @patch("crewai.cli.authentication.main.console.print") + @patch("crewai.auth.oauth2.console.print") def test_login_to_tool_repository_error( self, mock_console_print, mock_tool_command ): @@ -226,7 +226,7 @@ class TestAuthenticationCommand: ] mock_console_print.assert_has_calls(expected_calls) - @patch("crewai.cli.authentication.main.httpx.post") + @patch("crewai.auth.oauth2.httpx.post") def test_get_device_code(self, mock_post): mock_response = MagicMock() mock_response.json.return_value = { @@ -262,8 +262,8 @@ class TestAuthenticationCommand: "verification_uri_complete": "https://example.com/auth", } - @patch("crewai.cli.authentication.main.httpx.post") - @patch("crewai.cli.authentication.main.console.print") + @patch("crewai.auth.oauth2.httpx.post") + @patch("crewai.auth.oauth2.console.print") def test_poll_for_token_success(self, mock_console_print, mock_post): mock_response_success = MagicMock() mock_response_success.status_code = 200 @@ -311,8 +311,8 @@ class TestAuthenticationCommand: ] mock_console_print.assert_has_calls(expected_calls) - @patch("crewai.cli.authentication.main.httpx.post") - @patch("crewai.cli.authentication.main.console.print") + @patch("crewai.auth.oauth2.httpx.post") + @patch("crewai.auth.oauth2.console.print") def test_poll_for_token_timeout(self, mock_console_print, mock_post): mock_response_pending = MagicMock() mock_response_pending.status_code = 400 @@ -330,7 +330,7 @@ class TestAuthenticationCommand: "Timeout: Failed to get the token. Please try again.", style="bold red" ) - @patch("crewai.cli.authentication.main.httpx.post") + @patch("crewai.auth.oauth2.httpx.post") def test_poll_for_token_error(self, mock_post): """Test the method to poll for token (error path).""" # Setup mock to return error diff --git a/lib/crewai/tests/cli/authentication/test_utils.py b/lib/crewai/tests/cli/authentication/test_utils.py index 5df00db18..dbd16c842 100644 --- a/lib/crewai/tests/cli/authentication/test_utils.py +++ b/lib/crewai/tests/cli/authentication/test_utils.py @@ -3,11 +3,11 @@ from unittest.mock import MagicMock, patch import jwt -from crewai.cli.authentication.utils import validate_jwt_token +from crewai.auth.utils import validate_jwt_token -@patch("crewai.cli.authentication.utils.PyJWKClient", return_value=MagicMock()) -@patch("crewai.cli.authentication.utils.jwt") +@patch("crewai.auth.utils.PyJWKClient", return_value=MagicMock()) +@patch("crewai.auth.utils.jwt") class TestUtils(unittest.TestCase): def test_validate_jwt_token(self, mock_jwt, mock_pyjwkclient): mock_jwt.decode.return_value = {"exp": 1719859200} diff --git a/lib/crewai/tests/cli/test_cli.py b/lib/crewai/tests/cli/test_cli.py index 2a2f9e1c9..e4710564c 100644 --- a/lib/crewai/tests/cli/test_cli.py +++ b/lib/crewai/tests/cli/test_cli.py @@ -27,9 +27,9 @@ def mock_crew(): @pytest.fixture def mock_get_crews(mock_crew): with mock.patch( - "crewai.cli.reset_memories_command.get_crews", return_value=[mock_crew] + "crewai.utilities.reset_memories.get_crews", return_value=[mock_crew] ) as mock_get_crew, mock.patch( - "crewai.cli.reset_memories_command.get_flows", return_value=[] + "crewai.utilities.reset_memories.get_flows", return_value=[] ): yield mock_get_crew @@ -169,9 +169,9 @@ def mock_flow(): @pytest.fixture def mock_get_flows(mock_flow): with mock.patch( - "crewai.cli.reset_memories_command.get_flows", return_value=[mock_flow] + "crewai.utilities.reset_memories.get_flows", return_value=[mock_flow] ) as mock_get_flow, mock.patch( - "crewai.cli.reset_memories_command.get_crews", return_value=[] + "crewai.utilities.reset_memories.get_crews", return_value=[] ): yield mock_get_flow @@ -196,9 +196,9 @@ def test_reset_flow_knowledge_no_effect(mock_get_flows, mock_flow, runner): def test_reset_no_crew_or_flow_found(runner): with mock.patch( - "crewai.cli.reset_memories_command.get_crews", return_value=[] + "crewai.utilities.reset_memories.get_crews", return_value=[] ), mock.patch( - "crewai.cli.reset_memories_command.get_flows", return_value=[] + "crewai.utilities.reset_memories.get_flows", return_value=[] ): result = runner.invoke(reset_memories, ["-m"]) assert "No crew or flow found." in result.output @@ -206,9 +206,9 @@ def test_reset_no_crew_or_flow_found(runner): def test_reset_crew_and_flow_memory(mock_crew, mock_flow, runner): with mock.patch( - "crewai.cli.reset_memories_command.get_crews", return_value=[mock_crew] + "crewai.utilities.reset_memories.get_crews", return_value=[mock_crew] ), mock.patch( - "crewai.cli.reset_memories_command.get_flows", return_value=[mock_flow] + "crewai.utilities.reset_memories.get_flows", return_value=[mock_flow] ): result = runner.invoke(reset_memories, ["-m"]) mock_crew.reset_memories.assert_called_once_with(command_type="memory") @@ -222,9 +222,9 @@ def test_reset_flow_memory_none(runner): mock_flow.name = "NoMemFlow" mock_flow.memory = None with mock.patch( - "crewai.cli.reset_memories_command.get_crews", return_value=[] + "crewai.utilities.reset_memories.get_crews", return_value=[] ), mock.patch( - "crewai.cli.reset_memories_command.get_flows", return_value=[mock_flow] + "crewai.utilities.reset_memories.get_flows", return_value=[mock_flow] ): result = runner.invoke(reset_memories, ["-m"]) assert "[Flow (NoMemFlow)] Memory has been reset." in result.output diff --git a/lib/crewai/tests/cli/test_config.py b/lib/crewai/tests/cli/test_config.py index 4dec94ee3..eefdbdecd 100644 --- a/lib/crewai/tests/cli/test_config.py +++ b/lib/crewai/tests/cli/test_config.py @@ -6,13 +6,13 @@ from datetime import datetime, timedelta from pathlib import Path from unittest.mock import MagicMock, patch -from crewai.cli.config import ( +from crewai.settings import ( CLI_SETTINGS_KEYS, DEFAULT_CLI_SETTINGS, USER_SETTINGS_KEYS, Settings, ) -from crewai.cli.shared.token_manager import TokenManager +from crewai.auth.token_manager import TokenManager class TestSettings(unittest.TestCase): @@ -69,7 +69,7 @@ class TestSettings(unittest.TestCase): for key in user_settings.keys(): self.assertEqual(getattr(settings, key), None) - @patch("crewai.cli.config.TokenManager") + @patch("crewai.settings.TokenManager") def test_reset_settings(self, mock_token_manager): user_settings = {key: f"value_for_{key}" for key in USER_SETTINGS_KEYS} cli_settings = {key: f"value_for_{key}" for key in CLI_SETTINGS_KEYS if key != "oauth2_extra"} diff --git a/lib/crewai/tests/cli/test_constants.py b/lib/crewai/tests/cli/test_constants.py index 013d8ff8c..346875c8f 100644 --- a/lib/crewai/tests/cli/test_constants.py +++ b/lib/crewai/tests/cli/test_constants.py @@ -1,4 +1,4 @@ -from crewai.cli.constants import ENV_VARS, MODELS, PROVIDERS +from crewai.constants import ENV_VARS, MODELS, PROVIDERS def test_huggingface_in_providers(): diff --git a/lib/crewai/tests/cli/test_git.py b/lib/crewai/tests/cli/test_git.py deleted file mode 100644 index b77106d3f..000000000 --- a/lib/crewai/tests/cli/test_git.py +++ /dev/null @@ -1,101 +0,0 @@ -import pytest -from crewai.cli.git import Repository - - -@pytest.fixture() -def repository(fp): - fp.register(["git", "--version"], stdout="git version 2.30.0\n") - fp.register(["git", "rev-parse", "--is-inside-work-tree"], stdout="true\n") - fp.register(["git", "fetch"], stdout="") - return Repository(path=".") - - -def test_init_with_invalid_git_repo(fp): - fp.register(["git", "--version"], stdout="git version 2.30.0\n") - fp.register( - ["git", "rev-parse", "--is-inside-work-tree"], - returncode=1, - stderr="fatal: not a git repository\n", - ) - - with pytest.raises(ValueError): - Repository(path="invalid/path") - - -def test_is_git_not_installed(fp): - fp.register(["git", "--version"], returncode=1) - - with pytest.raises( - ValueError, match="Git is not installed or not found in your PATH." - ): - Repository(path=".") - - -def test_status(fp, repository): - fp.register( - ["git", "status", "--branch", "--porcelain"], - stdout="## main...origin/main [ahead 1]\n", - ) - assert repository.status() == "## main...origin/main [ahead 1]" - - -def test_has_uncommitted_changes(fp, repository): - fp.register( - ["git", "status", "--branch", "--porcelain"], - stdout="## main...origin/main\n M somefile.txt\n", - ) - assert repository.has_uncommitted_changes() is True - - -def test_is_ahead_or_behind(fp, repository): - fp.register( - ["git", "status", "--branch", "--porcelain"], - stdout="## main...origin/main [ahead 1]\n", - ) - assert repository.is_ahead_or_behind() is True - - -def test_is_synced_when_synced(fp, repository): - fp.register( - ["git", "status", "--branch", "--porcelain"], stdout="## main...origin/main\n" - ) - fp.register( - ["git", "status", "--branch", "--porcelain"], stdout="## main...origin/main\n" - ) - assert repository.is_synced() is True - - -def test_is_synced_with_uncommitted_changes(fp, repository): - fp.register( - ["git", "status", "--branch", "--porcelain"], - stdout="## main...origin/main\n M somefile.txt\n", - ) - assert repository.is_synced() is False - - -def test_is_synced_when_ahead_or_behind(fp, repository): - fp.register( - ["git", "status", "--branch", "--porcelain"], - stdout="## main...origin/main [ahead 1]\n", - ) - fp.register( - ["git", "status", "--branch", "--porcelain"], - stdout="## main...origin/main [ahead 1]\n", - ) - assert repository.is_synced() is False - - -def test_is_synced_with_uncommitted_changes_and_ahead(fp, repository): - fp.register( - ["git", "status", "--branch", "--porcelain"], - stdout="## main...origin/main [ahead 1]\n M somefile.txt\n", - ) - assert repository.is_synced() is False - - -def test_origin_url(fp, repository): - fp.register( - ["git", "remote", "get-url", "origin"], - stdout="https://github.com/user/repo.git\n", - ) - assert repository.origin_url() == "https://github.com/user/repo.git" diff --git a/lib/crewai/tests/cli/test_plus_api.py b/lib/crewai/tests/cli/test_plus_api.py index 95a322e21..5a1d76823 100644 --- a/lib/crewai/tests/cli/test_plus_api.py +++ b/lib/crewai/tests/cli/test_plus_api.py @@ -4,7 +4,7 @@ from unittest.mock import ANY, AsyncMock, MagicMock, patch import pytest -from crewai.cli.plus_api import PlusAPI +from crewai.plus_api import PlusAPI class TestPlusAPI(unittest.TestCase): @@ -20,7 +20,7 @@ class TestPlusAPI(unittest.TestCase): self.assertTrue("CrewAI-CLI/" in self.api.headers["User-Agent"]) self.assertTrue(self.api.headers["X-Crewai-Version"]) - @patch("crewai.cli.plus_api.PlusAPI._make_request") + @patch("crewai.plus_api.PlusAPI._make_request") def test_login_to_tool_repository(self, mock_make_request): mock_response = MagicMock() mock_make_request.return_value = mock_response @@ -32,7 +32,7 @@ class TestPlusAPI(unittest.TestCase): ) self.assertEqual(response, mock_response) - @patch("crewai.cli.plus_api.PlusAPI._make_request") + @patch("crewai.plus_api.PlusAPI._make_request") def test_login_to_tool_repository_with_user_identifier(self, mock_make_request): mock_response = MagicMock() mock_make_request.return_value = mock_response @@ -60,8 +60,8 @@ class TestPlusAPI(unittest.TestCase): **kwargs, ) - @patch("crewai.cli.plus_api.Settings") - @patch("crewai.cli.plus_api.httpx.Client") + @patch("crewai.plus_api.Settings") + @patch("crewai.plus_api.httpx.Client") def test_login_to_tool_repository_with_org_uuid( self, mock_client_class, mock_settings_class ): @@ -83,7 +83,7 @@ class TestPlusAPI(unittest.TestCase): ) self.assertEqual(response, mock_response) - @patch("crewai.cli.plus_api.PlusAPI._make_request") + @patch("crewai.plus_api.PlusAPI._make_request") def test_get_tool(self, mock_make_request): mock_response = MagicMock() mock_make_request.return_value = mock_response @@ -94,8 +94,8 @@ class TestPlusAPI(unittest.TestCase): ) self.assertEqual(response, mock_response) - @patch("crewai.cli.plus_api.Settings") - @patch("crewai.cli.plus_api.httpx.Client") + @patch("crewai.plus_api.Settings") + @patch("crewai.plus_api.httpx.Client") def test_get_tool_with_org_uuid(self, mock_client_class, mock_settings_class): mock_settings = MagicMock() mock_settings.org_uuid = self.org_uuid @@ -115,7 +115,7 @@ class TestPlusAPI(unittest.TestCase): ) self.assertEqual(response, mock_response) - @patch("crewai.cli.plus_api.PlusAPI._make_request") + @patch("crewai.plus_api.PlusAPI._make_request") def test_publish_tool(self, mock_make_request): mock_response = MagicMock() mock_make_request.return_value = mock_response @@ -142,8 +142,8 @@ class TestPlusAPI(unittest.TestCase): ) self.assertEqual(response, mock_response) - @patch("crewai.cli.plus_api.Settings") - @patch("crewai.cli.plus_api.httpx.Client") + @patch("crewai.plus_api.Settings") + @patch("crewai.plus_api.httpx.Client") def test_publish_tool_with_org_uuid(self, mock_client_class, mock_settings_class): mock_settings = MagicMock() mock_settings.org_uuid = self.org_uuid @@ -180,7 +180,7 @@ class TestPlusAPI(unittest.TestCase): ) self.assertEqual(response, mock_response) - @patch("crewai.cli.plus_api.PlusAPI._make_request") + @patch("crewai.plus_api.PlusAPI._make_request") def test_publish_tool_without_description(self, mock_make_request): mock_response = MagicMock() mock_make_request.return_value = mock_response @@ -207,7 +207,7 @@ class TestPlusAPI(unittest.TestCase): ) self.assertEqual(response, mock_response) - @patch("crewai.cli.plus_api.httpx.Client") + @patch("crewai.plus_api.httpx.Client") def test_make_request(self, mock_client_class): mock_client_instance = MagicMock() mock_response = MagicMock() @@ -222,35 +222,35 @@ class TestPlusAPI(unittest.TestCase): ) self.assertEqual(response, mock_response) - @patch("crewai.cli.plus_api.PlusAPI._make_request") + @patch("crewai.plus_api.PlusAPI._make_request") def test_deploy_by_name(self, mock_make_request): self.api.deploy_by_name("test_project") mock_make_request.assert_called_once_with( "POST", "/crewai_plus/api/v1/crews/by-name/test_project/deploy" ) - @patch("crewai.cli.plus_api.PlusAPI._make_request") + @patch("crewai.plus_api.PlusAPI._make_request") def test_deploy_by_uuid(self, mock_make_request): self.api.deploy_by_uuid("test_uuid") mock_make_request.assert_called_once_with( "POST", "/crewai_plus/api/v1/crews/test_uuid/deploy" ) - @patch("crewai.cli.plus_api.PlusAPI._make_request") + @patch("crewai.plus_api.PlusAPI._make_request") def test_crew_status_by_name(self, mock_make_request): self.api.crew_status_by_name("test_project") mock_make_request.assert_called_once_with( "GET", "/crewai_plus/api/v1/crews/by-name/test_project/status" ) - @patch("crewai.cli.plus_api.PlusAPI._make_request") + @patch("crewai.plus_api.PlusAPI._make_request") def test_crew_status_by_uuid(self, mock_make_request): self.api.crew_status_by_uuid("test_uuid") mock_make_request.assert_called_once_with( "GET", "/crewai_plus/api/v1/crews/test_uuid/status" ) - @patch("crewai.cli.plus_api.PlusAPI._make_request") + @patch("crewai.plus_api.PlusAPI._make_request") def test_crew_by_name(self, mock_make_request): self.api.crew_by_name("test_project") mock_make_request.assert_called_once_with( @@ -262,7 +262,7 @@ class TestPlusAPI(unittest.TestCase): "GET", "/crewai_plus/api/v1/crews/by-name/test_project/logs/custom_log" ) - @patch("crewai.cli.plus_api.PlusAPI._make_request") + @patch("crewai.plus_api.PlusAPI._make_request") def test_crew_by_uuid(self, mock_make_request): self.api.crew_by_uuid("test_uuid") mock_make_request.assert_called_once_with( @@ -274,26 +274,26 @@ class TestPlusAPI(unittest.TestCase): "GET", "/crewai_plus/api/v1/crews/test_uuid/logs/custom_log" ) - @patch("crewai.cli.plus_api.PlusAPI._make_request") + @patch("crewai.plus_api.PlusAPI._make_request") def test_delete_crew_by_name(self, mock_make_request): self.api.delete_crew_by_name("test_project") mock_make_request.assert_called_once_with( "DELETE", "/crewai_plus/api/v1/crews/by-name/test_project" ) - @patch("crewai.cli.plus_api.PlusAPI._make_request") + @patch("crewai.plus_api.PlusAPI._make_request") def test_delete_crew_by_uuid(self, mock_make_request): self.api.delete_crew_by_uuid("test_uuid") mock_make_request.assert_called_once_with( "DELETE", "/crewai_plus/api/v1/crews/test_uuid" ) - @patch("crewai.cli.plus_api.PlusAPI._make_request") + @patch("crewai.plus_api.PlusAPI._make_request") def test_list_crews(self, mock_make_request): self.api.list_crews() mock_make_request.assert_called_once_with("GET", "/crewai_plus/api/v1/crews") - @patch("crewai.cli.plus_api.PlusAPI._make_request") + @patch("crewai.plus_api.PlusAPI._make_request") def test_create_crew(self, mock_make_request): payload = {"name": "test_crew"} self.api.create_crew(payload) @@ -301,7 +301,7 @@ class TestPlusAPI(unittest.TestCase): "POST", "/crewai_plus/api/v1/crews", json=payload ) - @patch("crewai.cli.plus_api.Settings") + @patch("crewai.plus_api.Settings") @patch.dict(os.environ, {"CREWAI_PLUS_URL": ""}) def test_custom_base_url(self, mock_settings_class): mock_settings = MagicMock() @@ -342,7 +342,7 @@ async def test_get_agent(mock_async_client_class): @pytest.mark.asyncio @patch("httpx.AsyncClient") -@patch("crewai.cli.plus_api.Settings") +@patch("crewai.plus_api.Settings") async def test_get_agent_with_org_uuid(mock_settings_class, mock_async_client_class): org_uuid = "test-org-uuid" mock_settings = MagicMock() diff --git a/lib/crewai/tests/cli/test_token_manager.py b/lib/crewai/tests/cli/test_token_manager.py index 5d7fc5790..12407ae01 100644 --- a/lib/crewai/tests/cli/test_token_manager.py +++ b/lib/crewai/tests/cli/test_token_manager.py @@ -10,20 +10,20 @@ from unittest.mock import patch from cryptography.fernet import Fernet -from crewai.cli.shared.token_manager import TokenManager +from crewai.auth.token_manager import TokenManager class TestTokenManager(unittest.TestCase): """Test cases for TokenManager.""" - @patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key") + @patch("crewai.auth.token_manager.TokenManager._get_or_create_key") def setUp(self, mock_get_key: unittest.mock.MagicMock) -> None: """Set up test fixtures.""" mock_get_key.return_value = Fernet.generate_key() self.token_manager = TokenManager() - @patch("crewai.cli.shared.token_manager.TokenManager._read_secure_file") - @patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key") + @patch("crewai.auth.token_manager.TokenManager._read_secure_file") + @patch("crewai.auth.token_manager.TokenManager._get_or_create_key") def test_get_or_create_key_existing( self, mock_get_or_create: unittest.mock.MagicMock, @@ -45,7 +45,7 @@ class TestTokenManager(unittest.TestCase): with ( patch.object(self.token_manager, "_read_secure_file", return_value=None) as mock_read, patch.object(self.token_manager, "_atomic_create_secure_file", return_value=True) as mock_atomic_create, - patch("crewai.cli.shared.token_manager.Fernet.generate_key", return_value=mock_key) as mock_generate, + patch("crewai.auth.token_manager.Fernet.generate_key", return_value=mock_key) as mock_generate, ): result = self.token_manager._get_or_create_key() @@ -62,14 +62,14 @@ class TestTokenManager(unittest.TestCase): with ( patch.object(self.token_manager, "_read_secure_file", side_effect=[None, their_key]) as mock_read, patch.object(self.token_manager, "_atomic_create_secure_file", return_value=False) as mock_atomic_create, - patch("crewai.cli.shared.token_manager.Fernet.generate_key", return_value=our_key), + patch("crewai.auth.token_manager.Fernet.generate_key", return_value=our_key), ): result = self.token_manager._get_or_create_key() self.assertEqual(result, their_key) self.assertEqual(mock_read.call_count, 2) - @patch("crewai.cli.shared.token_manager.TokenManager._atomic_write_secure_file") + @patch("crewai.auth.token_manager.TokenManager._atomic_write_secure_file") def test_save_tokens( self, mock_write: unittest.mock.MagicMock ) -> None: @@ -88,7 +88,7 @@ class TestTokenManager(unittest.TestCase): expiration = datetime.fromisoformat(data["expiration"]) self.assertEqual(expiration, datetime.fromtimestamp(expires_at)) - @patch("crewai.cli.shared.token_manager.TokenManager._read_secure_file") + @patch("crewai.auth.token_manager.TokenManager._read_secure_file") def test_get_token_valid( self, mock_read: unittest.mock.MagicMock ) -> None: @@ -103,7 +103,7 @@ class TestTokenManager(unittest.TestCase): self.assertEqual(result, access_token) - @patch("crewai.cli.shared.token_manager.TokenManager._read_secure_file") + @patch("crewai.auth.token_manager.TokenManager._read_secure_file") def test_get_token_expired( self, mock_read: unittest.mock.MagicMock ) -> None: @@ -118,7 +118,7 @@ class TestTokenManager(unittest.TestCase): self.assertIsNone(result) - @patch("crewai.cli.shared.token_manager.TokenManager._read_secure_file") + @patch("crewai.auth.token_manager.TokenManager._read_secure_file") def test_get_token_not_found( self, mock_read: unittest.mock.MagicMock ) -> None: @@ -129,7 +129,7 @@ class TestTokenManager(unittest.TestCase): self.assertIsNone(result) - @patch("crewai.cli.shared.token_manager.TokenManager._delete_secure_file") + @patch("crewai.auth.token_manager.TokenManager._delete_secure_file") def test_clear_tokens( self, mock_delete: unittest.mock.MagicMock ) -> None: @@ -159,7 +159,7 @@ class TestAtomicFileOperations(unittest.TestCase): import shutil shutil.rmtree(self.temp_dir, ignore_errors=True) - @patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key") + @patch("crewai.auth.token_manager.TokenManager._get_or_create_key") def test_atomic_create_new_file( self, mock_get_key: unittest.mock.MagicMock ) -> None: @@ -175,7 +175,7 @@ class TestAtomicFileOperations(unittest.TestCase): self.assertEqual(file_path.read_bytes(), b"content") self.assertEqual(file_path.stat().st_mode & 0o777, 0o600) - @patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key") + @patch("crewai.auth.token_manager.TokenManager._get_or_create_key") def test_atomic_create_existing_file( self, mock_get_key: unittest.mock.MagicMock ) -> None: @@ -192,7 +192,7 @@ class TestAtomicFileOperations(unittest.TestCase): self.assertFalse(result) self.assertEqual(file_path.read_bytes(), b"original") - @patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key") + @patch("crewai.auth.token_manager.TokenManager._get_or_create_key") def test_atomic_write_new_file( self, mock_get_key: unittest.mock.MagicMock ) -> None: @@ -207,7 +207,7 @@ class TestAtomicFileOperations(unittest.TestCase): self.assertEqual(file_path.read_bytes(), b"content") self.assertEqual(file_path.stat().st_mode & 0o777, 0o600) - @patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key") + @patch("crewai.auth.token_manager.TokenManager._get_or_create_key") def test_atomic_write_overwrites( self, mock_get_key: unittest.mock.MagicMock ) -> None: @@ -222,7 +222,7 @@ class TestAtomicFileOperations(unittest.TestCase): self.assertEqual(file_path.read_bytes(), b"new content") - @patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key") + @patch("crewai.auth.token_manager.TokenManager._get_or_create_key") def test_atomic_write_no_temp_file_on_success( self, mock_get_key: unittest.mock.MagicMock ) -> None: @@ -236,7 +236,7 @@ class TestAtomicFileOperations(unittest.TestCase): temp_files = list(Path(self.temp_dir).glob(".test.txt.*")) self.assertEqual(len(temp_files), 0) - @patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key") + @patch("crewai.auth.token_manager.TokenManager._get_or_create_key") def test_read_secure_file_exists( self, mock_get_key: unittest.mock.MagicMock ) -> None: @@ -251,7 +251,7 @@ class TestAtomicFileOperations(unittest.TestCase): self.assertEqual(result, b"content") - @patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key") + @patch("crewai.auth.token_manager.TokenManager._get_or_create_key") def test_read_secure_file_not_exists( self, mock_get_key: unittest.mock.MagicMock ) -> None: @@ -263,7 +263,7 @@ class TestAtomicFileOperations(unittest.TestCase): self.assertIsNone(result) - @patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key") + @patch("crewai.auth.token_manager.TokenManager._get_or_create_key") def test_delete_secure_file_exists( self, mock_get_key: unittest.mock.MagicMock ) -> None: @@ -278,7 +278,7 @@ class TestAtomicFileOperations(unittest.TestCase): self.assertFalse(file_path.exists()) - @patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key") + @patch("crewai.auth.token_manager.TokenManager._get_or_create_key") def test_delete_secure_file_not_exists( self, mock_get_key: unittest.mock.MagicMock ) -> None: diff --git a/lib/crewai/tests/cli/test_utils.py b/lib/crewai/tests/cli/test_utils.py index 5baf1cffe..3c8b02dd1 100644 --- a/lib/crewai/tests/cli/test_utils.py +++ b/lib/crewai/tests/cli/test_utils.py @@ -4,7 +4,7 @@ import tempfile from pathlib import Path import pytest -from crewai.cli import utils +from crewai.utilities import project_utils as utils @pytest.fixture diff --git a/lib/crewai/tests/cli/test_version.py b/lib/crewai/tests/cli/test_version.py index 4e53ea923..39cbbaaa2 100644 --- a/lib/crewai/tests/cli/test_version.py +++ b/lib/crewai/tests/cli/test_version.py @@ -6,7 +6,7 @@ from pathlib import Path from unittest.mock import MagicMock, patch from crewai import __version__ -from crewai.cli.version import ( +from crewai.version import ( _find_latest_non_yanked_version, _get_cache_file, _is_cache_valid, @@ -60,8 +60,8 @@ class TestVersionChecking: cache_data = {"version": "1.0.0"} assert _is_cache_valid(cache_data) is False - @patch("crewai.cli.version.Path.exists") - @patch("crewai.cli.version.request.urlopen") + @patch("crewai.version.Path.exists") + @patch("crewai.version.request.urlopen") def test_get_latest_version_from_pypi_success( self, mock_urlopen: MagicMock, mock_exists: MagicMock ) -> None: @@ -82,8 +82,8 @@ class TestVersionChecking: version = get_latest_version_from_pypi() assert version == "2.0.0" - @patch("crewai.cli.version.Path.exists") - @patch("crewai.cli.version.request.urlopen") + @patch("crewai.version.Path.exists") + @patch("crewai.version.request.urlopen") def test_get_latest_version_from_pypi_failure( self, mock_urlopen: MagicMock, mock_exists: MagicMock ) -> None: @@ -97,8 +97,8 @@ class TestVersionChecking: version = get_latest_version_from_pypi() assert version is None - @patch("crewai.cli.version.get_crewai_version") - @patch("crewai.cli.version.get_latest_version_from_pypi") + @patch("crewai.version.get_crewai_version") + @patch("crewai.version.get_latest_version_from_pypi") def test_is_newer_version_available_true( self, mock_latest: MagicMock, mock_current: MagicMock ) -> None: @@ -111,8 +111,8 @@ class TestVersionChecking: assert current == "1.0.0" assert latest == "2.0.0" - @patch("crewai.cli.version.get_crewai_version") - @patch("crewai.cli.version.get_latest_version_from_pypi") + @patch("crewai.version.get_crewai_version") + @patch("crewai.version.get_latest_version_from_pypi") def test_is_newer_version_available_false( self, mock_latest: MagicMock, mock_current: MagicMock ) -> None: @@ -125,8 +125,8 @@ class TestVersionChecking: assert current == "2.0.0" assert latest == "2.0.0" - @patch("crewai.cli.version.get_crewai_version") - @patch("crewai.cli.version.get_latest_version_from_pypi") + @patch("crewai.version.get_crewai_version") + @patch("crewai.version.get_latest_version_from_pypi") def test_is_newer_version_available_with_none_latest( self, mock_latest: MagicMock, mock_current: MagicMock ) -> None: @@ -260,8 +260,8 @@ class TestIsVersionYanked: class TestIsCurrentVersionYanked: """Test is_current_version_yanked public function.""" - @patch("crewai.cli.version.get_crewai_version") - @patch("crewai.cli.version._get_cache_file") + @patch("crewai.version.get_crewai_version") + @patch("crewai.version._get_cache_file") def test_reads_from_valid_cache( self, mock_cache_file: MagicMock, mock_version: MagicMock, tmp_path: Path ) -> None: @@ -282,8 +282,8 @@ class TestIsCurrentVersionYanked: assert is_yanked is True assert reason == "bad release" - @patch("crewai.cli.version.get_crewai_version") - @patch("crewai.cli.version._get_cache_file") + @patch("crewai.version.get_crewai_version") + @patch("crewai.version._get_cache_file") def test_not_yanked_from_cache( self, mock_cache_file: MagicMock, mock_version: MagicMock, tmp_path: Path ) -> None: @@ -304,9 +304,9 @@ class TestIsCurrentVersionYanked: assert is_yanked is False assert reason == "" - @patch("crewai.cli.version.get_latest_version_from_pypi") - @patch("crewai.cli.version.get_crewai_version") - @patch("crewai.cli.version._get_cache_file") + @patch("crewai.version.get_latest_version_from_pypi") + @patch("crewai.version.get_crewai_version") + @patch("crewai.version._get_cache_file") def test_triggers_fetch_on_stale_cache( self, mock_cache_file: MagicMock, @@ -346,9 +346,9 @@ class TestIsCurrentVersionYanked: assert is_yanked is False mock_fetch.assert_called_once() - @patch("crewai.cli.version.get_latest_version_from_pypi") - @patch("crewai.cli.version.get_crewai_version") - @patch("crewai.cli.version._get_cache_file") + @patch("crewai.version.get_latest_version_from_pypi") + @patch("crewai.version.get_crewai_version") + @patch("crewai.version._get_cache_file") def test_returns_false_on_fetch_failure( self, mock_cache_file: MagicMock, diff --git a/lib/crewai/tests/llms/openai/test_openai.py b/lib/crewai/tests/llms/openai/test_openai.py index 069823a7a..3b9953e7a 100644 --- a/lib/crewai/tests/llms/openai/test_openai.py +++ b/lib/crewai/tests/llms/openai/test_openai.py @@ -11,7 +11,7 @@ from crewai.llms.providers.openai.completion import OpenAICompletion, ResponsesA from crewai.crew import Crew from crewai.agent import Agent from crewai.task import Task -from crewai.cli.constants import DEFAULT_LLM_MODEL +from crewai.constants import DEFAULT_LLM_MODEL def test_openai_completion_is_used_when_openai_provider(): """ diff --git a/lib/crewai/tests/mcp/test_amp_mcp.py b/lib/crewai/tests/mcp/test_amp_mcp.py index f13484a8d..5b86a525d 100644 --- a/lib/crewai/tests/mcp/test_amp_mcp.py +++ b/lib/crewai/tests/mcp/test_amp_mcp.py @@ -102,7 +102,7 @@ class TestBuildMCPConfigFromDict: class TestFetchAmpMCPConfigs: - @patch("crewai.cli.plus_api.PlusAPI") + @patch("crewai.plus_api.PlusAPI") @patch("crewai_tools.tools.crewai_platform_tools.misc.get_platform_integration_token", return_value="test-api-key") def test_fetches_configs_successfully(self, mock_get_token, mock_plus_api_class, resolver): mock_response = MagicMock() @@ -133,7 +133,7 @@ class TestFetchAmpMCPConfigs: mock_plus_api_class.assert_called_once_with(api_key="test-api-key") mock_plus_api.get_mcp_configs.assert_called_once_with(["notion", "github"]) - @patch("crewai.cli.plus_api.PlusAPI") + @patch("crewai.plus_api.PlusAPI") @patch("crewai_tools.tools.crewai_platform_tools.misc.get_platform_integration_token", return_value="test-api-key") def test_omits_missing_slugs(self, mock_get_token, mock_plus_api_class, resolver): mock_response = MagicMock() @@ -150,7 +150,7 @@ class TestFetchAmpMCPConfigs: assert "notion" in result assert "missing-server" not in result - @patch("crewai.cli.plus_api.PlusAPI") + @patch("crewai.plus_api.PlusAPI") @patch("crewai_tools.tools.crewai_platform_tools.misc.get_platform_integration_token", return_value="test-api-key") def test_returns_empty_on_http_error(self, mock_get_token, mock_plus_api_class, resolver): mock_response = MagicMock() @@ -163,7 +163,7 @@ class TestFetchAmpMCPConfigs: assert result == {} - @patch("crewai.cli.plus_api.PlusAPI") + @patch("crewai.plus_api.PlusAPI") @patch("crewai_tools.tools.crewai_platform_tools.misc.get_platform_integration_token", return_value="test-api-key") def test_returns_empty_on_network_error(self, mock_get_token, mock_plus_api_class, resolver): import httpx diff --git a/lib/crewai/tests/tracing/test_tracing.py b/lib/crewai/tests/tracing/test_tracing.py index c2558c17c..6030e9b6c 100644 --- a/lib/crewai/tests/tracing/test_tracing.py +++ b/lib/crewai/tests/tracing/test_tracing.py @@ -35,7 +35,7 @@ class TestTraceListenerSetup: # Need to patch all the places where get_auth_token is imported/used with ( patch( - "crewai.cli.authentication.token.get_auth_token", + "crewai.auth.token.get_auth_token", return_value="mock_token_12345", ), patch( diff --git a/lib/crewai/tests/utilities/test_llm_utils.py b/lib/crewai/tests/utilities/test_llm_utils.py index 5d7d70b76..255f29d7d 100644 --- a/lib/crewai/tests/utilities/test_llm_utils.py +++ b/lib/crewai/tests/utilities/test_llm_utils.py @@ -2,7 +2,7 @@ import os from typing import Any from unittest.mock import patch -from crewai.cli.constants import DEFAULT_LLM_MODEL +from crewai.constants import DEFAULT_LLM_MODEL from crewai.llm import LLM from crewai.llms.base_llm import BaseLLM from crewai.utilities.llm_utils import create_llm