refactor: remove cli/ from crewai package and relocate to proper modules

Move framework infrastructure out of crewai/cli/ to dedicated modules:
- cli/authentication/ → crewai/auth/
- cli/config.py → crewai/settings.py
- cli/constants.py → crewai/constants.py
- cli/plus_api.py → crewai/plus_api.py
- cli/version.py → crewai/version.py
- cli/crew_chat.py → crewai/utilities/crew_chat.py
- cli/reset_memories_command.py → crewai/utilities/reset_memories.py
- cli/utils.py (framework parts) → crewai/utilities/project_utils.py

Delete CLI-only duplicates (command.py, git.py, provider.py) already
present in crewai_cli. Replace _login_to_tool_repository with a
_post_login() hook in AuthenticationCommand. Update all imports and
mock.patch paths across both packages and tests.
This commit is contained in:
Greyson Lalonde
2026-03-15 19:39:55 -04:00
parent cf1636c300
commit 7afca5daab
57 changed files with 324 additions and 1025 deletions

View File

@@ -1,7 +1,7 @@
"""Wrapper for the crew chat command.
Delegates to ``crewai.cli.crew_chat.run_chat`` when the full crewai package is
installed, otherwise prints a helpful error message.
Delegates to ``crewai.utilities.crew_chat.run_chat`` when the full crewai
package is installed, otherwise prints a helpful error message.
"""
from __future__ import annotations
@@ -11,7 +11,7 @@ import click
def run_chat() -> None:
try:
from crewai.cli.crew_chat import run_chat as _run_chat
from crewai.utilities.crew_chat import run_chat as _run_chat
except ImportError:
click.secho(
"The 'chat' command requires the full crewai package.\n"

View File

@@ -1,6 +1,6 @@
"""Wrapper for the reset-memories command.
Delegates to ``crewai.cli.reset_memories_command`` when the full crewai
Delegates to ``crewai.utilities.reset_memories`` when the full crewai
package is installed, otherwise prints a helpful error message.
"""
@@ -17,7 +17,7 @@ def reset_memories_command(
all: bool,
) -> None:
try:
from crewai.cli.reset_memories_command import (
from crewai.utilities.reset_memories import (
reset_memories_command as _reset,
)
except ImportError:

View File

@@ -246,7 +246,7 @@ def is_valid_tool(obj: Any) -> bool:
Falls back to crewai's ``is_valid_tool`` when available.
"""
try:
from crewai.cli.utils import is_valid_tool as _core_is_valid_tool
from crewai.utilities.project_utils import is_valid_tool as _core_is_valid_tool
return _core_is_valid_tool(obj)
except ImportError:

View File

@@ -0,0 +1,7 @@
"""Authentication utilities for the CrewAI platform."""
from crewai.auth.oauth2 import AuthenticationCommand
from crewai.auth.token import AuthError, get_auth_token
__all__ = ["AuthError", "AuthenticationCommand", "get_auth_token"]

View File

@@ -0,0 +1,3 @@
"""Authentication constants."""
ALGORITHMS = ["RS256"]

View File

@@ -1,3 +1,5 @@
"""OAuth2 authentication for the CrewAI platform."""
import time
from typing import TYPE_CHECKING, Any, TypeVar, cast
import webbrowser
@@ -6,9 +8,9 @@ import httpx
from pydantic import BaseModel, Field
from rich.console import Console
from crewai.cli.authentication.utils import validate_jwt_token
from crewai.cli.config import Settings
from crewai.cli.shared.token_manager import TokenManager
from crewai.auth.token_manager import TokenManager
from crewai.auth.utils import validate_jwt_token
from crewai.settings import Settings
console = Console()
@@ -17,6 +19,8 @@ TOauth2Settings = TypeVar("TOauth2Settings", bound="Oauth2Settings")
class Oauth2Settings(BaseModel):
"""OAuth2 provider configuration."""
provider: str = Field(
description="OAuth2 provider used for authentication (e.g., workos, okta, auth0)."
)
@@ -38,7 +42,6 @@ class Oauth2Settings(BaseModel):
@classmethod
def from_settings(cls: type[TOauth2Settings]) -> TOauth2Settings:
"""Create an Oauth2Settings instance from the CLI settings."""
settings = Settings()
return cls(
@@ -51,23 +54,25 @@ class Oauth2Settings(BaseModel):
if TYPE_CHECKING:
from crewai.cli.authentication.providers.base_provider import BaseProvider
from crewai.auth.providers.base_provider import BaseProvider
class ProviderFactory:
"""Factory for creating OAuth2 providers from settings."""
@classmethod
def from_settings(
cls: type["ProviderFactory"], # noqa: UP037
settings: Oauth2Settings | None = None,
) -> "BaseProvider": # noqa: UP037
"""Create a provider instance from settings."""
settings = settings or Oauth2Settings.from_settings()
import importlib
module = importlib.import_module(
f"crewai.cli.authentication.providers.{settings.provider.lower()}"
f"crewai.auth.providers.{settings.provider.lower()}"
)
# Converts from snake_case to CamelCase to obtain the provider class name.
provider = getattr(
module,
f"{''.join(word.capitalize() for word in settings.provider.split('_'))}Provider",
@@ -77,6 +82,8 @@ class ProviderFactory:
class AuthenticationCommand:
"""Handles authentication with the CrewAI platform."""
def __init__(self) -> None:
self.token_manager = TokenManager()
self.oauth2_provider = ProviderFactory.from_settings()
@@ -92,7 +99,6 @@ class AuthenticationCommand:
def _get_device_code(self) -> dict[str, Any]:
"""Get the device code to authenticate the user."""
device_code_payload = {
"client_id": self.oauth2_provider.get_client_id(),
"scope": " ".join(self.oauth2_provider.get_oauth_scopes()),
@@ -108,7 +114,6 @@ class AuthenticationCommand:
def _display_auth_instructions(self, device_code_data: dict[str, str]) -> None:
"""Display the authentication instructions to the user."""
verification_uri = device_code_data.get(
"verification_uri_complete", device_code_data.get("verification_uri", "")
)
@@ -119,7 +124,6 @@ class AuthenticationCommand:
def _poll_for_token(self, device_code_data: dict[str, Any]) -> None:
"""Polls the server for the token until it is received, or max attempts are reached."""
token_payload = {
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
"device_code": device_code_data["device_code"],
@@ -143,7 +147,7 @@ class AuthenticationCommand:
style="bold green",
)
self._login_to_tool_repository()
self._post_login()
console.print("\n[bold green]Welcome to CrewAI AMP![/bold green]\n")
return
@@ -162,7 +166,6 @@ class AuthenticationCommand:
def _validate_and_save_token(self, token_data: dict[str, Any]) -> None:
"""Validates the JWT token and saves the token to the token manager."""
jwt_token = token_data["access_token"]
issuer = self.oauth2_provider.get_issuer()
jwt_token_data = {
@@ -177,39 +180,5 @@ class AuthenticationCommand:
expires_at = decoded_token.get("exp", 0)
self.token_manager.save_tokens(jwt_token, expires_at)
def _login_to_tool_repository(self) -> None:
"""Login to the tool repository."""
from crewai_cli.tools.main import ToolCommand
try:
console.print(
"Now logging you in to the Tool Repository... ",
style="bold blue",
end="",
)
ToolCommand().login()
console.print(
"Success!\n",
style="bold green",
)
settings = Settings()
console.print(
f"You are now authenticated to the tool repository for organization [bold cyan]'{settings.org_name if settings.org_name else settings.org_uuid}'[/bold cyan]",
style="green",
)
except Exception:
console.print(
"\n[bold yellow]Warning:[/bold yellow] Authentication with the Tool Repository failed.",
style="yellow",
)
console.print(
"Other features will work normally, but you may experience limitations "
"with downloading and publishing tools."
"\nRun [bold]crewai login[/bold] to try logging in again.\n",
style="yellow",
)
def _post_login(self) -> None:
"""Hook called after successful login. Override in subclasses for additional behavior."""

View File

@@ -0,0 +1 @@
"""OAuth2 authentication providers."""

View File

@@ -1,7 +1,11 @@
from crewai.cli.authentication.providers.base_provider import BaseProvider
"""Auth0 OAuth2 provider."""
from crewai.auth.providers.base_provider import BaseProvider
class Auth0Provider(BaseProvider):
"""Auth0 OAuth2 provider implementation."""
def get_authorize_url(self) -> str:
return f"https://{self._get_domain()}/oauth/device/code"

View File

@@ -1,9 +1,13 @@
"""Base OAuth2 provider interface."""
from abc import ABC, abstractmethod
from crewai.cli.authentication.main import Oauth2Settings
from crewai.auth.oauth2 import Oauth2Settings
class BaseProvider(ABC):
"""Abstract base class for OAuth2 providers."""
def __init__(self, settings: Oauth2Settings):
self.settings = settings
@@ -26,8 +30,9 @@ class BaseProvider(ABC):
def get_client_id(self) -> str: ...
def get_required_fields(self) -> list[str]:
"""Returns which provider-specific fields inside the "extra" dict will be required"""
"""Returns which provider-specific fields inside the "extra" dict will be required."""
return []
def get_oauth_scopes(self) -> list[str]:
"""Returns the OAuth scopes to request."""
return ["openid", "profile", "email"]

View File

@@ -1,9 +1,13 @@
"""Entra ID (Azure AD) OAuth2 provider."""
from typing import cast
from crewai.cli.authentication.providers.base_provider import BaseProvider
from crewai.auth.providers.base_provider import BaseProvider
class EntraIdProvider(BaseProvider):
"""Entra ID (Azure AD) OAuth2 provider implementation."""
def get_authorize_url(self) -> str:
return f"{self._base_url()}/oauth2/v2.0/devicecode"

View File

@@ -1,7 +1,11 @@
from crewai.cli.authentication.providers.base_provider import BaseProvider
"""Keycloak OAuth2 provider."""
from crewai.auth.providers.base_provider import BaseProvider
class KeycloakProvider(BaseProvider):
"""Keycloak OAuth2 provider implementation."""
def get_authorize_url(self) -> str:
return f"{self._oauth2_base_url()}/realms/{self.settings.extra.get('realm')}/protocol/openid-connect/auth/device"

View File

@@ -1,7 +1,11 @@
from crewai.cli.authentication.providers.base_provider import BaseProvider
"""Okta OAuth2 provider."""
from crewai.auth.providers.base_provider import BaseProvider
class OktaProvider(BaseProvider):
"""Okta OAuth2 provider implementation."""
def get_authorize_url(self) -> str:
return f"{self._oauth2_base_url()}/v1/device/authorize"

View File

@@ -1,7 +1,11 @@
from crewai.cli.authentication.providers.base_provider import BaseProvider
"""WorkOS OAuth2 provider."""
from crewai.auth.providers.base_provider import BaseProvider
class WorkosProvider(BaseProvider):
"""WorkOS OAuth2 provider implementation."""
def get_authorize_url(self) -> str:
return f"https://{self._get_domain()}/oauth2/device_authorization"

View File

@@ -1,8 +1,10 @@
from crewai.cli.shared.token_manager import TokenManager
"""Authentication token retrieval."""
from crewai.auth.token_manager import TokenManager
class AuthError(Exception):
pass
"""Raised when authentication fails."""
def get_auth_token() -> str:

View File

@@ -1,3 +1,5 @@
"""Manages encrypted token storage."""
from datetime import datetime
import json
import os

View File

@@ -1,3 +1,5 @@
"""JWT token validation utilities."""
from typing import Any
import jwt
@@ -7,18 +9,20 @@ from jwt import PyJWKClient
def validate_jwt_token(
jwt_token: str, jwks_url: str, issuer: str, audience: str
) -> Any:
"""
Verify the token's signature and claims using PyJWT.
:param jwt_token: The JWT (JWS) string to validate.
:param jwks_url: The URL of the JWKS endpoint.
:param issuer: The expected issuer of the token.
:param audience: The expected audience of the token.
:return: The decoded token.
:raises Exception: If the token is invalid for any reason (e.g., signature mismatch,
expired, incorrect issuer/audience, JWKS fetching error,
missing required claims).
"""
"""Verify the token's signature and claims using PyJWT.
Args:
jwt_token: The JWT (JWS) string to validate.
jwks_url: The URL of the JWKS endpoint.
issuer: The expected issuer of the token.
audience: The expected audience of the token.
Returns:
The decoded token.
Raises:
Exception: If the token is invalid for any reason.
"""
try:
jwk_client = PyJWKClient(jwks_url)
signing_key = jwk_client.get_signing_key_from_jwt(jwt_token)

View File

@@ -1,4 +0,0 @@
from crewai.cli.authentication.main import AuthenticationCommand
__all__ = ["AuthenticationCommand"]

View File

@@ -1 +0,0 @@
ALGORITHMS = ["RS256"]

View File

@@ -1,76 +0,0 @@
import json
import httpx
from rich.console import Console
from crewai.cli.authentication.token import get_auth_token
from crewai.cli.plus_api import PlusAPI
from crewai.telemetry.telemetry import Telemetry
console = Console()
class BaseCommand:
def __init__(self) -> None:
self._telemetry = Telemetry()
self._telemetry.set_tracer()
class PlusAPIMixin:
def __init__(self, telemetry: Telemetry) -> None:
try:
telemetry.set_tracer()
self.plus_api_client = PlusAPI(api_key=get_auth_token())
except Exception:
telemetry.deploy_signup_error_span()
console.print(
"Please sign up/login to CrewAI+ before using the CLI.",
style="bold red",
)
console.print("Run 'crewai login' to sign up/login.", style="bold green")
raise SystemExit from None
def _validate_response(self, response: httpx.Response) -> None:
"""
Handle and display error messages from API responses.
Args:
response (httpx.Response): The response from the Plus API
"""
try:
json_response = response.json()
except (json.JSONDecodeError, ValueError):
console.print(
"Failed to parse response from Enterprise API failed. Details:",
style="bold red",
)
console.print(f"Status Code: {response.status_code}")
console.print(
f"Response:\n{response.content.decode('utf-8', errors='replace')}"
)
raise SystemExit from None
if response.status_code == 422:
console.print(
"Failed to complete operation. Please fix the following errors:",
style="bold red",
)
for field, messages in json_response.items():
for message in messages:
console.print(
f"* [bold red]{field.capitalize()}[/bold red] {message}"
)
raise SystemExit
if not response.is_success:
console.print(
"Request to Enterprise API failed. Details:", style="bold red"
)
details = (
json_response.get("error")
or json_response.get("message")
or response.content.decode("utf-8", errors="replace")
)
console.print(f"{details}")
raise SystemExit

View File

@@ -1,89 +0,0 @@
from functools import lru_cache
import subprocess
class Repository:
def __init__(self, path: str = ".") -> None:
self.path = path
if not self.is_git_installed():
raise ValueError("Git is not installed or not found in your PATH.")
if not self.is_git_repo():
raise ValueError(f"{self.path} is not a Git repository.")
self.fetch()
@staticmethod
def is_git_installed() -> bool:
"""Check if Git is installed and available in the system."""
try:
subprocess.run(
["git", "--version"], # noqa: S607
capture_output=True,
check=True,
text=True,
)
return True
except (subprocess.CalledProcessError, FileNotFoundError):
return False
def fetch(self) -> None:
"""Fetch latest updates from the remote."""
subprocess.run(["git", "fetch"], cwd=self.path, check=True) # noqa: S607
def status(self) -> str:
"""Get the git status in porcelain format."""
return subprocess.check_output(
["git", "status", "--branch", "--porcelain"], # noqa: S607
cwd=self.path,
encoding="utf-8",
).strip()
@lru_cache(maxsize=None) # noqa: B019
def is_git_repo(self) -> bool:
"""Check if the current directory is a git repository.
Notes:
- TODO: This method is cached to avoid redundant checks, but using lru_cache on methods can lead to memory leaks
"""
try:
subprocess.check_output(
["git", "rev-parse", "--is-inside-work-tree"], # noqa: S607
cwd=self.path,
encoding="utf-8",
)
return True
except subprocess.CalledProcessError:
return False
def has_uncommitted_changes(self) -> bool:
"""Check if the repository has uncommitted changes."""
return len(self.status().splitlines()) > 1
def is_ahead_or_behind(self) -> bool:
"""Check if the repository is ahead or behind the remote."""
for line in self.status().splitlines():
if line.startswith("##") and ("ahead" in line or "behind" in line):
return True
return False
def is_synced(self) -> bool:
"""Return True if the Git repository is fully synced with the remote, False otherwise."""
if self.has_uncommitted_changes() or self.is_ahead_or_behind():
return False
return True
def origin_url(self) -> str | None:
"""Get the Git repository's remote URL."""
try:
result = subprocess.run(
["git", "remote", "get-url", "origin"], # noqa: S607
cwd=self.path,
capture_output=True,
text=True,
check=True,
)
return result.stdout.strip()
except subprocess.CalledProcessError:
return None

View File

@@ -1,231 +0,0 @@
from collections import defaultdict
from collections.abc import Sequence
import json
import os
from pathlib import Path
import time
from typing import Any
import certifi
import click
import httpx
from crewai.cli.constants import JSON_URL, MODELS, PROVIDERS
def select_choice(prompt_message: str, choices: Sequence[str]) -> str | None:
"""Presents a list of choices to the user and prompts them to select one.
Args:
prompt_message: The message to display to the user before presenting the choices.
choices: A list of options to present to the user.
Returns:
The selected choice from the list, or None if the user chooses to quit.
"""
provider_models = get_provider_data()
if not provider_models:
return None
click.secho(prompt_message, fg="cyan")
for idx, choice in enumerate(choices, start=1):
click.secho(f"{idx}. {choice}", fg="cyan")
click.secho("q. Quit", fg="cyan")
while True:
choice = click.prompt(
"Enter the number of your choice or 'q' to quit", type=str
)
if choice.lower() == "q":
return None
try:
selected_index = int(choice) - 1
if 0 <= selected_index < len(choices):
return choices[selected_index]
except ValueError:
pass
click.secho(
"Invalid selection. Please select a number between 1 and 6 or 'q' to quit.",
fg="red",
)
def select_provider(provider_models: dict[str, list[str]]) -> str | None | bool:
"""Presents a list of providers to the user and prompts them to select one.
Args:
provider_models: A dictionary of provider models.
Returns:
The selected provider, None if user explicitly quits, or False if no selection.
"""
predefined_providers = [p.lower() for p in PROVIDERS]
all_providers = sorted(set(predefined_providers + list(provider_models.keys())))
provider = select_choice(
"Select a provider to set up:", [*predefined_providers, "other"]
)
if provider is None: # User typed 'q'
return None
if provider == "other":
provider = select_choice("Select a provider from the full list:", all_providers)
if provider is None: # User typed 'q'
return None
return provider.lower() if provider else False
def select_model(provider: str, provider_models: dict[str, list[str]]) -> str | None:
"""Presents a list of models for a given provider to the user and prompts them to select one.
Args:
provider: The provider for which to select a model.
provider_models: A dictionary of provider models.
Returns:
The selected model, or None if the operation is aborted or an invalid selection is made.
"""
predefined_providers = [p.lower() for p in PROVIDERS]
if provider in predefined_providers:
available_models = MODELS.get(provider, [])
else:
available_models = provider_models.get(provider, [])
if not available_models:
click.secho(f"No models available for provider '{provider}'.", fg="red")
return None
return select_choice(
f"Select a model to use for {provider.capitalize()}:", available_models
)
def load_provider_data(cache_file: Path, cache_expiry: int) -> dict[str, Any] | None:
"""Loads provider data from a cache file if it exists and is not expired.
If the cache is expired or corrupted, it fetches the data from the web.
Args:
cache_file: The path to the cache file.
cache_expiry: The cache expiry time in seconds.
Returns:
The loaded provider data or None if the operation fails.
"""
current_time = time.time()
if (
cache_file.exists()
and (current_time - cache_file.stat().st_mtime) < cache_expiry
):
data = read_cache_file(cache_file)
if data:
return data
click.secho(
"Cache is corrupted. Fetching provider data from the web...", fg="yellow"
)
else:
click.secho(
"Cache expired or not found. Fetching provider data from the web...",
fg="cyan",
)
return fetch_provider_data(cache_file)
def read_cache_file(cache_file: Path) -> dict[str, Any] | None:
"""Reads and returns the JSON content from a cache file.
Args:
cache_file: The path to the cache file.
Returns:
The JSON content of the cache file or None if the JSON is invalid.
"""
try:
with open(cache_file, "r") as f:
data: dict[str, Any] = json.load(f)
return data
except json.JSONDecodeError:
return None
def fetch_provider_data(cache_file: Path) -> dict[str, Any] | None:
"""Fetches provider data from a specified URL and caches it to a file.
Args:
cache_file: The path to the cache file.
Returns:
The fetched provider data or None if the operation fails.
"""
ssl_config = os.environ["SSL_CERT_FILE"] = certifi.where()
try:
with httpx.stream("GET", JSON_URL, timeout=60, verify=ssl_config) as response:
response.raise_for_status()
data = download_data(response)
with open(cache_file, "w") as f:
json.dump(data, f)
return data
except httpx.HTTPError as e:
click.secho(f"Error fetching provider data: {e}", fg="red")
except json.JSONDecodeError:
click.secho("Error parsing provider data. Invalid JSON format.", fg="red")
return None
def download_data(response: httpx.Response) -> dict[str, Any]:
"""Downloads data from a given HTTP response and returns the JSON content.
Args:
response: The HTTP response object.
Returns:
The JSON content of the response.
"""
total_size = int(response.headers.get("content-length", 0))
block_size = 8192
data_chunks: list[bytes] = []
bar: Any
with click.progressbar(
length=total_size, label="Downloading", show_pos=True
) as bar:
for chunk in response.iter_bytes(block_size):
if chunk:
data_chunks.append(chunk)
bar.update(len(chunk))
data_content = b"".join(data_chunks)
result: dict[str, Any] = json.loads(data_content.decode("utf-8"))
return result
def get_provider_data() -> dict[str, list[str]] | None:
"""Retrieves provider data from a cache file.
Filters out models based on provider criteria, and returns a dictionary of providers
mapped to their models.
Returns:
A dictionary of providers mapped to their models or None if the operation fails.
"""
cache_dir = Path.home() / ".crewai"
cache_dir.mkdir(exist_ok=True)
cache_file = cache_dir / "provider_cache.json"
cache_expiry = 24 * 3600
data = load_provider_data(cache_file, cache_expiry)
if not data:
return None
provider_models = defaultdict(list)
for model_name, properties in data.items():
provider = properties.get("litellm_provider", "").strip().lower()
if "http" in provider or provider == "other":
continue
if provider:
provider_models[provider].append(model_name)
return provider_models

View File

@@ -1,3 +1,5 @@
"""CrewAI constants."""
from typing import Any

View File

@@ -8,17 +8,17 @@ import uuid
from rich.console import Console
from rich.panel import Panel
from crewai.cli.authentication.token import AuthError, get_auth_token
from crewai.cli.config import Settings
from crewai.cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL
from crewai.cli.plus_api import PlusAPI
from crewai.cli.version import get_crewai_version
from crewai.auth.token import AuthError, get_auth_token
from crewai.constants import DEFAULT_CREWAI_ENTERPRISE_URL
from crewai.events.listeners.tracing.types import TraceEvent
from crewai.events.listeners.tracing.utils import (
get_user_id,
is_tracing_enabled_in_context,
should_auto_collect_first_time_traces,
)
from crewai.plus_api import PlusAPI
from crewai.settings import Settings
from crewai.version import get_crewai_version
logger = getLogger(__name__)

View File

@@ -6,8 +6,7 @@ import uuid
from typing_extensions import Self
from crewai.cli.authentication.token import AuthError, get_auth_token
from crewai.cli.version import get_crewai_version
from crewai.auth.token import AuthError, get_auth_token
from crewai.events.base_event_listener import BaseEventListener
from crewai.events.base_events import BaseEvent
from crewai.events.event_bus import CrewAIEventsBus
@@ -110,6 +109,7 @@ from crewai.events.types.tool_usage_events import (
ToolUsageStartedEvent,
)
from crewai.events.utils.console_formatter import ConsoleFormatter
from crewai.version import get_crewai_version
class TraceCollectionListener(BaseEventListener):

View File

@@ -8,7 +8,7 @@ from rich.live import Live
from rich.panel import Panel
from rich.text import Text
from crewai.cli.version import is_current_version_yanked, is_newer_version_available
from crewai.version import is_current_version_yanked, is_newer_version_available
_disable_version_check: ContextVar[bool] = ContextVar(

View File

@@ -195,7 +195,7 @@ class MCPToolResolver:
get_platform_integration_token,
)
from crewai.cli.plus_api import PlusAPI
from crewai.plus_api import PlusAPI
plus_api = PlusAPI(api_key=get_platform_integration_token())
response = plus_api.get_mcp_configs(slugs)
@@ -285,6 +285,7 @@ class MCPToolResolver:
independent transport so that parallel tool executions never share
state.
"""
transport: StdioTransport | HTTPTransport | SSETransport
if isinstance(mcp_config, MCPServerStdio):
transport = StdioTransport(
command=mcp_config.command,

View File

@@ -1,18 +1,18 @@
"""CrewAI+ API client."""
import os
from typing import Any
from urllib.parse import urljoin
import httpx
from crewai.cli.config import Settings
from crewai.cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL
from crewai.cli.version import get_crewai_version
from crewai.constants import DEFAULT_CREWAI_ENTERPRISE_URL
from crewai.settings import Settings
from crewai.version import get_crewai_version
class PlusAPI:
"""
This class exposes methods for working with the CrewAI+ API.
"""
"""Client for working with the CrewAI+ API."""
TOOLS_RESOURCE = "/crewai_plus/api/v1/tools"
ORGANIZATIONS_RESOURCE = "/crewai_plus/api/v1/me/organizations"

View File

@@ -1,3 +1,5 @@
"""CrewAI platform settings management."""
import json
from logging import getLogger
from pathlib import Path
@@ -6,14 +8,14 @@ from typing import Any
from pydantic import BaseModel, Field
from crewai.cli.constants import (
from crewai.auth.token_manager import TokenManager
from crewai.constants import (
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER,
DEFAULT_CREWAI_ENTERPRISE_URL,
)
from crewai.cli.shared.token_manager import TokenManager
logger = getLogger(__name__)
@@ -22,8 +24,7 @@ DEFAULT_CONFIG_PATH = Path.home() / ".config" / "crewai" / "settings.json"
def get_writable_config_path() -> Path | None:
"""
Find a writable location for the config file with fallback options.
"""Find a writable location for the config file with fallback options.
Tries in order:
1. Default: ~/.config/crewai/settings.json
@@ -32,12 +33,12 @@ def get_writable_config_path() -> Path | None:
4. In-memory only (returns None)
Returns:
Path object for writable config location, or None if no writable location found
Path object for writable config location, or None if no writable location found.
"""
fallback_paths = [
DEFAULT_CONFIG_PATH, # Default location
Path(tempfile.gettempdir()) / "crewai_settings.json", # Temporary directory
Path.cwd() / "crewai_settings.json", # Current working directory
DEFAULT_CONFIG_PATH,
Path(tempfile.gettempdir()) / "crewai_settings.json",
Path.cwd() / "crewai_settings.json",
]
for config_path in fallback_paths:
@@ -46,7 +47,7 @@ def get_writable_config_path() -> Path | None:
test_file = config_path.parent / ".crewai_write_test"
try:
test_file.write_text("test")
test_file.unlink() # Clean up test file
test_file.unlink()
logger.info(f"Using config path: {config_path}")
return config_path
except Exception: # noqa: S112
@@ -58,7 +59,6 @@ def get_writable_config_path() -> Path | None:
return None
# Settings that are related to the user's account
USER_SETTINGS_KEYS = [
"tool_repository_username",
"tool_repository_password",
@@ -66,7 +66,6 @@ USER_SETTINGS_KEYS = [
"org_uuid",
]
# Settings that are related to the CLI
CLI_SETTINGS_KEYS = [
"enterprise_base_url",
"oauth2_provider",
@@ -76,7 +75,6 @@ CLI_SETTINGS_KEYS = [
"oauth2_extra",
]
# Default values for CLI settings
DEFAULT_CLI_SETTINGS = {
"enterprise_base_url": DEFAULT_CREWAI_ENTERPRISE_URL,
"oauth2_provider": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER,
@@ -86,13 +84,11 @@ DEFAULT_CLI_SETTINGS = {
"oauth2_extra": {},
}
# Readonly settings - cannot be set by the user
READONLY_SETTINGS_KEYS = [
"org_name",
"org_uuid",
]
# Hidden settings - not displayed by the 'list' command and cannot be set by the user
HIDDEN_SETTINGS_KEYS = [
"config_path",
"tool_repository_username",
@@ -101,8 +97,10 @@ HIDDEN_SETTINGS_KEYS = [
class Settings(BaseModel):
"""CrewAI platform settings."""
enterprise_base_url: str | None = Field(
default=DEFAULT_CLI_SETTINGS["enterprise_base_url"],
default=DEFAULT_CREWAI_ENTERPRISE_URL,
description="Base URL of the CrewAI AMP instance",
)
tool_repository_username: str | None = Field(
@@ -121,22 +119,22 @@ class Settings(BaseModel):
oauth2_provider: str = Field(
description="OAuth2 provider used for authentication (e.g., workos, okta, auth0).",
default=DEFAULT_CLI_SETTINGS["oauth2_provider"],
default=CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER,
)
oauth2_audience: str | None = Field(
description="OAuth2 audience value, typically used to identify the target API or resource.",
default=DEFAULT_CLI_SETTINGS["oauth2_audience"],
default=CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
)
oauth2_client_id: str = Field(
default=DEFAULT_CLI_SETTINGS["oauth2_client_id"],
default=CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
description="OAuth2 client ID issued by the provider, used during authentication requests.",
)
oauth2_domain: str = Field(
description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens.",
default=DEFAULT_CLI_SETTINGS["oauth2_domain"],
default=CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
)
oauth2_extra: dict[str, Any] = Field(
@@ -145,14 +143,12 @@ class Settings(BaseModel):
)
def __init__(self, config_path: Path | None = None, **data: dict[str, Any]) -> None:
"""Load Settings from config path with fallback support"""
"""Load Settings from config path with fallback support."""
if config_path is None:
config_path = get_writable_config_path()
# If config_path is None, we're in memory-only mode
if config_path is None:
merged_data = {**data}
# Dummy path for memory-only mode
super().__init__(config_path=Path("/dev/null"), **merged_data)
return
@@ -160,7 +156,6 @@ class Settings(BaseModel):
config_path.parent.mkdir(parents=True, exist_ok=True)
except Exception:
merged_data = {**data}
# Dummy path for memory-only mode
super().__init__(config_path=Path("/dev/null"), **merged_data)
return
@@ -176,19 +171,19 @@ class Settings(BaseModel):
super().__init__(config_path=config_path, **merged_data)
def clear_user_settings(self) -> None:
"""Clear all user settings"""
"""Clear all user settings."""
self._reset_user_settings()
self.dump()
def reset(self) -> None:
"""Reset all settings to default values"""
"""Reset all settings to default values."""
self._reset_user_settings()
self._reset_cli_settings()
self._clear_auth_tokens()
self.dump()
def dump(self) -> None:
"""Save current settings to settings.json"""
"""Save current settings to settings.json."""
if str(self.config_path) == "/dev/null":
return
@@ -207,15 +202,15 @@ class Settings(BaseModel):
pass
def _reset_user_settings(self) -> None:
"""Reset all user settings to default values"""
"""Reset all user settings to default values."""
for key in USER_SETTINGS_KEYS:
setattr(self, key, None)
def _reset_cli_settings(self) -> None:
"""Reset all CLI settings to default values"""
"""Reset all CLI settings to default values."""
for key in CLI_SETTINGS_KEYS:
setattr(self, key, DEFAULT_CLI_SETTINGS.get(key))
def _clear_auth_tokens(self) -> None:
"""Clear all authentication tokens"""
"""Clear all authentication tokens."""
TokenManager().clear_tokens()

View File

@@ -19,8 +19,8 @@ from crewai.agents.parser import (
OutputParserError,
parse,
)
from crewai.cli.config import Settings
from crewai.llms.base_llm import BaseLLM
from crewai.settings import Settings
from crewai.tools import BaseTool as CrewAITool
from crewai.tools.base_tool import BaseTool
from crewai.tools.structured_tool import CrewStructuredTool
@@ -1047,8 +1047,8 @@ def load_agent_from_repository(from_repository: str) -> dict[str, Any]:
if callable(_create_plus_client_hook):
client = _create_plus_client_hook()
else:
from crewai.cli.authentication.token import get_auth_token
from crewai.cli.plus_api import PlusAPI
from crewai.auth.token import get_auth_token
from crewai.plus_api import PlusAPI
client = PlusAPI(api_key=get_auth_token())
_print_current_organization()

View File

@@ -1,3 +1,5 @@
"""Interactive chat interface for CrewAI crews."""
import contextvars
import json
from pathlib import Path
@@ -12,15 +14,15 @@ import click
from packaging import version
import tomli
from crewai.cli.utils import read_toml
from crewai.cli.version import get_crewai_version
from crewai.crew import Crew
from crewai.llm import LLM
from crewai.llms.base_llm import BaseLLM
from crewai.types.crew_chat import ChatInputField, ChatInputs
from crewai.utilities.llm_utils import create_llm
from crewai.utilities.printer import Printer
from crewai.utilities.project_utils import read_toml
from crewai.utilities.types import LLMMessage
from crewai.version import get_crewai_version
_printer = Printer()
@@ -31,15 +33,14 @@ MIN_REQUIRED_VERSION: Final[Literal["0.98.0"]] = "0.98.0"
def check_conversational_crews_version(
crewai_version: str, pyproject_data: dict[str, Any]
) -> bool:
"""
Check if the installed crewAI version supports conversational crews.
"""Check if the installed crewAI version supports conversational crews.
Args:
crewai_version: The current version of crewAI.
pyproject_data: Dictionary containing pyproject.toml data.
Returns:
bool: True if version check passes, False otherwise.
True if version check passes, False otherwise.
"""
try:
if version.parse(crewai_version) < version.parse(MIN_REQUIRED_VERSION):
@@ -56,8 +57,8 @@ def check_conversational_crews_version(
def run_chat() -> None:
"""
Runs an interactive chat loop using the Crew's chat LLM with function calling.
"""Run an interactive chat loop using the Crew's chat LLM with function calling.
Incorporates crew_name, crew_description, and input fields to build a tool schema.
Exits if crew_name or crew_description are missing.
"""
@@ -72,14 +73,12 @@ def run_chat() -> None:
if not chat_llm:
return
# Indicate that the crew is being analyzed
click.secho(
"\nAnalyzing crew and required inputs - this may take 3 to 30 seconds "
"depending on the complexity of your crew.",
fg="white",
)
# Start loading indicator
loading_complete = threading.Event()
ctx = contextvars.copy_context()
loading_thread = threading.Thread(
@@ -92,16 +91,13 @@ def run_chat() -> None:
crew_tool_schema = generate_crew_tool_schema(crew_chat_inputs)
system_message = build_system_message(crew_chat_inputs)
# Call the LLM to generate the introductory message
introductory_message = chat_llm.call(
messages=[{"role": "system", "content": system_message}]
)
finally:
# Stop loading indicator
loading_complete.set()
loading_thread.join()
# Indicate that the analysis is complete
click.secho("\nFinished analyzing crew.\n", fg="white")
click.secho(f"Assistant: {introductory_message}\n", fg="green")
@@ -127,7 +123,7 @@ def show_loading(event: threading.Event) -> None:
def initialize_chat_llm(crew: Crew) -> LLM | BaseLLM | None:
"""Initializes the chat LLM and handles exceptions."""
"""Initialize the chat LLM and handle exceptions."""
try:
return create_llm(crew.chat_llm)
except Exception as e:
@@ -139,7 +135,7 @@ def initialize_chat_llm(crew: Crew) -> LLM | BaseLLM | None:
def build_system_message(crew_chat_inputs: ChatInputs) -> str:
"""Builds the initial system message for the chat."""
"""Build the initial system message for the chat."""
required_fields_str = (
", ".join(
f"{field.name} (desc: {field.description or 'n/a'})"
@@ -168,7 +164,7 @@ def build_system_message(crew_chat_inputs: ChatInputs) -> str:
def create_tool_function(crew: Crew, messages: list[LLMMessage]) -> Any:
"""Creates a wrapper function for running the crew tool with messages."""
"""Create a wrapper function for running the crew tool with messages."""
def run_crew_tool_with_messages(**kwargs: Any) -> str:
return run_crew_tool(crew, messages, **kwargs)
@@ -179,13 +175,11 @@ def create_tool_function(crew: Crew, messages: list[LLMMessage]) -> Any:
def flush_input() -> None:
"""Flush any pending input from the user."""
if platform.system() == "Windows":
# Windows platform
import msvcrt
while msvcrt.kbhit(): # type: ignore[attr-defined]
msvcrt.getch() # type: ignore[attr-defined]
else:
# Unix-like platforms (Linux, macOS)
import termios
termios.tcflush(sys.stdin, termios.TCIFLUSH)
@@ -200,7 +194,6 @@ def chat_loop(
"""Main chat loop for interacting with the user."""
while True:
try:
# Flush any pending input before accepting new input
flush_input()
user_input = get_user_input()
@@ -250,11 +243,9 @@ def handle_user_input(
messages.append({"role": "user", "content": user_input})
# Indicate that assistant is processing
click.echo()
click.secho("Assistant is processing your input. Please wait...", fg="green")
# Process assistant's response
final_response = chat_llm.call(
messages=messages,
tools=[crew_tool_schema],
@@ -266,12 +257,11 @@ def handle_user_input(
def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict[str, Any]:
"""
Dynamically build a Littellm 'function' schema for the given crew.
"""Dynamically build a Littellm 'function' schema for the given crew.
crew_name: The name of the crew (used for the function 'name').
crew_inputs: A ChatInputs object containing crew_description
and a list of input fields (each with a name & description).
Args:
crew_inputs: A ChatInputs object containing crew_description
and a list of input fields (each with a name & description).
"""
properties = {}
for field in crew_inputs.inputs:
@@ -297,70 +287,51 @@ def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict[str, Any]:
def run_crew_tool(crew: Crew, messages: list[LLMMessage], **kwargs: Any) -> str:
"""
Runs the crew using crew.kickoff(inputs=kwargs) and returns the output.
"""Run the crew using crew.kickoff(inputs=kwargs) and return the output.
Args:
crew (Crew): The crew instance to run.
messages (List[Dict[str, str]]): The chat messages up to this point.
crew: The crew instance to run.
messages: The chat messages up to this point.
**kwargs: The inputs collected from the user.
Returns:
str: The output from the crew's execution.
Raises:
SystemExit: Exits the chat if an error occurs during crew execution.
The output from the crew's execution.
"""
try:
# Serialize 'messages' to JSON string before adding to kwargs
kwargs["crew_chat_messages"] = json.dumps(messages)
# Run the crew with the provided inputs
crew_output = crew.kickoff(inputs=kwargs)
# Convert CrewOutput to a string to send back to the user
return str(crew_output)
except Exception as e:
# Exit the chat and show the error message
click.secho("An error occurred while running the crew:", fg="red")
click.secho(str(e), fg="red")
sys.exit(1)
def load_crew_and_name() -> tuple[Crew, str]:
"""
Loads the crew by importing the crew class from the user's project.
"""Load the crew by importing the crew class from the user's project.
Returns:
Tuple[Crew, str]: A tuple containing the Crew instance and the name of the crew.
A tuple containing the Crew instance and the name of the crew.
"""
# Get the current working directory
cwd = Path.cwd()
# Path to the pyproject.toml file
pyproject_path = cwd / "pyproject.toml"
if not pyproject_path.exists():
raise FileNotFoundError("pyproject.toml not found in the current directory.")
# Load the pyproject.toml file using 'tomli'
with pyproject_path.open("rb") as f:
pyproject_data = tomli.load(f)
# Get the project name from the 'project' section
project_name = pyproject_data["project"]["name"]
folder_name = project_name
# Derive the crew class name from the project name
# E.g., if project_name is 'my_project', crew_class_name is 'MyProject'
crew_class_name = project_name.replace("_", " ").title().replace(" ", "")
# Add the 'src' directory to sys.path
src_path = cwd / "src"
if str(src_path) not in sys.path:
sys.path.insert(0, str(src_path))
# Import the crew module
crew_module_name = f"{folder_name}.crew"
try:
crew_module = __import__(crew_module_name, fromlist=[crew_class_name])
@@ -369,7 +340,6 @@ def load_crew_and_name() -> tuple[Crew, str]:
f"Failed to import crew module {crew_module_name}: {e}"
) from e
# Get the crew class from the module
try:
crew_class = getattr(crew_module, crew_class_name)
except AttributeError as e:
@@ -377,7 +347,6 @@ def load_crew_and_name() -> tuple[Crew, str]:
f"Crew class {crew_class_name} not found in module {crew_module_name}"
) from e
# Instantiate the crew
crew_instance = crew_class().crew()
return crew_instance, crew_class_name
@@ -385,27 +354,23 @@ def load_crew_and_name() -> tuple[Crew, str]:
def generate_crew_chat_inputs(
crew: Crew, crew_name: str, chat_llm: LLM | BaseLLM
) -> ChatInputs:
"""
Generates the ChatInputs required for the crew by analyzing the tasks and agents.
"""Generate the ChatInputs required for the crew by analyzing the tasks and agents.
Args:
crew (Crew): The crew object containing tasks and agents.
crew_name (str): The name of the crew.
crew: The crew object containing tasks and agents.
crew_name: The name of the crew.
chat_llm: The chat language model to use for AI calls.
Returns:
ChatInputs: An object containing the crew's name, description, and input fields.
An object containing the crew's name, description, and input fields.
"""
# Extract placeholders from tasks and agents
required_inputs = fetch_required_inputs(crew)
# Generate descriptions for each input using AI
input_fields = []
for input_name in required_inputs:
description = generate_input_description_with_ai(input_name, crew, chat_llm)
input_fields.append(ChatInputField(name=input_name, description=description))
# Generate crew description using AI
crew_description = generate_crew_description_with_ai(crew, chat_llm)
return ChatInputs(
@@ -414,13 +379,13 @@ def generate_crew_chat_inputs(
def fetch_required_inputs(crew: Crew) -> set[str]:
"""Extracts placeholders from the crew's tasks and agents.
"""Extract placeholders from the crew's tasks and agents.
Args:
crew (Crew): The crew object.
crew: The crew object.
Returns:
Set[str]: A set of placeholder names.
A set of placeholder names.
"""
return crew.fetch_inputs()
@@ -428,18 +393,16 @@ def fetch_required_inputs(crew: Crew) -> set[str]:
def generate_input_description_with_ai(
input_name: str, crew: Crew, chat_llm: LLM | BaseLLM
) -> str:
"""
Generates an input description using AI based on the context of the crew.
"""Generate an input description using AI based on the context of the crew.
Args:
input_name (str): The name of the input placeholder.
crew (Crew): The crew object.
input_name: The name of the input placeholder.
crew: The crew object.
chat_llm: The chat language model to use for AI calls.
Returns:
str: A concise description of the input.
A concise description of the input.
"""
# Gather context from tasks and agents where the input is used
context_texts = []
placeholder_pattern = re.compile(r"\{(.+?)}")
@@ -448,7 +411,6 @@ def generate_input_description_with_ai(
f"{{{input_name}}}" in task.description
or f"{{{input_name}}}" in task.expected_output
):
# Replace placeholders with input names
task_description = placeholder_pattern.sub(
lambda m: m.group(1), task.description or ""
)
@@ -463,7 +425,6 @@ def generate_input_description_with_ai(
or f"{{{input_name}}}" in agent.goal
or f"{{{input_name}}}" in agent.backstory
):
# Replace placeholders with input names
agent_role = placeholder_pattern.sub(lambda m: m.group(1), agent.role or "")
agent_goal = placeholder_pattern.sub(lambda m: m.group(1), agent.goal or "")
agent_backstory = placeholder_pattern.sub(
@@ -475,7 +436,6 @@ def generate_input_description_with_ai(
context = "\n".join(context_texts)
if not context:
# If no context is found for the input, raise an exception as per instruction
raise ValueError(f"No context found for input '{input_name}'.")
prompt = (
@@ -489,22 +449,19 @@ def generate_input_description_with_ai(
def generate_crew_description_with_ai(crew: Crew, chat_llm: LLM | BaseLLM) -> str:
"""
Generates a brief description of the crew using AI.
"""Generate a brief description of the crew using AI.
Args:
crew (Crew): The crew object.
crew: The crew object.
chat_llm: The chat language model to use for AI calls.
Returns:
str: A concise description of the crew's purpose (15 words or less).
A concise description of the crew's purpose (15 words or less).
"""
# Gather context from tasks and agents
context_texts = []
placeholder_pattern = re.compile(r"\{(.+?)}")
for task in crew.tasks:
# Replace placeholders with input names
task_description = placeholder_pattern.sub(
lambda m: m.group(1), task.description or ""
)
@@ -514,7 +471,6 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm: LLM | BaseLLM) -> st
context_texts.append(f"Task Description: {task_description}")
context_texts.append(f"Expected Output: {expected_output}")
for agent in crew.agents:
# Replace placeholders with input names
agent_role = placeholder_pattern.sub(lambda m: m.group(1), agent.role or "")
agent_goal = placeholder_pattern.sub(lambda m: m.group(1), agent.goal or "")
agent_backstory = placeholder_pattern.sub(

View File

@@ -2,7 +2,7 @@ import logging
import os
from typing import Any, Final
from crewai.cli.constants import DEFAULT_LLM_MODEL, ENV_VARS, LITELLM_PARAMS
from crewai.constants import DEFAULT_LLM_MODEL, ENV_VARS, LITELLM_PARAMS
from crewai.llm import LLM
from crewai.llms.base_llm import BaseLLM

View File

@@ -1,20 +1,19 @@
"""Project utility functions for discovering crews, flows, and tools."""
from functools import reduce
import importlib.util
from inspect import getmro, isclass, isfunction, ismethod
import os
from pathlib import Path
import shutil
import sys
from typing import Any, cast, get_type_hints
import click
from rich.console import Console
import tomli
from crewai.cli.config import Settings
from crewai.cli.constants import ENV_VARS
from crewai.crew import Crew
from crewai.flow import Flow
from crewai.settings import Settings
if sys.version_info >= (3, 11):
@@ -23,25 +22,6 @@ if sys.version_info >= (3, 11):
console = Console()
def copy_template(
src: Path, dst: Path, name: str, class_name: str, folder_name: str
) -> None:
"""Copy a file from src to dst."""
with open(src, "r") as file:
content = file.read()
# Interpolate the content
content = content.replace("{{name}}", name)
content = content.replace("{{crew_name}}", class_name)
content = content.replace("{{folder_name}}", folder_name)
# Write the interpolated content to the new file
with open(dst, "w") as file:
file.write(content)
click.secho(f" - Created {dst}", fg="green")
def read_toml(file_path: str = "pyproject.toml") -> dict[str, Any]:
"""Read the content of a TOML file and return it as a dictionary."""
with open(file_path, "rb") as f:
@@ -49,6 +29,7 @@ def read_toml(file_path: str = "pyproject.toml") -> dict[str, Any]:
def parse_toml(content: str) -> dict[str, Any]:
"""Parse a TOML string and return it as a dictionary."""
if sys.version_info >= (3, 11):
return tomllib.loads(content)
return tomli.loads(content)
@@ -104,7 +85,6 @@ def _get_project_attribute(
style="bold red",
)
except Exception as e:
# Handle TOML decode errors for Python 3.11+
if sys.version_info >= (3, 11) and isinstance(e, tomllib.TOMLDecodeError):
console.print(
f"Error: {pyproject_path} is not a valid TOML file.", style="bold red"
@@ -128,164 +108,18 @@ def _get_nested_value(data: dict[str, Any], keys: list[str]) -> Any:
return reduce(dict.__getitem__, keys, data)
def fetch_and_json_env_file(env_file_path: str = ".env") -> dict[str, Any]:
"""Fetch the environment variables from a .env file and return them as a dictionary."""
try:
# Read the .env file
with open(env_file_path, "r") as f:
env_content = f.read()
# Parse the .env file content to a dictionary
env_dict = {}
for line in env_content.splitlines():
if line.strip() and not line.strip().startswith("#"):
key, value = line.split("=", 1)
env_dict[key.strip()] = value.strip()
return env_dict
except FileNotFoundError:
console.print(f"Error: {env_file_path} not found.", style="bold red")
except Exception as e:
console.print(f"Error reading the .env file: {e}", style="bold red")
return {}
def tree_copy(source: Path, destination: Path) -> None:
"""Copies the entire directory structure from the source to the destination."""
for item in os.listdir(source):
source_item = os.path.join(source, item)
destination_item = os.path.join(destination, item)
if os.path.isdir(source_item):
shutil.copytree(source_item, destination_item)
else:
shutil.copy2(source_item, destination_item)
def tree_find_and_replace(directory: Path, find: str, replace: str) -> None:
"""Recursively searches through a directory, replacing a target string in
both file contents and filenames with a specified replacement string.
"""
for path, dirs, files in os.walk(os.path.abspath(directory), topdown=False):
for filename in files:
filepath = os.path.join(path, filename)
with open(filepath, "r", encoding="utf-8", errors="ignore") as file:
contents = file.read()
with open(filepath, "w") as file:
file.write(contents.replace(find, replace))
if find in filename:
new_filename = filename.replace(find, replace)
new_filepath = os.path.join(path, new_filename)
os.rename(filepath, new_filepath)
for dirname in dirs:
if find in dirname:
new_dirname = dirname.replace(find, replace)
new_dirpath = os.path.join(path, new_dirname)
old_dirpath = os.path.join(path, dirname)
os.rename(old_dirpath, new_dirpath)
def load_env_vars(folder_path: Path) -> dict[str, Any]:
"""
Loads environment variables from a .env file in the specified folder path.
Args:
- folder_path (Path): The path to the folder containing the .env file.
Returns:
- dict: A dictionary of environment variables.
"""
env_file_path = folder_path / ".env"
env_vars = {}
if env_file_path.exists():
with open(env_file_path, "r") as file:
for line in file:
key, _, value = line.strip().partition("=")
if key and value:
env_vars[key] = value
return env_vars
def update_env_vars(
env_vars: dict[str, Any], provider: str, model: str
) -> dict[str, Any] | None:
"""
Updates environment variables with the API key for the selected provider and model.
Args:
- env_vars (dict): Environment variables dictionary.
- provider (str): Selected provider.
- model (str): Selected model.
Returns:
- None
"""
provider_config = cast(
list[str],
ENV_VARS.get(
provider,
[
click.prompt(
f"Enter the environment variable name for your {provider.capitalize()} API key",
type=str,
)
],
),
)
api_key_var = provider_config[0]
if api_key_var not in env_vars:
try:
env_vars[api_key_var] = click.prompt(
f"Enter your {provider.capitalize()} API key", type=str, hide_input=True
)
except click.exceptions.Abort:
click.secho("Operation aborted by the user.", fg="red")
return None
else:
click.secho(f"API key already exists for {provider.capitalize()}.", fg="yellow")
env_vars["MODEL"] = model
click.secho(f"Selected model: {model}", fg="green")
return env_vars
def write_env_file(folder_path: Path, env_vars: dict[str, Any]) -> None:
"""
Writes environment variables to a .env file in the specified folder.
Args:
- folder_path (Path): The path to the folder where the .env file will be written.
- env_vars (dict): A dictionary of environment variables to write.
"""
env_file_path = folder_path / ".env"
with open(env_file_path, "w") as file:
for key, value in env_vars.items():
file.write(f"{key.upper()}={value}\n")
def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
"""Get the crew instances from a file."""
crew_instances = []
try:
import importlib.util
# Add the current directory to sys.path to ensure imports resolve correctly
current_dir = os.getcwd()
if current_dir not in sys.path:
sys.path.insert(0, current_dir)
# If we're not in src directory but there's a src directory, add it to path
src_dir = os.path.join(current_dir, "src")
if os.path.isdir(src_dir) and src_dir not in sys.path:
sys.path.insert(0, src_dir)
# Search in both current directory and src directory if it exists
search_paths = [".", "src"] if os.path.isdir("src") else ["."]
for search_path in search_paths:
@@ -316,7 +150,6 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
)
continue
# If we found crew instances, break out of the loop
if crew_instances:
break
@@ -334,7 +167,6 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
)
continue
# If we found crew instances in this search path, break out of the search paths loop
if crew_instances:
break
@@ -352,6 +184,7 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
def get_crew_instance(module_attr: Any) -> Crew | None:
"""Get a Crew instance from a module attribute."""
if (
callable(module_attr)
and hasattr(module_attr, "is_crew_class")
@@ -372,6 +205,7 @@ def get_crew_instance(module_attr: Any) -> Crew | None:
def fetch_crews(module_attr: Any) -> list[Crew]:
"""Fetch crew instances from a module attribute."""
crew_instances: list[Crew] = []
if crew_instance := get_crew_instance(module_attr):
@@ -386,7 +220,7 @@ def fetch_crews(module_attr: Any) -> list[Crew]:
return crew_instances
def get_flow_instance(module_attr: Any) -> Flow | None:
def get_flow_instance(module_attr: Any) -> Flow[Any] | None:
"""Check if a module attribute is a user-defined Flow subclass and return an instance.
Args:
@@ -413,13 +247,12 @@ _SKIP_DIRS = frozenset(
)
def get_flows(flow_path: str = "main.py") -> list[Flow]:
def get_flows(flow_path: str = "main.py") -> list[Flow[Any]]:
"""Get the flow instances from project files.
Walks the project directory looking for files matching ``flow_path``
(default ``main.py``), loads each module, and extracts Flow subclass
instances. Directories that are clearly not user source code (virtual
environments, ``.git``, etc.) are pruned to avoid noisy import errors.
instances.
Args:
flow_path: Filename to search for (default ``main.py``).
@@ -427,7 +260,7 @@ def get_flows(flow_path: str = "main.py") -> list[Flow]:
Returns:
A list of discovered Flow instances.
"""
flow_instances: list[Flow] = []
flow_instances: list[Flow[Any]] = []
try:
current_dir = os.getcwd()
if current_dir not in sys.path:
@@ -486,6 +319,7 @@ def get_flows(flow_path: str = "main.py") -> list[Flow]:
def is_valid_tool(obj: Any) -> bool:
"""Check if an object is a valid CrewAI tool."""
from crewai.tools.base_tool import Tool
if isclass(obj):
@@ -498,12 +332,12 @@ def is_valid_tool(obj: Any) -> bool:
def extract_available_exports(dir_path: str = "src") -> list[dict[str, Any]]:
"""
Extract available tool classes from the project's __init__.py files.
"""Extract available tool classes from the project's __init__.py files.
Only includes classes that inherit from BaseTool or functions decorated with @tool.
Returns:
list: A list of valid tool class names or ["BaseTool"] if none found
A list of valid tool class names or ["BaseTool"] if none found.
"""
try:
init_files = Path(dir_path).glob("**/__init__.py")
@@ -530,6 +364,7 @@ def extract_available_exports(dir_path: str = "src") -> list[dict[str, Any]]:
def build_env_with_tool_repository_credentials(
repository_handle: str,
) -> dict[str, Any]:
"""Build environment variables with tool repository credentials."""
repository_handle = repository_handle.upper().replace("-", "_")
settings = Settings()
@@ -545,9 +380,7 @@ def build_env_with_tool_repository_credentials(
def _load_tools_from_init(init_file: Path) -> list[dict[str, Any]]:
"""
Load and validate tools from a given __init__.py file.
"""
"""Load and validate tools from a given __init__.py file."""
spec = importlib.util.spec_from_file_location("temp_module", init_file)
if not spec or not spec.loader:
@@ -583,9 +416,7 @@ def _load_tools_from_init(init_file: Path) -> list[dict[str, Any]]:
def _print_no_tools_warning() -> None:
"""
Display warning and usage instructions if no tools were found.
"""
"""Display warning and usage instructions if no tools were found."""
console.print(
"\n[bold yellow]Warning: No valid tools were exposed in your __init__.py file![/bold yellow]"
)

View File

@@ -1,12 +1,15 @@
"""Memory reset utilities for CrewAI crews and flows."""
import subprocess
from typing import Any
import click
from crewai.cli.utils import get_crews, get_flows
from crewai.flow import Flow
from crewai.utilities.project_utils import get_crews, get_flows
def _reset_flow_memory(flow: Flow) -> None:
def _reset_flow_memory(flow: Flow[Any]) -> None:
"""Reset memory for a single flow instance.
Handles Memory, MemoryScope (both have .reset()), and MemorySlice

View File

@@ -1,4 +1,4 @@
"""Version utilities for CrewAI CLI."""
"""Version utilities for CrewAI."""
from collections.abc import Mapping
from datetime import datetime, timedelta
@@ -26,7 +26,7 @@ def _get_cache_file() -> Path:
def get_crewai_version() -> str:
"""Get the version number of CrewAI running the CLI."""
"""Get the version number of the installed CrewAI package."""
return importlib.metadata.version("crewai")

View File

@@ -6,7 +6,7 @@ from unittest import mock
from unittest.mock import MagicMock, patch
from crewai.agents.crew_agent_executor import AgentFinish, CrewAgentExecutor
from crewai.cli.constants import DEFAULT_LLM_MODEL
from crewai.constants import DEFAULT_LLM_MODEL
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.tool_usage_events import ToolUsageFinishedEvent
from crewai.knowledge.knowledge import Knowledge
@@ -2046,12 +2046,12 @@ def test_get_knowledge_search_query():
@pytest.fixture
def mock_get_auth_token():
with patch(
"crewai.cli.authentication.token.get_auth_token", return_value="test_token"
"crewai.auth.token.get_auth_token", return_value="test_token"
):
yield
@patch("crewai.cli.plus_api.PlusAPI.get_agent")
@patch("crewai.plus_api.PlusAPI.get_agent")
def test_agent_from_repository(mock_get_agent, mock_get_auth_token):
from crewai_tools import (
FileReadTool,
@@ -2092,7 +2092,7 @@ def test_agent_from_repository(mock_get_agent, mock_get_auth_token):
assert agent.tools[1].file_path == "test.txt"
@patch("crewai.cli.plus_api.PlusAPI.get_agent")
@patch("crewai.plus_api.PlusAPI.get_agent")
def test_agent_from_repository_override_attributes(mock_get_agent, mock_get_auth_token):
from crewai_tools import SerperDevTool
@@ -2116,7 +2116,7 @@ def test_agent_from_repository_override_attributes(mock_get_agent, mock_get_auth
assert isinstance(agent.tools[0], SerperDevTool)
@patch("crewai.cli.plus_api.PlusAPI.get_agent")
@patch("crewai.plus_api.PlusAPI.get_agent")
def test_agent_from_repository_with_invalid_tools(mock_get_agent, mock_get_auth_token):
mock_get_response = MagicMock()
mock_get_response.status_code = 200
@@ -2139,7 +2139,7 @@ def test_agent_from_repository_with_invalid_tools(mock_get_agent, mock_get_auth_
Agent(from_repository="test_agent")
@patch("crewai.cli.plus_api.PlusAPI.get_agent")
@patch("crewai.plus_api.PlusAPI.get_agent")
def test_agent_from_repository_internal_error(mock_get_agent, mock_get_auth_token):
mock_get_response = MagicMock()
mock_get_response.status_code = 500
@@ -2152,7 +2152,7 @@ def test_agent_from_repository_internal_error(mock_get_agent, mock_get_auth_toke
Agent(from_repository="test_agent")
@patch("crewai.cli.plus_api.PlusAPI.get_agent")
@patch("crewai.plus_api.PlusAPI.get_agent")
def test_agent_from_repository_agent_not_found(mock_get_agent, mock_get_auth_token):
mock_get_response = MagicMock()
mock_get_response.status_code = 404
@@ -2165,7 +2165,7 @@ def test_agent_from_repository_agent_not_found(mock_get_agent, mock_get_auth_tok
Agent(from_repository="test_agent")
@patch("crewai.cli.plus_api.PlusAPI.get_agent")
@patch("crewai.plus_api.PlusAPI.get_agent")
@patch("crewai.utilities.agent_utils.Settings")
@patch("crewai.utilities.agent_utils.console")
def test_agent_from_repository_displays_org_info(
@@ -2198,7 +2198,7 @@ def test_agent_from_repository_displays_org_info(
assert agent.backstory == "test backstory"
@patch("crewai.cli.plus_api.PlusAPI.get_agent")
@patch("crewai.plus_api.PlusAPI.get_agent")
@patch("crewai.utilities.agent_utils.Settings")
@patch("crewai.utilities.agent_utils.console")
def test_agent_from_repository_without_org_set(

View File

@@ -1,6 +1,6 @@
import pytest
from crewai.cli.authentication.main import Oauth2Settings
from crewai.cli.authentication.providers.auth0 import Auth0Provider
from crewai.auth.oauth2 import Oauth2Settings
from crewai.auth.providers.auth0 import Auth0Provider

View File

@@ -1,7 +1,7 @@
import pytest
from crewai.cli.authentication.main import Oauth2Settings
from crewai.cli.authentication.providers.entra_id import EntraIdProvider
from crewai.auth.oauth2 import Oauth2Settings
from crewai.auth.providers.entra_id import EntraIdProvider
class TestEntraIdProvider:

View File

@@ -1,7 +1,7 @@
import pytest
from crewai.cli.authentication.main import Oauth2Settings
from crewai.cli.authentication.providers.keycloak import KeycloakProvider
from crewai.auth.oauth2 import Oauth2Settings
from crewai.auth.providers.keycloak import KeycloakProvider
class TestKeycloakProvider:

View File

@@ -1,7 +1,7 @@
import pytest
from crewai.cli.authentication.main import Oauth2Settings
from crewai.cli.authentication.providers.okta import OktaProvider
from crewai.auth.oauth2 import Oauth2Settings
from crewai.auth.providers.okta import OktaProvider
class TestOktaProvider:

View File

@@ -1,6 +1,6 @@
import pytest
from crewai.cli.authentication.main import Oauth2Settings
from crewai.cli.authentication.providers.workos import WorkosProvider
from crewai.auth.oauth2 import Oauth2Settings
from crewai.auth.providers.workos import WorkosProvider
class TestWorkosProvider:

View File

@@ -3,8 +3,8 @@ from unittest.mock import MagicMock, call, patch
import pytest
import httpx
from crewai.cli.authentication.main import AuthenticationCommand
from crewai.cli.constants import (
from crewai.auth.oauth2 import AuthenticationCommand
from crewai.constants import (
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
@@ -14,7 +14,7 @@ from crewai.cli.constants import (
class TestAuthenticationCommand:
def setup_method(self):
# Mock Settings so we always use default constants regardless of local config.
with patch("crewai.cli.authentication.main.Settings") as mock_settings:
with patch("crewai.auth.oauth2.Settings") as mock_settings:
instance = mock_settings.return_value
instance.oauth2_provider = "workos"
instance.oauth2_domain = CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN
@@ -38,12 +38,12 @@ class TestAuthenticationCommand:
),
],
)
@patch("crewai.cli.authentication.main.AuthenticationCommand._get_device_code")
@patch("crewai.auth.oauth2.AuthenticationCommand._get_device_code")
@patch(
"crewai.cli.authentication.main.AuthenticationCommand._display_auth_instructions"
"crewai.auth.oauth2.AuthenticationCommand._display_auth_instructions"
)
@patch("crewai.cli.authentication.main.AuthenticationCommand._poll_for_token")
@patch("crewai.cli.authentication.main.console.print")
@patch("crewai.auth.oauth2.AuthenticationCommand._poll_for_token")
@patch("crewai.auth.oauth2.console.print")
def test_login(
self,
mock_console_print,
@@ -82,8 +82,8 @@ class TestAuthenticationCommand:
self.auth_command.oauth2_provider._get_domain() == expected_urls["domain"]
)
@patch("crewai.cli.authentication.main.webbrowser")
@patch("crewai.cli.authentication.main.console.print")
@patch("crewai.auth.oauth2.webbrowser")
@patch("crewai.auth.oauth2.console.print")
def test_display_auth_instructions(self, mock_console_print, mock_webbrowser):
device_code_data = {
"verification_uri_complete": "https://example.com/auth",
@@ -113,8 +113,8 @@ class TestAuthenticationCommand:
],
)
@pytest.mark.parametrize("has_expiration", [True, False])
@patch("crewai.cli.authentication.main.validate_jwt_token")
@patch("crewai.cli.authentication.main.TokenManager.save_tokens")
@patch("crewai.auth.oauth2.validate_jwt_token")
@patch("crewai.auth.oauth2.TokenManager.save_tokens")
def test_validate_and_save_token(
self,
mock_save_tokens,
@@ -123,8 +123,8 @@ class TestAuthenticationCommand:
jwt_config,
has_expiration,
):
from crewai.cli.authentication.main import Oauth2Settings
from crewai.cli.authentication.providers.workos import WorkosProvider
from crewai.auth.oauth2 import Oauth2Settings
from crewai.auth.providers.workos import WorkosProvider
if user_provider == "workos":
self.auth_command.oauth2_provider = WorkosProvider(
@@ -163,8 +163,8 @@ class TestAuthenticationCommand:
mock_save_tokens.assert_called_once_with("test_access_token", 0)
@patch("crewai_cli.tools.main.ToolCommand")
@patch("crewai.cli.authentication.main.Settings")
@patch("crewai.cli.authentication.main.console.print")
@patch("crewai.auth.oauth2.Settings")
@patch("crewai.auth.oauth2.console.print")
def test_login_to_tool_repository_success(
self, mock_console_print, mock_settings, mock_tool_command
):
@@ -196,7 +196,7 @@ class TestAuthenticationCommand:
mock_console_print.assert_has_calls(expected_calls)
@patch("crewai_cli.tools.main.ToolCommand")
@patch("crewai.cli.authentication.main.console.print")
@patch("crewai.auth.oauth2.console.print")
def test_login_to_tool_repository_error(
self, mock_console_print, mock_tool_command
):
@@ -226,7 +226,7 @@ class TestAuthenticationCommand:
]
mock_console_print.assert_has_calls(expected_calls)
@patch("crewai.cli.authentication.main.httpx.post")
@patch("crewai.auth.oauth2.httpx.post")
def test_get_device_code(self, mock_post):
mock_response = MagicMock()
mock_response.json.return_value = {
@@ -262,8 +262,8 @@ class TestAuthenticationCommand:
"verification_uri_complete": "https://example.com/auth",
}
@patch("crewai.cli.authentication.main.httpx.post")
@patch("crewai.cli.authentication.main.console.print")
@patch("crewai.auth.oauth2.httpx.post")
@patch("crewai.auth.oauth2.console.print")
def test_poll_for_token_success(self, mock_console_print, mock_post):
mock_response_success = MagicMock()
mock_response_success.status_code = 200
@@ -311,8 +311,8 @@ class TestAuthenticationCommand:
]
mock_console_print.assert_has_calls(expected_calls)
@patch("crewai.cli.authentication.main.httpx.post")
@patch("crewai.cli.authentication.main.console.print")
@patch("crewai.auth.oauth2.httpx.post")
@patch("crewai.auth.oauth2.console.print")
def test_poll_for_token_timeout(self, mock_console_print, mock_post):
mock_response_pending = MagicMock()
mock_response_pending.status_code = 400
@@ -330,7 +330,7 @@ class TestAuthenticationCommand:
"Timeout: Failed to get the token. Please try again.", style="bold red"
)
@patch("crewai.cli.authentication.main.httpx.post")
@patch("crewai.auth.oauth2.httpx.post")
def test_poll_for_token_error(self, mock_post):
"""Test the method to poll for token (error path)."""
# Setup mock to return error

View File

@@ -3,11 +3,11 @@ from unittest.mock import MagicMock, patch
import jwt
from crewai.cli.authentication.utils import validate_jwt_token
from crewai.auth.utils import validate_jwt_token
@patch("crewai.cli.authentication.utils.PyJWKClient", return_value=MagicMock())
@patch("crewai.cli.authentication.utils.jwt")
@patch("crewai.auth.utils.PyJWKClient", return_value=MagicMock())
@patch("crewai.auth.utils.jwt")
class TestUtils(unittest.TestCase):
def test_validate_jwt_token(self, mock_jwt, mock_pyjwkclient):
mock_jwt.decode.return_value = {"exp": 1719859200}

View File

@@ -27,9 +27,9 @@ def mock_crew():
@pytest.fixture
def mock_get_crews(mock_crew):
with mock.patch(
"crewai.cli.reset_memories_command.get_crews", return_value=[mock_crew]
"crewai.utilities.reset_memories.get_crews", return_value=[mock_crew]
) as mock_get_crew, mock.patch(
"crewai.cli.reset_memories_command.get_flows", return_value=[]
"crewai.utilities.reset_memories.get_flows", return_value=[]
):
yield mock_get_crew
@@ -169,9 +169,9 @@ def mock_flow():
@pytest.fixture
def mock_get_flows(mock_flow):
with mock.patch(
"crewai.cli.reset_memories_command.get_flows", return_value=[mock_flow]
"crewai.utilities.reset_memories.get_flows", return_value=[mock_flow]
) as mock_get_flow, mock.patch(
"crewai.cli.reset_memories_command.get_crews", return_value=[]
"crewai.utilities.reset_memories.get_crews", return_value=[]
):
yield mock_get_flow
@@ -196,9 +196,9 @@ def test_reset_flow_knowledge_no_effect(mock_get_flows, mock_flow, runner):
def test_reset_no_crew_or_flow_found(runner):
with mock.patch(
"crewai.cli.reset_memories_command.get_crews", return_value=[]
"crewai.utilities.reset_memories.get_crews", return_value=[]
), mock.patch(
"crewai.cli.reset_memories_command.get_flows", return_value=[]
"crewai.utilities.reset_memories.get_flows", return_value=[]
):
result = runner.invoke(reset_memories, ["-m"])
assert "No crew or flow found." in result.output
@@ -206,9 +206,9 @@ def test_reset_no_crew_or_flow_found(runner):
def test_reset_crew_and_flow_memory(mock_crew, mock_flow, runner):
with mock.patch(
"crewai.cli.reset_memories_command.get_crews", return_value=[mock_crew]
"crewai.utilities.reset_memories.get_crews", return_value=[mock_crew]
), mock.patch(
"crewai.cli.reset_memories_command.get_flows", return_value=[mock_flow]
"crewai.utilities.reset_memories.get_flows", return_value=[mock_flow]
):
result = runner.invoke(reset_memories, ["-m"])
mock_crew.reset_memories.assert_called_once_with(command_type="memory")
@@ -222,9 +222,9 @@ def test_reset_flow_memory_none(runner):
mock_flow.name = "NoMemFlow"
mock_flow.memory = None
with mock.patch(
"crewai.cli.reset_memories_command.get_crews", return_value=[]
"crewai.utilities.reset_memories.get_crews", return_value=[]
), mock.patch(
"crewai.cli.reset_memories_command.get_flows", return_value=[mock_flow]
"crewai.utilities.reset_memories.get_flows", return_value=[mock_flow]
):
result = runner.invoke(reset_memories, ["-m"])
assert "[Flow (NoMemFlow)] Memory has been reset." in result.output

View File

@@ -6,13 +6,13 @@ from datetime import datetime, timedelta
from pathlib import Path
from unittest.mock import MagicMock, patch
from crewai.cli.config import (
from crewai.settings import (
CLI_SETTINGS_KEYS,
DEFAULT_CLI_SETTINGS,
USER_SETTINGS_KEYS,
Settings,
)
from crewai.cli.shared.token_manager import TokenManager
from crewai.auth.token_manager import TokenManager
class TestSettings(unittest.TestCase):
@@ -69,7 +69,7 @@ class TestSettings(unittest.TestCase):
for key in user_settings.keys():
self.assertEqual(getattr(settings, key), None)
@patch("crewai.cli.config.TokenManager")
@patch("crewai.settings.TokenManager")
def test_reset_settings(self, mock_token_manager):
user_settings = {key: f"value_for_{key}" for key in USER_SETTINGS_KEYS}
cli_settings = {key: f"value_for_{key}" for key in CLI_SETTINGS_KEYS if key != "oauth2_extra"}

View File

@@ -1,4 +1,4 @@
from crewai.cli.constants import ENV_VARS, MODELS, PROVIDERS
from crewai.constants import ENV_VARS, MODELS, PROVIDERS
def test_huggingface_in_providers():

View File

@@ -1,101 +0,0 @@
import pytest
from crewai.cli.git import Repository
@pytest.fixture()
def repository(fp):
fp.register(["git", "--version"], stdout="git version 2.30.0\n")
fp.register(["git", "rev-parse", "--is-inside-work-tree"], stdout="true\n")
fp.register(["git", "fetch"], stdout="")
return Repository(path=".")
def test_init_with_invalid_git_repo(fp):
fp.register(["git", "--version"], stdout="git version 2.30.0\n")
fp.register(
["git", "rev-parse", "--is-inside-work-tree"],
returncode=1,
stderr="fatal: not a git repository\n",
)
with pytest.raises(ValueError):
Repository(path="invalid/path")
def test_is_git_not_installed(fp):
fp.register(["git", "--version"], returncode=1)
with pytest.raises(
ValueError, match="Git is not installed or not found in your PATH."
):
Repository(path=".")
def test_status(fp, repository):
fp.register(
["git", "status", "--branch", "--porcelain"],
stdout="## main...origin/main [ahead 1]\n",
)
assert repository.status() == "## main...origin/main [ahead 1]"
def test_has_uncommitted_changes(fp, repository):
fp.register(
["git", "status", "--branch", "--porcelain"],
stdout="## main...origin/main\n M somefile.txt\n",
)
assert repository.has_uncommitted_changes() is True
def test_is_ahead_or_behind(fp, repository):
fp.register(
["git", "status", "--branch", "--porcelain"],
stdout="## main...origin/main [ahead 1]\n",
)
assert repository.is_ahead_or_behind() is True
def test_is_synced_when_synced(fp, repository):
fp.register(
["git", "status", "--branch", "--porcelain"], stdout="## main...origin/main\n"
)
fp.register(
["git", "status", "--branch", "--porcelain"], stdout="## main...origin/main\n"
)
assert repository.is_synced() is True
def test_is_synced_with_uncommitted_changes(fp, repository):
fp.register(
["git", "status", "--branch", "--porcelain"],
stdout="## main...origin/main\n M somefile.txt\n",
)
assert repository.is_synced() is False
def test_is_synced_when_ahead_or_behind(fp, repository):
fp.register(
["git", "status", "--branch", "--porcelain"],
stdout="## main...origin/main [ahead 1]\n",
)
fp.register(
["git", "status", "--branch", "--porcelain"],
stdout="## main...origin/main [ahead 1]\n",
)
assert repository.is_synced() is False
def test_is_synced_with_uncommitted_changes_and_ahead(fp, repository):
fp.register(
["git", "status", "--branch", "--porcelain"],
stdout="## main...origin/main [ahead 1]\n M somefile.txt\n",
)
assert repository.is_synced() is False
def test_origin_url(fp, repository):
fp.register(
["git", "remote", "get-url", "origin"],
stdout="https://github.com/user/repo.git\n",
)
assert repository.origin_url() == "https://github.com/user/repo.git"

View File

@@ -4,7 +4,7 @@ from unittest.mock import ANY, AsyncMock, MagicMock, patch
import pytest
from crewai.cli.plus_api import PlusAPI
from crewai.plus_api import PlusAPI
class TestPlusAPI(unittest.TestCase):
@@ -20,7 +20,7 @@ class TestPlusAPI(unittest.TestCase):
self.assertTrue("CrewAI-CLI/" in self.api.headers["User-Agent"])
self.assertTrue(self.api.headers["X-Crewai-Version"])
@patch("crewai.cli.plus_api.PlusAPI._make_request")
@patch("crewai.plus_api.PlusAPI._make_request")
def test_login_to_tool_repository(self, mock_make_request):
mock_response = MagicMock()
mock_make_request.return_value = mock_response
@@ -32,7 +32,7 @@ class TestPlusAPI(unittest.TestCase):
)
self.assertEqual(response, mock_response)
@patch("crewai.cli.plus_api.PlusAPI._make_request")
@patch("crewai.plus_api.PlusAPI._make_request")
def test_login_to_tool_repository_with_user_identifier(self, mock_make_request):
mock_response = MagicMock()
mock_make_request.return_value = mock_response
@@ -60,8 +60,8 @@ class TestPlusAPI(unittest.TestCase):
**kwargs,
)
@patch("crewai.cli.plus_api.Settings")
@patch("crewai.cli.plus_api.httpx.Client")
@patch("crewai.plus_api.Settings")
@patch("crewai.plus_api.httpx.Client")
def test_login_to_tool_repository_with_org_uuid(
self, mock_client_class, mock_settings_class
):
@@ -83,7 +83,7 @@ class TestPlusAPI(unittest.TestCase):
)
self.assertEqual(response, mock_response)
@patch("crewai.cli.plus_api.PlusAPI._make_request")
@patch("crewai.plus_api.PlusAPI._make_request")
def test_get_tool(self, mock_make_request):
mock_response = MagicMock()
mock_make_request.return_value = mock_response
@@ -94,8 +94,8 @@ class TestPlusAPI(unittest.TestCase):
)
self.assertEqual(response, mock_response)
@patch("crewai.cli.plus_api.Settings")
@patch("crewai.cli.plus_api.httpx.Client")
@patch("crewai.plus_api.Settings")
@patch("crewai.plus_api.httpx.Client")
def test_get_tool_with_org_uuid(self, mock_client_class, mock_settings_class):
mock_settings = MagicMock()
mock_settings.org_uuid = self.org_uuid
@@ -115,7 +115,7 @@ class TestPlusAPI(unittest.TestCase):
)
self.assertEqual(response, mock_response)
@patch("crewai.cli.plus_api.PlusAPI._make_request")
@patch("crewai.plus_api.PlusAPI._make_request")
def test_publish_tool(self, mock_make_request):
mock_response = MagicMock()
mock_make_request.return_value = mock_response
@@ -142,8 +142,8 @@ class TestPlusAPI(unittest.TestCase):
)
self.assertEqual(response, mock_response)
@patch("crewai.cli.plus_api.Settings")
@patch("crewai.cli.plus_api.httpx.Client")
@patch("crewai.plus_api.Settings")
@patch("crewai.plus_api.httpx.Client")
def test_publish_tool_with_org_uuid(self, mock_client_class, mock_settings_class):
mock_settings = MagicMock()
mock_settings.org_uuid = self.org_uuid
@@ -180,7 +180,7 @@ class TestPlusAPI(unittest.TestCase):
)
self.assertEqual(response, mock_response)
@patch("crewai.cli.plus_api.PlusAPI._make_request")
@patch("crewai.plus_api.PlusAPI._make_request")
def test_publish_tool_without_description(self, mock_make_request):
mock_response = MagicMock()
mock_make_request.return_value = mock_response
@@ -207,7 +207,7 @@ class TestPlusAPI(unittest.TestCase):
)
self.assertEqual(response, mock_response)
@patch("crewai.cli.plus_api.httpx.Client")
@patch("crewai.plus_api.httpx.Client")
def test_make_request(self, mock_client_class):
mock_client_instance = MagicMock()
mock_response = MagicMock()
@@ -222,35 +222,35 @@ class TestPlusAPI(unittest.TestCase):
)
self.assertEqual(response, mock_response)
@patch("crewai.cli.plus_api.PlusAPI._make_request")
@patch("crewai.plus_api.PlusAPI._make_request")
def test_deploy_by_name(self, mock_make_request):
self.api.deploy_by_name("test_project")
mock_make_request.assert_called_once_with(
"POST", "/crewai_plus/api/v1/crews/by-name/test_project/deploy"
)
@patch("crewai.cli.plus_api.PlusAPI._make_request")
@patch("crewai.plus_api.PlusAPI._make_request")
def test_deploy_by_uuid(self, mock_make_request):
self.api.deploy_by_uuid("test_uuid")
mock_make_request.assert_called_once_with(
"POST", "/crewai_plus/api/v1/crews/test_uuid/deploy"
)
@patch("crewai.cli.plus_api.PlusAPI._make_request")
@patch("crewai.plus_api.PlusAPI._make_request")
def test_crew_status_by_name(self, mock_make_request):
self.api.crew_status_by_name("test_project")
mock_make_request.assert_called_once_with(
"GET", "/crewai_plus/api/v1/crews/by-name/test_project/status"
)
@patch("crewai.cli.plus_api.PlusAPI._make_request")
@patch("crewai.plus_api.PlusAPI._make_request")
def test_crew_status_by_uuid(self, mock_make_request):
self.api.crew_status_by_uuid("test_uuid")
mock_make_request.assert_called_once_with(
"GET", "/crewai_plus/api/v1/crews/test_uuid/status"
)
@patch("crewai.cli.plus_api.PlusAPI._make_request")
@patch("crewai.plus_api.PlusAPI._make_request")
def test_crew_by_name(self, mock_make_request):
self.api.crew_by_name("test_project")
mock_make_request.assert_called_once_with(
@@ -262,7 +262,7 @@ class TestPlusAPI(unittest.TestCase):
"GET", "/crewai_plus/api/v1/crews/by-name/test_project/logs/custom_log"
)
@patch("crewai.cli.plus_api.PlusAPI._make_request")
@patch("crewai.plus_api.PlusAPI._make_request")
def test_crew_by_uuid(self, mock_make_request):
self.api.crew_by_uuid("test_uuid")
mock_make_request.assert_called_once_with(
@@ -274,26 +274,26 @@ class TestPlusAPI(unittest.TestCase):
"GET", "/crewai_plus/api/v1/crews/test_uuid/logs/custom_log"
)
@patch("crewai.cli.plus_api.PlusAPI._make_request")
@patch("crewai.plus_api.PlusAPI._make_request")
def test_delete_crew_by_name(self, mock_make_request):
self.api.delete_crew_by_name("test_project")
mock_make_request.assert_called_once_with(
"DELETE", "/crewai_plus/api/v1/crews/by-name/test_project"
)
@patch("crewai.cli.plus_api.PlusAPI._make_request")
@patch("crewai.plus_api.PlusAPI._make_request")
def test_delete_crew_by_uuid(self, mock_make_request):
self.api.delete_crew_by_uuid("test_uuid")
mock_make_request.assert_called_once_with(
"DELETE", "/crewai_plus/api/v1/crews/test_uuid"
)
@patch("crewai.cli.plus_api.PlusAPI._make_request")
@patch("crewai.plus_api.PlusAPI._make_request")
def test_list_crews(self, mock_make_request):
self.api.list_crews()
mock_make_request.assert_called_once_with("GET", "/crewai_plus/api/v1/crews")
@patch("crewai.cli.plus_api.PlusAPI._make_request")
@patch("crewai.plus_api.PlusAPI._make_request")
def test_create_crew(self, mock_make_request):
payload = {"name": "test_crew"}
self.api.create_crew(payload)
@@ -301,7 +301,7 @@ class TestPlusAPI(unittest.TestCase):
"POST", "/crewai_plus/api/v1/crews", json=payload
)
@patch("crewai.cli.plus_api.Settings")
@patch("crewai.plus_api.Settings")
@patch.dict(os.environ, {"CREWAI_PLUS_URL": ""})
def test_custom_base_url(self, mock_settings_class):
mock_settings = MagicMock()
@@ -342,7 +342,7 @@ async def test_get_agent(mock_async_client_class):
@pytest.mark.asyncio
@patch("httpx.AsyncClient")
@patch("crewai.cli.plus_api.Settings")
@patch("crewai.plus_api.Settings")
async def test_get_agent_with_org_uuid(mock_settings_class, mock_async_client_class):
org_uuid = "test-org-uuid"
mock_settings = MagicMock()

View File

@@ -10,20 +10,20 @@ from unittest.mock import patch
from cryptography.fernet import Fernet
from crewai.cli.shared.token_manager import TokenManager
from crewai.auth.token_manager import TokenManager
class TestTokenManager(unittest.TestCase):
"""Test cases for TokenManager."""
@patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key")
@patch("crewai.auth.token_manager.TokenManager._get_or_create_key")
def setUp(self, mock_get_key: unittest.mock.MagicMock) -> None:
"""Set up test fixtures."""
mock_get_key.return_value = Fernet.generate_key()
self.token_manager = TokenManager()
@patch("crewai.cli.shared.token_manager.TokenManager._read_secure_file")
@patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key")
@patch("crewai.auth.token_manager.TokenManager._read_secure_file")
@patch("crewai.auth.token_manager.TokenManager._get_or_create_key")
def test_get_or_create_key_existing(
self,
mock_get_or_create: unittest.mock.MagicMock,
@@ -45,7 +45,7 @@ class TestTokenManager(unittest.TestCase):
with (
patch.object(self.token_manager, "_read_secure_file", return_value=None) as mock_read,
patch.object(self.token_manager, "_atomic_create_secure_file", return_value=True) as mock_atomic_create,
patch("crewai.cli.shared.token_manager.Fernet.generate_key", return_value=mock_key) as mock_generate,
patch("crewai.auth.token_manager.Fernet.generate_key", return_value=mock_key) as mock_generate,
):
result = self.token_manager._get_or_create_key()
@@ -62,14 +62,14 @@ class TestTokenManager(unittest.TestCase):
with (
patch.object(self.token_manager, "_read_secure_file", side_effect=[None, their_key]) as mock_read,
patch.object(self.token_manager, "_atomic_create_secure_file", return_value=False) as mock_atomic_create,
patch("crewai.cli.shared.token_manager.Fernet.generate_key", return_value=our_key),
patch("crewai.auth.token_manager.Fernet.generate_key", return_value=our_key),
):
result = self.token_manager._get_or_create_key()
self.assertEqual(result, their_key)
self.assertEqual(mock_read.call_count, 2)
@patch("crewai.cli.shared.token_manager.TokenManager._atomic_write_secure_file")
@patch("crewai.auth.token_manager.TokenManager._atomic_write_secure_file")
def test_save_tokens(
self, mock_write: unittest.mock.MagicMock
) -> None:
@@ -88,7 +88,7 @@ class TestTokenManager(unittest.TestCase):
expiration = datetime.fromisoformat(data["expiration"])
self.assertEqual(expiration, datetime.fromtimestamp(expires_at))
@patch("crewai.cli.shared.token_manager.TokenManager._read_secure_file")
@patch("crewai.auth.token_manager.TokenManager._read_secure_file")
def test_get_token_valid(
self, mock_read: unittest.mock.MagicMock
) -> None:
@@ -103,7 +103,7 @@ class TestTokenManager(unittest.TestCase):
self.assertEqual(result, access_token)
@patch("crewai.cli.shared.token_manager.TokenManager._read_secure_file")
@patch("crewai.auth.token_manager.TokenManager._read_secure_file")
def test_get_token_expired(
self, mock_read: unittest.mock.MagicMock
) -> None:
@@ -118,7 +118,7 @@ class TestTokenManager(unittest.TestCase):
self.assertIsNone(result)
@patch("crewai.cli.shared.token_manager.TokenManager._read_secure_file")
@patch("crewai.auth.token_manager.TokenManager._read_secure_file")
def test_get_token_not_found(
self, mock_read: unittest.mock.MagicMock
) -> None:
@@ -129,7 +129,7 @@ class TestTokenManager(unittest.TestCase):
self.assertIsNone(result)
@patch("crewai.cli.shared.token_manager.TokenManager._delete_secure_file")
@patch("crewai.auth.token_manager.TokenManager._delete_secure_file")
def test_clear_tokens(
self, mock_delete: unittest.mock.MagicMock
) -> None:
@@ -159,7 +159,7 @@ class TestAtomicFileOperations(unittest.TestCase):
import shutil
shutil.rmtree(self.temp_dir, ignore_errors=True)
@patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key")
@patch("crewai.auth.token_manager.TokenManager._get_or_create_key")
def test_atomic_create_new_file(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
@@ -175,7 +175,7 @@ class TestAtomicFileOperations(unittest.TestCase):
self.assertEqual(file_path.read_bytes(), b"content")
self.assertEqual(file_path.stat().st_mode & 0o777, 0o600)
@patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key")
@patch("crewai.auth.token_manager.TokenManager._get_or_create_key")
def test_atomic_create_existing_file(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
@@ -192,7 +192,7 @@ class TestAtomicFileOperations(unittest.TestCase):
self.assertFalse(result)
self.assertEqual(file_path.read_bytes(), b"original")
@patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key")
@patch("crewai.auth.token_manager.TokenManager._get_or_create_key")
def test_atomic_write_new_file(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
@@ -207,7 +207,7 @@ class TestAtomicFileOperations(unittest.TestCase):
self.assertEqual(file_path.read_bytes(), b"content")
self.assertEqual(file_path.stat().st_mode & 0o777, 0o600)
@patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key")
@patch("crewai.auth.token_manager.TokenManager._get_or_create_key")
def test_atomic_write_overwrites(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
@@ -222,7 +222,7 @@ class TestAtomicFileOperations(unittest.TestCase):
self.assertEqual(file_path.read_bytes(), b"new content")
@patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key")
@patch("crewai.auth.token_manager.TokenManager._get_or_create_key")
def test_atomic_write_no_temp_file_on_success(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
@@ -236,7 +236,7 @@ class TestAtomicFileOperations(unittest.TestCase):
temp_files = list(Path(self.temp_dir).glob(".test.txt.*"))
self.assertEqual(len(temp_files), 0)
@patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key")
@patch("crewai.auth.token_manager.TokenManager._get_or_create_key")
def test_read_secure_file_exists(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
@@ -251,7 +251,7 @@ class TestAtomicFileOperations(unittest.TestCase):
self.assertEqual(result, b"content")
@patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key")
@patch("crewai.auth.token_manager.TokenManager._get_or_create_key")
def test_read_secure_file_not_exists(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
@@ -263,7 +263,7 @@ class TestAtomicFileOperations(unittest.TestCase):
self.assertIsNone(result)
@patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key")
@patch("crewai.auth.token_manager.TokenManager._get_or_create_key")
def test_delete_secure_file_exists(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
@@ -278,7 +278,7 @@ class TestAtomicFileOperations(unittest.TestCase):
self.assertFalse(file_path.exists())
@patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key")
@patch("crewai.auth.token_manager.TokenManager._get_or_create_key")
def test_delete_secure_file_not_exists(
self, mock_get_key: unittest.mock.MagicMock
) -> None:

View File

@@ -4,7 +4,7 @@ import tempfile
from pathlib import Path
import pytest
from crewai.cli import utils
from crewai.utilities import project_utils as utils
@pytest.fixture

View File

@@ -6,7 +6,7 @@ from pathlib import Path
from unittest.mock import MagicMock, patch
from crewai import __version__
from crewai.cli.version import (
from crewai.version import (
_find_latest_non_yanked_version,
_get_cache_file,
_is_cache_valid,
@@ -60,8 +60,8 @@ class TestVersionChecking:
cache_data = {"version": "1.0.0"}
assert _is_cache_valid(cache_data) is False
@patch("crewai.cli.version.Path.exists")
@patch("crewai.cli.version.request.urlopen")
@patch("crewai.version.Path.exists")
@patch("crewai.version.request.urlopen")
def test_get_latest_version_from_pypi_success(
self, mock_urlopen: MagicMock, mock_exists: MagicMock
) -> None:
@@ -82,8 +82,8 @@ class TestVersionChecking:
version = get_latest_version_from_pypi()
assert version == "2.0.0"
@patch("crewai.cli.version.Path.exists")
@patch("crewai.cli.version.request.urlopen")
@patch("crewai.version.Path.exists")
@patch("crewai.version.request.urlopen")
def test_get_latest_version_from_pypi_failure(
self, mock_urlopen: MagicMock, mock_exists: MagicMock
) -> None:
@@ -97,8 +97,8 @@ class TestVersionChecking:
version = get_latest_version_from_pypi()
assert version is None
@patch("crewai.cli.version.get_crewai_version")
@patch("crewai.cli.version.get_latest_version_from_pypi")
@patch("crewai.version.get_crewai_version")
@patch("crewai.version.get_latest_version_from_pypi")
def test_is_newer_version_available_true(
self, mock_latest: MagicMock, mock_current: MagicMock
) -> None:
@@ -111,8 +111,8 @@ class TestVersionChecking:
assert current == "1.0.0"
assert latest == "2.0.0"
@patch("crewai.cli.version.get_crewai_version")
@patch("crewai.cli.version.get_latest_version_from_pypi")
@patch("crewai.version.get_crewai_version")
@patch("crewai.version.get_latest_version_from_pypi")
def test_is_newer_version_available_false(
self, mock_latest: MagicMock, mock_current: MagicMock
) -> None:
@@ -125,8 +125,8 @@ class TestVersionChecking:
assert current == "2.0.0"
assert latest == "2.0.0"
@patch("crewai.cli.version.get_crewai_version")
@patch("crewai.cli.version.get_latest_version_from_pypi")
@patch("crewai.version.get_crewai_version")
@patch("crewai.version.get_latest_version_from_pypi")
def test_is_newer_version_available_with_none_latest(
self, mock_latest: MagicMock, mock_current: MagicMock
) -> None:
@@ -260,8 +260,8 @@ class TestIsVersionYanked:
class TestIsCurrentVersionYanked:
"""Test is_current_version_yanked public function."""
@patch("crewai.cli.version.get_crewai_version")
@patch("crewai.cli.version._get_cache_file")
@patch("crewai.version.get_crewai_version")
@patch("crewai.version._get_cache_file")
def test_reads_from_valid_cache(
self, mock_cache_file: MagicMock, mock_version: MagicMock, tmp_path: Path
) -> None:
@@ -282,8 +282,8 @@ class TestIsCurrentVersionYanked:
assert is_yanked is True
assert reason == "bad release"
@patch("crewai.cli.version.get_crewai_version")
@patch("crewai.cli.version._get_cache_file")
@patch("crewai.version.get_crewai_version")
@patch("crewai.version._get_cache_file")
def test_not_yanked_from_cache(
self, mock_cache_file: MagicMock, mock_version: MagicMock, tmp_path: Path
) -> None:
@@ -304,9 +304,9 @@ class TestIsCurrentVersionYanked:
assert is_yanked is False
assert reason == ""
@patch("crewai.cli.version.get_latest_version_from_pypi")
@patch("crewai.cli.version.get_crewai_version")
@patch("crewai.cli.version._get_cache_file")
@patch("crewai.version.get_latest_version_from_pypi")
@patch("crewai.version.get_crewai_version")
@patch("crewai.version._get_cache_file")
def test_triggers_fetch_on_stale_cache(
self,
mock_cache_file: MagicMock,
@@ -346,9 +346,9 @@ class TestIsCurrentVersionYanked:
assert is_yanked is False
mock_fetch.assert_called_once()
@patch("crewai.cli.version.get_latest_version_from_pypi")
@patch("crewai.cli.version.get_crewai_version")
@patch("crewai.cli.version._get_cache_file")
@patch("crewai.version.get_latest_version_from_pypi")
@patch("crewai.version.get_crewai_version")
@patch("crewai.version._get_cache_file")
def test_returns_false_on_fetch_failure(
self,
mock_cache_file: MagicMock,

View File

@@ -11,7 +11,7 @@ from crewai.llms.providers.openai.completion import OpenAICompletion, ResponsesA
from crewai.crew import Crew
from crewai.agent import Agent
from crewai.task import Task
from crewai.cli.constants import DEFAULT_LLM_MODEL
from crewai.constants import DEFAULT_LLM_MODEL
def test_openai_completion_is_used_when_openai_provider():
"""

View File

@@ -102,7 +102,7 @@ class TestBuildMCPConfigFromDict:
class TestFetchAmpMCPConfigs:
@patch("crewai.cli.plus_api.PlusAPI")
@patch("crewai.plus_api.PlusAPI")
@patch("crewai_tools.tools.crewai_platform_tools.misc.get_platform_integration_token", return_value="test-api-key")
def test_fetches_configs_successfully(self, mock_get_token, mock_plus_api_class, resolver):
mock_response = MagicMock()
@@ -133,7 +133,7 @@ class TestFetchAmpMCPConfigs:
mock_plus_api_class.assert_called_once_with(api_key="test-api-key")
mock_plus_api.get_mcp_configs.assert_called_once_with(["notion", "github"])
@patch("crewai.cli.plus_api.PlusAPI")
@patch("crewai.plus_api.PlusAPI")
@patch("crewai_tools.tools.crewai_platform_tools.misc.get_platform_integration_token", return_value="test-api-key")
def test_omits_missing_slugs(self, mock_get_token, mock_plus_api_class, resolver):
mock_response = MagicMock()
@@ -150,7 +150,7 @@ class TestFetchAmpMCPConfigs:
assert "notion" in result
assert "missing-server" not in result
@patch("crewai.cli.plus_api.PlusAPI")
@patch("crewai.plus_api.PlusAPI")
@patch("crewai_tools.tools.crewai_platform_tools.misc.get_platform_integration_token", return_value="test-api-key")
def test_returns_empty_on_http_error(self, mock_get_token, mock_plus_api_class, resolver):
mock_response = MagicMock()
@@ -163,7 +163,7 @@ class TestFetchAmpMCPConfigs:
assert result == {}
@patch("crewai.cli.plus_api.PlusAPI")
@patch("crewai.plus_api.PlusAPI")
@patch("crewai_tools.tools.crewai_platform_tools.misc.get_platform_integration_token", return_value="test-api-key")
def test_returns_empty_on_network_error(self, mock_get_token, mock_plus_api_class, resolver):
import httpx

View File

@@ -35,7 +35,7 @@ class TestTraceListenerSetup:
# Need to patch all the places where get_auth_token is imported/used
with (
patch(
"crewai.cli.authentication.token.get_auth_token",
"crewai.auth.token.get_auth_token",
return_value="mock_token_12345",
),
patch(

View File

@@ -2,7 +2,7 @@ import os
from typing import Any
from unittest.mock import patch
from crewai.cli.constants import DEFAULT_LLM_MODEL
from crewai.constants import DEFAULT_LLM_MODEL
from crewai.llm import LLM
from crewai.llms.base_llm import BaseLLM
from crewai.utilities.llm_utils import create_llm