diff --git a/src/crewai/cli/authentication/providers/__init__.py b/src/crewai/cli/authentication/providers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/crewai/cli/authentication/providers/auth0.py b/src/crewai/cli/authentication/providers/auth0.py index 8538550db..b27e3d168 100644 --- a/src/crewai/cli/authentication/providers/auth0.py +++ b/src/crewai/cli/authentication/providers/auth0.py @@ -1,5 +1,6 @@ from crewai.cli.authentication.providers.base_provider import BaseProvider + class Auth0Provider(BaseProvider): def get_authorize_url(self) -> str: return f"https://{self._get_domain()}/oauth/device/code" @@ -14,13 +15,20 @@ class Auth0Provider(BaseProvider): return f"https://{self._get_domain()}/" def get_audience(self) -> str: - assert self.settings.audience is not None, "Audience is required" + if self.settings.audience is None: + raise ValueError( + "Audience is required. Please set it in the configuration." + ) return self.settings.audience def get_client_id(self) -> str: - assert self.settings.client_id is not None, "Client ID is required" + if self.settings.client_id is None: + raise ValueError( + "Client ID is required. Please set it in the configuration." + ) return self.settings.client_id def _get_domain(self) -> str: - assert self.settings.domain is not None, "Domain is required" + if self.settings.domain is None: + raise ValueError("Domain is required. Please set it in the configuration.") return self.settings.domain diff --git a/src/crewai/cli/authentication/providers/base_provider.py b/src/crewai/cli/authentication/providers/base_provider.py index c321de9f7..2b7a0140e 100644 --- a/src/crewai/cli/authentication/providers/base_provider.py +++ b/src/crewai/cli/authentication/providers/base_provider.py @@ -1,30 +1,26 @@ from abc import ABC, abstractmethod + from crewai.cli.authentication.main import Oauth2Settings + class BaseProvider(ABC): def __init__(self, settings: Oauth2Settings): self.settings = settings @abstractmethod - def get_authorize_url(self) -> str: - ... + def get_authorize_url(self) -> str: ... @abstractmethod - def get_token_url(self) -> str: - ... + def get_token_url(self) -> str: ... @abstractmethod - def get_jwks_url(self) -> str: - ... + def get_jwks_url(self) -> str: ... @abstractmethod - def get_issuer(self) -> str: - ... + def get_issuer(self) -> str: ... @abstractmethod - def get_audience(self) -> str: - ... + def get_audience(self) -> str: ... @abstractmethod - def get_client_id(self) -> str: - ... + def get_client_id(self) -> str: ... diff --git a/src/crewai/cli/authentication/providers/okta.py b/src/crewai/cli/authentication/providers/okta.py index 14227ae2b..d13087e7d 100644 --- a/src/crewai/cli/authentication/providers/okta.py +++ b/src/crewai/cli/authentication/providers/okta.py @@ -1,5 +1,6 @@ from crewai.cli.authentication.providers.base_provider import BaseProvider + class OktaProvider(BaseProvider): def get_authorize_url(self) -> str: return f"https://{self.settings.domain}/oauth2/default/v1/device/authorize" @@ -14,9 +15,15 @@ class OktaProvider(BaseProvider): return f"https://{self.settings.domain}/oauth2/default" def get_audience(self) -> str: - assert self.settings.audience is not None + if self.settings.audience is None: + raise ValueError( + "Audience is required. Please set it in the configuration." + ) return self.settings.audience def get_client_id(self) -> str: - assert self.settings.client_id is not None + if self.settings.client_id is None: + raise ValueError( + "Client ID is required. Please set it in the configuration." + ) return self.settings.client_id diff --git a/src/crewai/cli/authentication/providers/workos.py b/src/crewai/cli/authentication/providers/workos.py index 8cf475a4d..7cffdf890 100644 --- a/src/crewai/cli/authentication/providers/workos.py +++ b/src/crewai/cli/authentication/providers/workos.py @@ -1,5 +1,6 @@ from crewai.cli.authentication.providers.base_provider import BaseProvider + class WorkosProvider(BaseProvider): def get_authorize_url(self) -> str: return f"https://{self._get_domain()}/oauth2/device_authorization" @@ -17,9 +18,13 @@ class WorkosProvider(BaseProvider): return self.settings.audience or "" def get_client_id(self) -> str: - assert self.settings.client_id is not None, "Client ID is required" + if self.settings.client_id is None: + raise ValueError( + "Client ID is required. Please set it in the configuration." + ) return self.settings.client_id def _get_domain(self) -> str: - assert self.settings.domain is not None, "Domain is required" + if self.settings.domain is None: + raise ValueError("Domain is required. Please set it in the configuration.") return self.settings.domain diff --git a/src/crewai/cli/authentication/utils.py b/src/crewai/cli/authentication/utils.py index a788cccc0..849dda594 100644 --- a/src/crewai/cli/authentication/utils.py +++ b/src/crewai/cli/authentication/utils.py @@ -17,8 +17,6 @@ def validate_jwt_token( missing required claims). """ - decoded_token = None - try: jwk_client = PyJWKClient(jwks_url) signing_key = jwk_client.get_signing_key_from_jwt(jwt_token) @@ -26,7 +24,7 @@ def validate_jwt_token( _unverified_decoded_token = jwt.decode( jwt_token, options={"verify_signature": False} ) - decoded_token = jwt.decode( + return jwt.decode( jwt_token, signing_key.key, algorithms=["RS256"], @@ -40,23 +38,22 @@ def validate_jwt_token( "require": ["exp", "iat", "iss", "aud", "sub"], }, ) - return decoded_token - except jwt.ExpiredSignatureError: - raise Exception("Token has expired.") - except jwt.InvalidAudienceError: + except jwt.ExpiredSignatureError as e: + raise Exception("Token has expired.") from e + except jwt.InvalidAudienceError as e: actual_audience = _unverified_decoded_token.get("aud", "[no audience found]") raise Exception( f"Invalid token audience. Got: '{actual_audience}'. Expected: '{audience}'" - ) - except jwt.InvalidIssuerError: + ) from e + except jwt.InvalidIssuerError as e: actual_issuer = _unverified_decoded_token.get("iss", "[no issuer found]") raise Exception( f"Invalid token issuer. Got: '{actual_issuer}'. Expected: '{issuer}'" - ) + ) from e except jwt.MissingRequiredClaimError as e: - raise Exception(f"Token is missing required claims: {str(e)}") + raise Exception(f"Token is missing required claims: {e!s}") from e except jwt.exceptions.PyJWKClientError as e: - raise Exception(f"JWKS or key processing error: {str(e)}") + raise Exception(f"JWKS or key processing error: {e!s}") from e except jwt.InvalidTokenError as e: - raise Exception(f"Invalid token: {str(e)}") + raise Exception(f"Invalid token: {e!s}") from e diff --git a/src/crewai/cli/cli.py b/src/crewai/cli/cli.py index 07b8a0696..b9bf7147b 100644 --- a/src/crewai/cli/cli.py +++ b/src/crewai/cli/cli.py @@ -1,13 +1,13 @@ from importlib.metadata import version as get_version -from typing import Optional import click -from crewai.cli.config import Settings -from crewai.cli.settings.main import SettingsCommand + from crewai.cli.add_crew_to_flow import add_crew_to_flow +from crewai.cli.config import Settings from crewai.cli.create_crew import create_crew from crewai.cli.create_flow import create_flow from crewai.cli.crew_chat import run_chat +from crewai.cli.settings.main import SettingsCommand from crewai.memory.storage.kickoff_task_outputs_storage import ( KickoffTaskOutputsSQLiteStorage, ) @@ -237,13 +237,11 @@ def login(): @crewai.group() def deploy(): """Deploy the Crew CLI group.""" - pass @crewai.group() def tool(): """Tool Repository related commands.""" - pass @deploy.command(name="create") @@ -263,7 +261,7 @@ def deploy_list(): @deploy.command(name="push") @click.option("-u", "--uuid", type=str, help="Crew UUID parameter") -def deploy_push(uuid: Optional[str]): +def deploy_push(uuid: str | None): """Deploy the Crew.""" deploy_cmd = DeployCommand() deploy_cmd.deploy(uuid=uuid) @@ -271,7 +269,7 @@ def deploy_push(uuid: Optional[str]): @deploy.command(name="status") @click.option("-u", "--uuid", type=str, help="Crew UUID parameter") -def deply_status(uuid: Optional[str]): +def deply_status(uuid: str | None): """Get the status of a deployment.""" deploy_cmd = DeployCommand() deploy_cmd.get_crew_status(uuid=uuid) @@ -279,7 +277,7 @@ def deply_status(uuid: Optional[str]): @deploy.command(name="logs") @click.option("-u", "--uuid", type=str, help="Crew UUID parameter") -def deploy_logs(uuid: Optional[str]): +def deploy_logs(uuid: str | None): """Get the logs of a deployment.""" deploy_cmd = DeployCommand() deploy_cmd.get_crew_logs(uuid=uuid) @@ -287,7 +285,7 @@ def deploy_logs(uuid: Optional[str]): @deploy.command(name="remove") @click.option("-u", "--uuid", type=str, help="Crew UUID parameter") -def deploy_remove(uuid: Optional[str]): +def deploy_remove(uuid: str | None): """Remove a deployment.""" deploy_cmd = DeployCommand() deploy_cmd.remove_crew(uuid=uuid) @@ -327,7 +325,6 @@ def tool_publish(is_public: bool, force: bool): @crewai.group() def flow(): """Flow related commands.""" - pass @flow.command(name="kickoff") @@ -359,7 +356,7 @@ def chat(): and using the Chat LLM to generate responses. """ click.secho( - "\nStarting a conversation with the Crew\n" "Type 'exit' or Ctrl+C to quit.\n", + "\nStarting a conversation with the Crew\nType 'exit' or Ctrl+C to quit.\n", ) run_chat() @@ -368,7 +365,6 @@ def chat(): @crewai.group(invoke_without_command=True) def org(): """Organization management commands.""" - pass @org.command("list") @@ -396,7 +392,6 @@ def current(): @crewai.group() def enterprise(): """Enterprise Configuration commands.""" - pass @enterprise.command("configure") @@ -410,7 +405,6 @@ def enterprise_configure(enterprise_url: str): @crewai.group() def config(): """CLI Configuration commands.""" - pass @config.command("list") diff --git a/src/crewai/cli/config.py b/src/crewai/cli/config.py index 8eccbbb05..d1c2ba725 100644 --- a/src/crewai/cli/config.py +++ b/src/crewai/cli/config.py @@ -1,15 +1,14 @@ import json from pathlib import Path -from typing import Optional from pydantic import BaseModel, Field from crewai.cli.constants import ( - DEFAULT_CREWAI_ENTERPRISE_URL, - CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER, CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE, CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID, CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN, + CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER, + DEFAULT_CREWAI_ENTERPRISE_URL, ) from crewai.cli.shared.token_manager import TokenManager @@ -56,20 +55,20 @@ HIDDEN_SETTINGS_KEYS = [ class Settings(BaseModel): - enterprise_base_url: Optional[str] = Field( + enterprise_base_url: str | None = Field( default=DEFAULT_CLI_SETTINGS["enterprise_base_url"], description="Base URL of the CrewAI Enterprise instance", ) - tool_repository_username: Optional[str] = Field( + tool_repository_username: str | None = Field( None, description="Username for interacting with the Tool Repository" ) - tool_repository_password: Optional[str] = Field( + tool_repository_password: str | None = Field( None, description="Password for interacting with the Tool Repository" ) - org_name: Optional[str] = Field( + org_name: str | None = Field( None, description="Name of the currently active organization" ) - org_uuid: Optional[str] = Field( + org_uuid: str | None = Field( None, description="UUID of the currently active organization" ) config_path: Path = Field(default=DEFAULT_CONFIG_PATH, frozen=True, exclude=True) @@ -79,7 +78,7 @@ class Settings(BaseModel): default=DEFAULT_CLI_SETTINGS["oauth2_provider"], ) - oauth2_audience: Optional[str] = Field( + oauth2_audience: str | None = Field( description="OAuth2 audience value, typically used to identify the target API or resource.", default=DEFAULT_CLI_SETTINGS["oauth2_audience"], ) diff --git a/src/crewai/cli/create_crew.py b/src/crewai/cli/create_crew.py index 1d3e3ddce..3c0408637 100644 --- a/src/crewai/cli/create_crew.py +++ b/src/crewai/cli/create_crew.py @@ -16,48 +16,72 @@ from crewai.cli.utils import copy_template, load_env_vars, write_env_file def create_folder_structure(name, parent_folder=None): import keyword import re - - name = name.rstrip('/') - + + name = name.rstrip("/") + if not name.strip(): raise ValueError("Project name cannot be empty or contain only whitespace") - + folder_name = name.replace(" ", "_").replace("-", "_").lower() - folder_name = re.sub(r'[^a-zA-Z0-9_]', '', folder_name) - + folder_name = re.sub(r"[^a-zA-Z0-9_]", "", folder_name) + # Check if the name starts with invalid characters or is primarily invalid - if re.match(r'^[^a-zA-Z0-9_-]+', name): - raise ValueError(f"Project name '{name}' contains no valid characters for a Python module name") - + if re.match(r"^[^a-zA-Z0-9_-]+", name): + raise ValueError( + f"Project name '{name}' contains no valid characters for a Python module name" + ) + if not folder_name: - raise ValueError(f"Project name '{name}' contains no valid characters for a Python module name") - + raise ValueError( + f"Project name '{name}' contains no valid characters for a Python module name" + ) + if folder_name[0].isdigit(): - raise ValueError(f"Project name '{name}' would generate folder name '{folder_name}' which cannot start with a digit (invalid Python module name)") - + raise ValueError( + f"Project name '{name}' would generate folder name '{folder_name}' which cannot start with a digit (invalid Python module name)" + ) + if keyword.iskeyword(folder_name): - raise ValueError(f"Project name '{name}' would generate folder name '{folder_name}' which is a reserved Python keyword") - + raise ValueError( + f"Project name '{name}' would generate folder name '{folder_name}' which is a reserved Python keyword" + ) + if not folder_name.isidentifier(): - raise ValueError(f"Project name '{name}' would generate invalid Python module name '{folder_name}'") - + raise ValueError( + f"Project name '{name}' would generate invalid Python module name '{folder_name}'" + ) + class_name = name.replace("_", " ").replace("-", " ").title().replace(" ", "") - - class_name = re.sub(r'[^a-zA-Z0-9_]', '', class_name) - + + class_name = re.sub(r"[^a-zA-Z0-9_]", "", class_name) + if not class_name: - raise ValueError(f"Project name '{name}' contains no valid characters for a Python class name") - + raise ValueError( + f"Project name '{name}' contains no valid characters for a Python class name" + ) + if class_name[0].isdigit(): - raise ValueError(f"Project name '{name}' would generate class name '{class_name}' which cannot start with a digit") - + raise ValueError( + f"Project name '{name}' would generate class name '{class_name}' which cannot start with a digit" + ) + # Check if the original name (before title casing) is a keyword - original_name_clean = re.sub(r'[^a-zA-Z0-9_]', '', name.replace("_", "").replace("-", "").lower()) - if keyword.iskeyword(original_name_clean) or keyword.iskeyword(class_name) or class_name in ('True', 'False', 'None'): - raise ValueError(f"Project name '{name}' would generate class name '{class_name}' which is a reserved Python keyword") - + original_name_clean = re.sub( + r"[^a-zA-Z0-9_]", "", name.replace("_", "").replace("-", "").lower() + ) + if ( + keyword.iskeyword(original_name_clean) + or keyword.iskeyword(class_name) + or class_name in ("True", "False", "None") + ): + raise ValueError( + f"Project name '{name}' would generate class name '{class_name}' which is a reserved Python keyword" + ) + if not class_name.isidentifier(): - raise ValueError(f"Project name '{name}' would generate invalid Python class name '{class_name}'") + raise ValueError( + f"Project name '{name}' would generate invalid Python class name '{class_name}'" + ) if parent_folder: folder_path = Path(parent_folder) / folder_name @@ -172,7 +196,7 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None): ) # Check if the selected provider has predefined models - if selected_provider in MODELS and MODELS[selected_provider]: + if MODELS.get(selected_provider): while True: selected_model = select_model(selected_provider, provider_models) if selected_model is None: # User typed 'q' diff --git a/src/crewai/cli/crew_chat.py b/src/crewai/cli/crew_chat.py index 1b4e18c78..6fe9d87c8 100644 --- a/src/crewai/cli/crew_chat.py +++ b/src/crewai/cli/crew_chat.py @@ -5,7 +5,7 @@ import sys import threading import time from pathlib import Path -from typing import Any, Dict, List, Optional, Set, Tuple +from typing import Any import click import tomli @@ -116,7 +116,7 @@ def show_loading(event: threading.Event): print() -def initialize_chat_llm(crew: Crew) -> Optional[LLM | BaseLLM]: +def initialize_chat_llm(crew: Crew) -> LLM | BaseLLM | None: """Initializes the chat LLM and handles exceptions.""" try: return create_llm(crew.chat_llm) @@ -157,7 +157,7 @@ def build_system_message(crew_chat_inputs: ChatInputs) -> str: ) -def create_tool_function(crew: Crew, messages: List[Dict[str, str]]) -> Any: +def create_tool_function(crew: Crew, messages: list[dict[str, str]]) -> Any: """Creates a wrapper function for running the crew tool with messages.""" def run_crew_tool_with_messages(**kwargs): @@ -193,7 +193,7 @@ def chat_loop(chat_llm, messages, crew_tool_schema, available_functions): user_input, chat_llm, messages, crew_tool_schema, available_functions ) - except KeyboardInterrupt: + except KeyboardInterrupt: # noqa: PERF203 click.echo("\nExiting chat. Goodbye!") break except Exception as e: @@ -221,9 +221,9 @@ def get_user_input() -> str: def handle_user_input( user_input: str, chat_llm: LLM, - messages: List[Dict[str, str]], - crew_tool_schema: Dict[str, Any], - available_functions: Dict[str, Any], + messages: list[dict[str, str]], + crew_tool_schema: dict[str, Any], + available_functions: dict[str, Any], ) -> None: if user_input.strip().lower() == "exit": click.echo("Exiting chat. Goodbye!") @@ -281,7 +281,7 @@ def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict: } -def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs): +def run_crew_tool(crew: Crew, messages: list[dict[str, str]], **kwargs): """ Runs the crew using crew.kickoff(inputs=kwargs) and returns the output. @@ -304,9 +304,8 @@ def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs): crew_output = crew.kickoff(inputs=kwargs) # Convert CrewOutput to a string to send back to the user - result = str(crew_output) + return str(crew_output) - return result except Exception as e: # Exit the chat and show the error message click.secho("An error occurred while running the crew:", fg="red") @@ -314,7 +313,7 @@ def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs): sys.exit(1) -def load_crew_and_name() -> Tuple[Crew, str]: +def load_crew_and_name() -> tuple[Crew, str]: """ Loads the crew by importing the crew class from the user's project. @@ -351,15 +350,17 @@ def load_crew_and_name() -> Tuple[Crew, str]: try: crew_module = __import__(crew_module_name, fromlist=[crew_class_name]) except ImportError as e: - raise ImportError(f"Failed to import crew module {crew_module_name}: {e}") + raise ImportError( + f"Failed to import crew module {crew_module_name}: {e}" + ) from e # Get the crew class from the module try: crew_class = getattr(crew_module, crew_class_name) - except AttributeError: + except AttributeError as e: raise AttributeError( f"Crew class {crew_class_name} not found in module {crew_module_name}" - ) + ) from e # Instantiate the crew crew_instance = crew_class().crew() @@ -395,7 +396,7 @@ def generate_crew_chat_inputs(crew: Crew, crew_name: str, chat_llm) -> ChatInput ) -def fetch_required_inputs(crew: Crew) -> Set[str]: +def fetch_required_inputs(crew: Crew) -> set[str]: """ Extracts placeholders from the crew's tasks and agents. @@ -405,8 +406,8 @@ def fetch_required_inputs(crew: Crew) -> Set[str]: Returns: Set[str]: A set of placeholder names. """ - placeholder_pattern = re.compile(r"\{(.+?)\}") - required_inputs: Set[str] = set() + placeholder_pattern = re.compile(r"\{(.+?)}") + required_inputs: set[str] = set() # Scan tasks for task in crew.tasks: @@ -435,7 +436,7 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) -> """ # Gather context from tasks and agents where the input is used context_texts = [] - placeholder_pattern = re.compile(r"\{(.+?)\}") + placeholder_pattern = re.compile(r"\{(.+?)}") for task in crew.tasks: if ( @@ -479,9 +480,7 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) -> f"{context}" ) response = chat_llm.call(messages=[{"role": "user", "content": prompt}]) - description = response.strip() - - return description + return response.strip() def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str: @@ -497,7 +496,7 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str: """ # Gather context from tasks and agents context_texts = [] - placeholder_pattern = re.compile(r"\{(.+?)\}") + placeholder_pattern = re.compile(r"\{(.+?)}") for task in crew.tasks: # Replace placeholders with input names @@ -531,6 +530,4 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str: f"{context}" ) response = chat_llm.call(messages=[{"role": "user", "content": prompt}]) - crew_description = response.strip() - - return crew_description + return response.strip() diff --git a/src/crewai/cli/git.py b/src/crewai/cli/git.py index 58836e733..5828a717c 100644 --- a/src/crewai/cli/git.py +++ b/src/crewai/cli/git.py @@ -14,11 +14,15 @@ class Repository: self.fetch() - def is_git_installed(self) -> bool: + @staticmethod + def is_git_installed() -> bool: """Check if Git is installed and available in the system.""" try: subprocess.run( - ["git", "--version"], capture_output=True, check=True, text=True + ["git", "--version"], # noqa: S607 + capture_output=True, + check=True, + text=True, ) return True except (subprocess.CalledProcessError, FileNotFoundError): @@ -26,22 +30,26 @@ class Repository: def fetch(self) -> None: """Fetch latest updates from the remote.""" - subprocess.run(["git", "fetch"], cwd=self.path, check=True) + subprocess.run(["git", "fetch"], cwd=self.path, check=True) # noqa: S607 def status(self) -> str: """Get the git status in porcelain format.""" return subprocess.check_output( - ["git", "status", "--branch", "--porcelain"], + ["git", "status", "--branch", "--porcelain"], # noqa: S607 cwd=self.path, encoding="utf-8", ).strip() - @lru_cache(maxsize=None) + @lru_cache(maxsize=None) # noqa: B019 def is_git_repo(self) -> bool: - """Check if the current directory is a git repository.""" + """Check if the current directory is a git repository. + + Notes: + - TODO: This method is cached to avoid redundant checks, but using lru_cache on methods can lead to memory leaks + """ try: subprocess.check_output( - ["git", "rev-parse", "--is-inside-work-tree"], + ["git", "rev-parse", "--is-inside-work-tree"], # noqa: S607 cwd=self.path, encoding="utf-8", ) @@ -64,14 +72,13 @@ class Repository: """Return True if the Git repository is fully synced with the remote, False otherwise.""" if self.has_uncommitted_changes() or self.is_ahead_or_behind(): return False - else: - return True + return True def origin_url(self) -> str | None: """Get the Git repository's remote URL.""" try: result = subprocess.run( - ["git", "remote", "get-url", "origin"], + ["git", "remote", "get-url", "origin"], # noqa: S607 cwd=self.path, capture_output=True, text=True, diff --git a/src/crewai/cli/install_crew.py b/src/crewai/cli/install_crew.py index bd0f35879..aa10902aa 100644 --- a/src/crewai/cli/install_crew.py +++ b/src/crewai/cli/install_crew.py @@ -12,8 +12,8 @@ def install_crew(proxy_options: list[str]) -> None: Install the crew by running the UV command to lock and install. """ try: - command = ["uv", "sync"] + proxy_options - subprocess.run(command, check=True, capture_output=False, text=True) + command = ["uv", "sync", *proxy_options] + subprocess.run(command, check=True, capture_output=False, text=True) # noqa: S603 except subprocess.CalledProcessError as e: click.echo(f"An error occurred while running the crew: {e}", err=True) diff --git a/src/crewai/cli/plus_api.py b/src/crewai/cli/plus_api.py index c4d4422d2..2b0e4a7d7 100644 --- a/src/crewai/cli/plus_api.py +++ b/src/crewai/cli/plus_api.py @@ -1,11 +1,10 @@ -from typing import List, Optional from urllib.parse import urljoin import requests from crewai.cli.config import Settings -from crewai.cli.version import get_crewai_version from crewai.cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL +from crewai.cli.version import get_crewai_version class PlusAPI: @@ -56,9 +55,9 @@ class PlusAPI: handle: str, is_public: bool, version: str, - description: Optional[str], + description: str | None, encoded_file: str, - available_exports: Optional[List[str]] = None, + available_exports: list[str] | None = None, ): params = { "handle": handle, diff --git a/src/crewai/cli/provider.py b/src/crewai/cli/provider.py index 0a26ef809..3374fef00 100644 --- a/src/crewai/cli/provider.py +++ b/src/crewai/cli/provider.py @@ -1,10 +1,10 @@ -import os -import certifi import json +import os import time from collections import defaultdict from pathlib import Path +import certifi import click import requests @@ -25,7 +25,7 @@ def select_choice(prompt_message, choices): provider_models = get_provider_data() if not provider_models: - return + return None click.secho(prompt_message, fg="cyan") for idx, choice in enumerate(choices, start=1): click.secho(f"{idx}. {choice}", fg="cyan") @@ -67,7 +67,7 @@ def select_provider(provider_models): all_providers = sorted(set(predefined_providers + list(provider_models.keys()))) provider = select_choice( - "Select a provider to set up:", predefined_providers + ["other"] + "Select a provider to set up:", [*predefined_providers, "other"] ) if provider is None: # User typed 'q' return None @@ -102,10 +102,9 @@ def select_model(provider, provider_models): click.secho(f"No models available for provider '{provider}'.", fg="red") return None - selected_model = select_choice( + return select_choice( f"Select a model to use for {provider.capitalize()}:", available_models ) - return selected_model def load_provider_data(cache_file, cache_expiry): @@ -165,7 +164,7 @@ def fetch_provider_data(cache_file): Returns: - dict or None: The fetched provider data or None if the operation fails. """ - ssl_config = os.environ['SSL_CERT_FILE'] = certifi.where() + ssl_config = os.environ["SSL_CERT_FILE"] = certifi.where() try: response = requests.get(JSON_URL, stream=True, timeout=60, verify=ssl_config) diff --git a/src/crewai/cli/run_crew.py b/src/crewai/cli/run_crew.py index 62241a4b5..d7fd96db5 100644 --- a/src/crewai/cli/run_crew.py +++ b/src/crewai/cli/run_crew.py @@ -1,6 +1,5 @@ import subprocess from enum import Enum -from typing import List, Optional import click from packaging import version @@ -57,7 +56,7 @@ def execute_command(crew_type: CrewType) -> None: command = ["uv", "run", "kickoff" if crew_type == CrewType.FLOW else "run_crew"] try: - subprocess.run(command, capture_output=False, text=True, check=True) + subprocess.run(command, capture_output=False, text=True, check=True) # noqa: S603 except subprocess.CalledProcessError as e: handle_error(e, crew_type) diff --git a/src/crewai/cli/shared/token_manager.py b/src/crewai/cli/shared/token_manager.py index c0e69dc43..89d44c573 100644 --- a/src/crewai/cli/shared/token_manager.py +++ b/src/crewai/cli/shared/token_manager.py @@ -3,7 +3,7 @@ import os import sys from datetime import datetime from pathlib import Path -from typing import Optional + from cryptography.fernet import Fernet @@ -49,7 +49,7 @@ class TokenManager: encrypted_data = self.fernet.encrypt(json.dumps(data).encode()) self.save_secure_file(self.file_path, encrypted_data) - def get_token(self) -> Optional[str]: + def get_token(self) -> str | None: """ Get the access token if it is valid and not expired. @@ -113,7 +113,7 @@ class TokenManager: # Set appropriate permissions (read/write for owner only) os.chmod(file_path, 0o600) - def read_secure_file(self, filename: str) -> Optional[bytes]: + def read_secure_file(self, filename: str) -> bytes | None: """ Read the content of a secure file. diff --git a/src/crewai/cli/utils.py b/src/crewai/cli/utils.py index 3fff637de..764af9d2f 100644 --- a/src/crewai/cli/utils.py +++ b/src/crewai/cli/utils.py @@ -5,7 +5,7 @@ import sys from functools import reduce from inspect import getmro, isclass, isfunction, ismethod from pathlib import Path -from typing import Any, Dict, List, get_type_hints +from typing import Any, get_type_hints import click import tomli @@ -41,8 +41,7 @@ def copy_template(src, dst, name, class_name, folder_name): def read_toml(file_path: str = "pyproject.toml"): """Read the content of a TOML file and return it as a dictionary.""" with open(file_path, "rb") as f: - toml_dict = tomli.load(f) - return toml_dict + return tomli.load(f) def parse_toml(content): @@ -77,7 +76,7 @@ def get_project_description( def _get_project_attribute( - pyproject_path: str, keys: List[str], require: bool + pyproject_path: str, keys: list[str], require: bool ) -> Any | None: """Get an attribute from the pyproject.toml file.""" attribute = None @@ -96,16 +95,20 @@ def _get_project_attribute( except FileNotFoundError: console.print(f"Error: {pyproject_path} not found.", style="bold red") except KeyError: - console.print(f"Error: {pyproject_path} is not a valid pyproject.toml file.", style="bold red") - except tomllib.TOMLDecodeError if sys.version_info >= (3, 11) else Exception as e: # type: ignore console.print( - f"Error: {pyproject_path} is not a valid TOML file." - if sys.version_info >= (3, 11) - else f"Error reading the pyproject.toml file: {e}", + f"Error: {pyproject_path} is not a valid pyproject.toml file.", style="bold red", ) except Exception as e: - console.print(f"Error reading the pyproject.toml file: {e}", style="bold red") + # Handle TOML decode errors for Python 3.11+ + if sys.version_info >= (3, 11) and isinstance(e, tomllib.TOMLDecodeError): # type: ignore + console.print( + f"Error: {pyproject_path} is not a valid TOML file.", style="bold red" + ) + else: + console.print( + f"Error reading the pyproject.toml file: {e}", style="bold red" + ) if require and not attribute: console.print( @@ -117,7 +120,7 @@ def _get_project_attribute( return attribute -def _get_nested_value(data: Dict[str, Any], keys: List[str]) -> Any: +def _get_nested_value(data: dict[str, Any], keys: list[str]) -> Any: return reduce(dict.__getitem__, keys, data) @@ -296,7 +299,10 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]: try: crew_instances.extend(fetch_crews(module_attr)) except Exception as e: - console.print(f"Error processing attribute {attr_name}: {e}", style="bold red") + console.print( + f"Error processing attribute {attr_name}: {e}", + style="bold red", + ) continue # If we found crew instances, break out of the loop @@ -304,12 +310,15 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]: break except Exception as exec_error: - console.print(f"Error executing module: {exec_error}", style="bold red") + console.print( + f"Error executing module: {exec_error}", + style="bold red", + ) except (ImportError, AttributeError) as e: if require: console.print( - f"Error importing crew from {crew_path}: {str(e)}", + f"Error importing crew from {crew_path}: {e!s}", style="bold red", ) continue @@ -325,9 +334,9 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]: except Exception as e: if require: console.print( - f"Unexpected error while loading crew: {str(e)}", style="bold red" + f"Unexpected error while loading crew: {e!s}", style="bold red" ) - raise SystemExit + raise SystemExit from e return crew_instances @@ -348,8 +357,7 @@ def get_crew_instance(module_attr) -> Crew | None: if isinstance(module_attr, Crew): return module_attr - else: - return None + return None def fetch_crews(module_attr) -> list[Crew]: @@ -402,11 +410,11 @@ def extract_available_exports(dir_path: str = "src"): return available_exports except Exception as e: - console.print(f"[red]Error: Could not extract tool classes: {str(e)}[/red]") + console.print(f"[red]Error: Could not extract tool classes: {e!s}[/red]") console.print( "Please ensure your project contains valid tools (classes inheriting from BaseTool or functions with @tool decorator)." ) - raise SystemExit(1) + raise SystemExit(1) from e def _load_tools_from_init(init_file: Path) -> list[dict[str, Any]]: @@ -440,8 +448,8 @@ def _load_tools_from_init(init_file: Path) -> list[dict[str, Any]]: ] except Exception as e: - console.print(f"[red]Warning: Could not load {init_file}: {str(e)}[/red]") - raise SystemExit(1) + console.print(f"[red]Warning: Could not load {init_file}: {e!s}[/red]") + raise SystemExit(1) from e finally: sys.modules.pop("temp_module", None) diff --git a/tests/cli/authentication/providers/test_okta.py b/tests/cli/authentication/providers/test_okta.py index b952464ba..5ceb441bf 100644 --- a/tests/cli/authentication/providers/test_okta.py +++ b/tests/cli/authentication/providers/test_okta.py @@ -1,17 +1,17 @@ import pytest + from crewai.cli.authentication.main import Oauth2Settings from crewai.cli.authentication.providers.okta import OktaProvider class TestOktaProvider: - @pytest.fixture(autouse=True) def setup_method(self): self.valid_settings = Oauth2Settings( provider="okta", domain="test-domain.okta.com", client_id="test-client-id", - audience="test-audience" + audience="test-audience", ) self.provider = OktaProvider(self.valid_settings) @@ -32,7 +32,7 @@ class TestOktaProvider: provider="okta", domain="my-company.okta.com", client_id="test-client", - audience="test-audience" + audience="test-audience", ) provider = OktaProvider(settings) expected_url = "https://my-company.okta.com/oauth2/default/v1/device/authorize" @@ -47,7 +47,7 @@ class TestOktaProvider: provider="okta", domain="another-domain.okta.com", client_id="test-client", - audience="test-audience" + audience="test-audience", ) provider = OktaProvider(settings) expected_url = "https://another-domain.okta.com/oauth2/default/v1/token" @@ -62,7 +62,7 @@ class TestOktaProvider: provider="okta", domain="dev.okta.com", client_id="test-client", - audience="test-audience" + audience="test-audience", ) provider = OktaProvider(settings) expected_url = "https://dev.okta.com/oauth2/default/v1/keys" @@ -77,7 +77,7 @@ class TestOktaProvider: provider="okta", domain="prod.okta.com", client_id="test-client", - audience="test-audience" + audience="test-audience", ) provider = OktaProvider(settings) expected_issuer = "https://prod.okta.com/oauth2/default" @@ -91,11 +91,11 @@ class TestOktaProvider: provider="okta", domain="test-domain.okta.com", client_id="test-client-id", - audience=None + audience=None, ) provider = OktaProvider(settings) - with pytest.raises(AssertionError): + with pytest.raises(ValueError, match="Audience is required"): provider.get_audience() def test_get_client_id(self):