chore: apply ruff linting fixes to CLI module

fix: apply ruff fixes to CLI and update Okta provider test
This commit is contained in:
Greyson LaLonde
2025-09-19 19:55:55 -04:00
committed by GitHub
parent de5d3c3ad1
commit f4abc41235
18 changed files with 207 additions and 168 deletions

View File

@@ -1,5 +1,6 @@
from crewai.cli.authentication.providers.base_provider import BaseProvider from crewai.cli.authentication.providers.base_provider import BaseProvider
class Auth0Provider(BaseProvider): class Auth0Provider(BaseProvider):
def get_authorize_url(self) -> str: def get_authorize_url(self) -> str:
return f"https://{self._get_domain()}/oauth/device/code" return f"https://{self._get_domain()}/oauth/device/code"
@@ -14,13 +15,20 @@ class Auth0Provider(BaseProvider):
return f"https://{self._get_domain()}/" return f"https://{self._get_domain()}/"
def get_audience(self) -> str: def get_audience(self) -> str:
assert self.settings.audience is not None, "Audience is required" if self.settings.audience is None:
raise ValueError(
"Audience is required. Please set it in the configuration."
)
return self.settings.audience return self.settings.audience
def get_client_id(self) -> str: def get_client_id(self) -> str:
assert self.settings.client_id is not None, "Client ID is required" if self.settings.client_id is None:
raise ValueError(
"Client ID is required. Please set it in the configuration."
)
return self.settings.client_id return self.settings.client_id
def _get_domain(self) -> str: def _get_domain(self) -> str:
assert self.settings.domain is not None, "Domain is required" if self.settings.domain is None:
raise ValueError("Domain is required. Please set it in the configuration.")
return self.settings.domain return self.settings.domain

View File

@@ -1,30 +1,26 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from crewai.cli.authentication.main import Oauth2Settings from crewai.cli.authentication.main import Oauth2Settings
class BaseProvider(ABC): class BaseProvider(ABC):
def __init__(self, settings: Oauth2Settings): def __init__(self, settings: Oauth2Settings):
self.settings = settings self.settings = settings
@abstractmethod @abstractmethod
def get_authorize_url(self) -> str: def get_authorize_url(self) -> str: ...
...
@abstractmethod @abstractmethod
def get_token_url(self) -> str: def get_token_url(self) -> str: ...
...
@abstractmethod @abstractmethod
def get_jwks_url(self) -> str: def get_jwks_url(self) -> str: ...
...
@abstractmethod @abstractmethod
def get_issuer(self) -> str: def get_issuer(self) -> str: ...
...
@abstractmethod @abstractmethod
def get_audience(self) -> str: def get_audience(self) -> str: ...
...
@abstractmethod @abstractmethod
def get_client_id(self) -> str: def get_client_id(self) -> str: ...
...

View File

@@ -1,5 +1,6 @@
from crewai.cli.authentication.providers.base_provider import BaseProvider from crewai.cli.authentication.providers.base_provider import BaseProvider
class OktaProvider(BaseProvider): class OktaProvider(BaseProvider):
def get_authorize_url(self) -> str: def get_authorize_url(self) -> str:
return f"https://{self.settings.domain}/oauth2/default/v1/device/authorize" return f"https://{self.settings.domain}/oauth2/default/v1/device/authorize"
@@ -14,9 +15,15 @@ class OktaProvider(BaseProvider):
return f"https://{self.settings.domain}/oauth2/default" return f"https://{self.settings.domain}/oauth2/default"
def get_audience(self) -> str: def get_audience(self) -> str:
assert self.settings.audience is not None if self.settings.audience is None:
raise ValueError(
"Audience is required. Please set it in the configuration."
)
return self.settings.audience return self.settings.audience
def get_client_id(self) -> str: def get_client_id(self) -> str:
assert self.settings.client_id is not None if self.settings.client_id is None:
raise ValueError(
"Client ID is required. Please set it in the configuration."
)
return self.settings.client_id return self.settings.client_id

View File

@@ -1,5 +1,6 @@
from crewai.cli.authentication.providers.base_provider import BaseProvider from crewai.cli.authentication.providers.base_provider import BaseProvider
class WorkosProvider(BaseProvider): class WorkosProvider(BaseProvider):
def get_authorize_url(self) -> str: def get_authorize_url(self) -> str:
return f"https://{self._get_domain()}/oauth2/device_authorization" return f"https://{self._get_domain()}/oauth2/device_authorization"
@@ -17,9 +18,13 @@ class WorkosProvider(BaseProvider):
return self.settings.audience or "" return self.settings.audience or ""
def get_client_id(self) -> str: def get_client_id(self) -> str:
assert self.settings.client_id is not None, "Client ID is required" if self.settings.client_id is None:
raise ValueError(
"Client ID is required. Please set it in the configuration."
)
return self.settings.client_id return self.settings.client_id
def _get_domain(self) -> str: def _get_domain(self) -> str:
assert self.settings.domain is not None, "Domain is required" if self.settings.domain is None:
raise ValueError("Domain is required. Please set it in the configuration.")
return self.settings.domain return self.settings.domain

View File

@@ -17,8 +17,6 @@ def validate_jwt_token(
missing required claims). missing required claims).
""" """
decoded_token = None
try: try:
jwk_client = PyJWKClient(jwks_url) jwk_client = PyJWKClient(jwks_url)
signing_key = jwk_client.get_signing_key_from_jwt(jwt_token) signing_key = jwk_client.get_signing_key_from_jwt(jwt_token)
@@ -26,7 +24,7 @@ def validate_jwt_token(
_unverified_decoded_token = jwt.decode( _unverified_decoded_token = jwt.decode(
jwt_token, options={"verify_signature": False} jwt_token, options={"verify_signature": False}
) )
decoded_token = jwt.decode( return jwt.decode(
jwt_token, jwt_token,
signing_key.key, signing_key.key,
algorithms=["RS256"], algorithms=["RS256"],
@@ -40,23 +38,22 @@ def validate_jwt_token(
"require": ["exp", "iat", "iss", "aud", "sub"], "require": ["exp", "iat", "iss", "aud", "sub"],
}, },
) )
return decoded_token
except jwt.ExpiredSignatureError: except jwt.ExpiredSignatureError as e:
raise Exception("Token has expired.") raise Exception("Token has expired.") from e
except jwt.InvalidAudienceError: except jwt.InvalidAudienceError as e:
actual_audience = _unverified_decoded_token.get("aud", "[no audience found]") actual_audience = _unverified_decoded_token.get("aud", "[no audience found]")
raise Exception( raise Exception(
f"Invalid token audience. Got: '{actual_audience}'. Expected: '{audience}'" f"Invalid token audience. Got: '{actual_audience}'. Expected: '{audience}'"
) ) from e
except jwt.InvalidIssuerError: except jwt.InvalidIssuerError as e:
actual_issuer = _unverified_decoded_token.get("iss", "[no issuer found]") actual_issuer = _unverified_decoded_token.get("iss", "[no issuer found]")
raise Exception( raise Exception(
f"Invalid token issuer. Got: '{actual_issuer}'. Expected: '{issuer}'" f"Invalid token issuer. Got: '{actual_issuer}'. Expected: '{issuer}'"
) ) from e
except jwt.MissingRequiredClaimError as e: except jwt.MissingRequiredClaimError as e:
raise Exception(f"Token is missing required claims: {str(e)}") raise Exception(f"Token is missing required claims: {e!s}") from e
except jwt.exceptions.PyJWKClientError as e: except jwt.exceptions.PyJWKClientError as e:
raise Exception(f"JWKS or key processing error: {str(e)}") raise Exception(f"JWKS or key processing error: {e!s}") from e
except jwt.InvalidTokenError as e: except jwt.InvalidTokenError as e:
raise Exception(f"Invalid token: {str(e)}") raise Exception(f"Invalid token: {e!s}") from e

View File

@@ -1,13 +1,13 @@
from importlib.metadata import version as get_version from importlib.metadata import version as get_version
from typing import Optional
import click import click
from crewai.cli.config import Settings
from crewai.cli.settings.main import SettingsCommand
from crewai.cli.add_crew_to_flow import add_crew_to_flow from crewai.cli.add_crew_to_flow import add_crew_to_flow
from crewai.cli.config import Settings
from crewai.cli.create_crew import create_crew from crewai.cli.create_crew import create_crew
from crewai.cli.create_flow import create_flow from crewai.cli.create_flow import create_flow
from crewai.cli.crew_chat import run_chat from crewai.cli.crew_chat import run_chat
from crewai.cli.settings.main import SettingsCommand
from crewai.memory.storage.kickoff_task_outputs_storage import ( from crewai.memory.storage.kickoff_task_outputs_storage import (
KickoffTaskOutputsSQLiteStorage, KickoffTaskOutputsSQLiteStorage,
) )
@@ -237,13 +237,11 @@ def login():
@crewai.group() @crewai.group()
def deploy(): def deploy():
"""Deploy the Crew CLI group.""" """Deploy the Crew CLI group."""
pass
@crewai.group() @crewai.group()
def tool(): def tool():
"""Tool Repository related commands.""" """Tool Repository related commands."""
pass
@deploy.command(name="create") @deploy.command(name="create")
@@ -263,7 +261,7 @@ def deploy_list():
@deploy.command(name="push") @deploy.command(name="push")
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter") @click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
def deploy_push(uuid: Optional[str]): def deploy_push(uuid: str | None):
"""Deploy the Crew.""" """Deploy the Crew."""
deploy_cmd = DeployCommand() deploy_cmd = DeployCommand()
deploy_cmd.deploy(uuid=uuid) deploy_cmd.deploy(uuid=uuid)
@@ -271,7 +269,7 @@ def deploy_push(uuid: Optional[str]):
@deploy.command(name="status") @deploy.command(name="status")
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter") @click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
def deply_status(uuid: Optional[str]): def deply_status(uuid: str | None):
"""Get the status of a deployment.""" """Get the status of a deployment."""
deploy_cmd = DeployCommand() deploy_cmd = DeployCommand()
deploy_cmd.get_crew_status(uuid=uuid) deploy_cmd.get_crew_status(uuid=uuid)
@@ -279,7 +277,7 @@ def deply_status(uuid: Optional[str]):
@deploy.command(name="logs") @deploy.command(name="logs")
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter") @click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
def deploy_logs(uuid: Optional[str]): def deploy_logs(uuid: str | None):
"""Get the logs of a deployment.""" """Get the logs of a deployment."""
deploy_cmd = DeployCommand() deploy_cmd = DeployCommand()
deploy_cmd.get_crew_logs(uuid=uuid) deploy_cmd.get_crew_logs(uuid=uuid)
@@ -287,7 +285,7 @@ def deploy_logs(uuid: Optional[str]):
@deploy.command(name="remove") @deploy.command(name="remove")
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter") @click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
def deploy_remove(uuid: Optional[str]): def deploy_remove(uuid: str | None):
"""Remove a deployment.""" """Remove a deployment."""
deploy_cmd = DeployCommand() deploy_cmd = DeployCommand()
deploy_cmd.remove_crew(uuid=uuid) deploy_cmd.remove_crew(uuid=uuid)
@@ -327,7 +325,6 @@ def tool_publish(is_public: bool, force: bool):
@crewai.group() @crewai.group()
def flow(): def flow():
"""Flow related commands.""" """Flow related commands."""
pass
@flow.command(name="kickoff") @flow.command(name="kickoff")
@@ -359,7 +356,7 @@ def chat():
and using the Chat LLM to generate responses. and using the Chat LLM to generate responses.
""" """
click.secho( click.secho(
"\nStarting a conversation with the Crew\n" "Type 'exit' or Ctrl+C to quit.\n", "\nStarting a conversation with the Crew\nType 'exit' or Ctrl+C to quit.\n",
) )
run_chat() run_chat()
@@ -368,7 +365,6 @@ def chat():
@crewai.group(invoke_without_command=True) @crewai.group(invoke_without_command=True)
def org(): def org():
"""Organization management commands.""" """Organization management commands."""
pass
@org.command("list") @org.command("list")
@@ -396,7 +392,6 @@ def current():
@crewai.group() @crewai.group()
def enterprise(): def enterprise():
"""Enterprise Configuration commands.""" """Enterprise Configuration commands."""
pass
@enterprise.command("configure") @enterprise.command("configure")
@@ -410,7 +405,6 @@ def enterprise_configure(enterprise_url: str):
@crewai.group() @crewai.group()
def config(): def config():
"""CLI Configuration commands.""" """CLI Configuration commands."""
pass
@config.command("list") @config.command("list")

View File

@@ -1,15 +1,14 @@
import json import json
from pathlib import Path from pathlib import Path
from typing import Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from crewai.cli.constants import ( from crewai.cli.constants import (
DEFAULT_CREWAI_ENTERPRISE_URL,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE, CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID, CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN, CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER,
DEFAULT_CREWAI_ENTERPRISE_URL,
) )
from crewai.cli.shared.token_manager import TokenManager from crewai.cli.shared.token_manager import TokenManager
@@ -56,20 +55,20 @@ HIDDEN_SETTINGS_KEYS = [
class Settings(BaseModel): class Settings(BaseModel):
enterprise_base_url: Optional[str] = Field( enterprise_base_url: str | None = Field(
default=DEFAULT_CLI_SETTINGS["enterprise_base_url"], default=DEFAULT_CLI_SETTINGS["enterprise_base_url"],
description="Base URL of the CrewAI Enterprise instance", description="Base URL of the CrewAI Enterprise instance",
) )
tool_repository_username: Optional[str] = Field( tool_repository_username: str | None = Field(
None, description="Username for interacting with the Tool Repository" None, description="Username for interacting with the Tool Repository"
) )
tool_repository_password: Optional[str] = Field( tool_repository_password: str | None = Field(
None, description="Password for interacting with the Tool Repository" None, description="Password for interacting with the Tool Repository"
) )
org_name: Optional[str] = Field( org_name: str | None = Field(
None, description="Name of the currently active organization" None, description="Name of the currently active organization"
) )
org_uuid: Optional[str] = Field( org_uuid: str | None = Field(
None, description="UUID of the currently active organization" None, description="UUID of the currently active organization"
) )
config_path: Path = Field(default=DEFAULT_CONFIG_PATH, frozen=True, exclude=True) config_path: Path = Field(default=DEFAULT_CONFIG_PATH, frozen=True, exclude=True)
@@ -79,7 +78,7 @@ class Settings(BaseModel):
default=DEFAULT_CLI_SETTINGS["oauth2_provider"], default=DEFAULT_CLI_SETTINGS["oauth2_provider"],
) )
oauth2_audience: Optional[str] = Field( oauth2_audience: str | None = Field(
description="OAuth2 audience value, typically used to identify the target API or resource.", description="OAuth2 audience value, typically used to identify the target API or resource.",
default=DEFAULT_CLI_SETTINGS["oauth2_audience"], default=DEFAULT_CLI_SETTINGS["oauth2_audience"],
) )

View File

@@ -16,48 +16,72 @@ from crewai.cli.utils import copy_template, load_env_vars, write_env_file
def create_folder_structure(name, parent_folder=None): def create_folder_structure(name, parent_folder=None):
import keyword import keyword
import re import re
name = name.rstrip('/') name = name.rstrip("/")
if not name.strip(): if not name.strip():
raise ValueError("Project name cannot be empty or contain only whitespace") raise ValueError("Project name cannot be empty or contain only whitespace")
folder_name = name.replace(" ", "_").replace("-", "_").lower() folder_name = name.replace(" ", "_").replace("-", "_").lower()
folder_name = re.sub(r'[^a-zA-Z0-9_]', '', folder_name) folder_name = re.sub(r"[^a-zA-Z0-9_]", "", folder_name)
# Check if the name starts with invalid characters or is primarily invalid # Check if the name starts with invalid characters or is primarily invalid
if re.match(r'^[^a-zA-Z0-9_-]+', name): if re.match(r"^[^a-zA-Z0-9_-]+", name):
raise ValueError(f"Project name '{name}' contains no valid characters for a Python module name") raise ValueError(
f"Project name '{name}' contains no valid characters for a Python module name"
)
if not folder_name: if not folder_name:
raise ValueError(f"Project name '{name}' contains no valid characters for a Python module name") raise ValueError(
f"Project name '{name}' contains no valid characters for a Python module name"
)
if folder_name[0].isdigit(): if folder_name[0].isdigit():
raise ValueError(f"Project name '{name}' would generate folder name '{folder_name}' which cannot start with a digit (invalid Python module name)") raise ValueError(
f"Project name '{name}' would generate folder name '{folder_name}' which cannot start with a digit (invalid Python module name)"
)
if keyword.iskeyword(folder_name): if keyword.iskeyword(folder_name):
raise ValueError(f"Project name '{name}' would generate folder name '{folder_name}' which is a reserved Python keyword") raise ValueError(
f"Project name '{name}' would generate folder name '{folder_name}' which is a reserved Python keyword"
)
if not folder_name.isidentifier(): if not folder_name.isidentifier():
raise ValueError(f"Project name '{name}' would generate invalid Python module name '{folder_name}'") raise ValueError(
f"Project name '{name}' would generate invalid Python module name '{folder_name}'"
)
class_name = name.replace("_", " ").replace("-", " ").title().replace(" ", "") class_name = name.replace("_", " ").replace("-", " ").title().replace(" ", "")
class_name = re.sub(r'[^a-zA-Z0-9_]', '', class_name) class_name = re.sub(r"[^a-zA-Z0-9_]", "", class_name)
if not class_name: if not class_name:
raise ValueError(f"Project name '{name}' contains no valid characters for a Python class name") raise ValueError(
f"Project name '{name}' contains no valid characters for a Python class name"
)
if class_name[0].isdigit(): if class_name[0].isdigit():
raise ValueError(f"Project name '{name}' would generate class name '{class_name}' which cannot start with a digit") raise ValueError(
f"Project name '{name}' would generate class name '{class_name}' which cannot start with a digit"
)
# Check if the original name (before title casing) is a keyword # Check if the original name (before title casing) is a keyword
original_name_clean = re.sub(r'[^a-zA-Z0-9_]', '', name.replace("_", "").replace("-", "").lower()) original_name_clean = re.sub(
if keyword.iskeyword(original_name_clean) or keyword.iskeyword(class_name) or class_name in ('True', 'False', 'None'): r"[^a-zA-Z0-9_]", "", name.replace("_", "").replace("-", "").lower()
raise ValueError(f"Project name '{name}' would generate class name '{class_name}' which is a reserved Python keyword") )
if (
keyword.iskeyword(original_name_clean)
or keyword.iskeyword(class_name)
or class_name in ("True", "False", "None")
):
raise ValueError(
f"Project name '{name}' would generate class name '{class_name}' which is a reserved Python keyword"
)
if not class_name.isidentifier(): if not class_name.isidentifier():
raise ValueError(f"Project name '{name}' would generate invalid Python class name '{class_name}'") raise ValueError(
f"Project name '{name}' would generate invalid Python class name '{class_name}'"
)
if parent_folder: if parent_folder:
folder_path = Path(parent_folder) / folder_name folder_path = Path(parent_folder) / folder_name
@@ -172,7 +196,7 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None):
) )
# Check if the selected provider has predefined models # Check if the selected provider has predefined models
if selected_provider in MODELS and MODELS[selected_provider]: if MODELS.get(selected_provider):
while True: while True:
selected_model = select_model(selected_provider, provider_models) selected_model = select_model(selected_provider, provider_models)
if selected_model is None: # User typed 'q' if selected_model is None: # User typed 'q'

View File

@@ -5,7 +5,7 @@ import sys
import threading import threading
import time import time
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple from typing import Any
import click import click
import tomli import tomli
@@ -116,7 +116,7 @@ def show_loading(event: threading.Event):
print() print()
def initialize_chat_llm(crew: Crew) -> Optional[LLM | BaseLLM]: def initialize_chat_llm(crew: Crew) -> LLM | BaseLLM | None:
"""Initializes the chat LLM and handles exceptions.""" """Initializes the chat LLM and handles exceptions."""
try: try:
return create_llm(crew.chat_llm) return create_llm(crew.chat_llm)
@@ -157,7 +157,7 @@ def build_system_message(crew_chat_inputs: ChatInputs) -> str:
) )
def create_tool_function(crew: Crew, messages: List[Dict[str, str]]) -> Any: def create_tool_function(crew: Crew, messages: list[dict[str, str]]) -> Any:
"""Creates a wrapper function for running the crew tool with messages.""" """Creates a wrapper function for running the crew tool with messages."""
def run_crew_tool_with_messages(**kwargs): def run_crew_tool_with_messages(**kwargs):
@@ -193,7 +193,7 @@ def chat_loop(chat_llm, messages, crew_tool_schema, available_functions):
user_input, chat_llm, messages, crew_tool_schema, available_functions user_input, chat_llm, messages, crew_tool_schema, available_functions
) )
except KeyboardInterrupt: except KeyboardInterrupt: # noqa: PERF203
click.echo("\nExiting chat. Goodbye!") click.echo("\nExiting chat. Goodbye!")
break break
except Exception as e: except Exception as e:
@@ -221,9 +221,9 @@ def get_user_input() -> str:
def handle_user_input( def handle_user_input(
user_input: str, user_input: str,
chat_llm: LLM, chat_llm: LLM,
messages: List[Dict[str, str]], messages: list[dict[str, str]],
crew_tool_schema: Dict[str, Any], crew_tool_schema: dict[str, Any],
available_functions: Dict[str, Any], available_functions: dict[str, Any],
) -> None: ) -> None:
if user_input.strip().lower() == "exit": if user_input.strip().lower() == "exit":
click.echo("Exiting chat. Goodbye!") click.echo("Exiting chat. Goodbye!")
@@ -281,7 +281,7 @@ def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict:
} }
def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs): def run_crew_tool(crew: Crew, messages: list[dict[str, str]], **kwargs):
""" """
Runs the crew using crew.kickoff(inputs=kwargs) and returns the output. Runs the crew using crew.kickoff(inputs=kwargs) and returns the output.
@@ -304,9 +304,8 @@ def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs):
crew_output = crew.kickoff(inputs=kwargs) crew_output = crew.kickoff(inputs=kwargs)
# Convert CrewOutput to a string to send back to the user # Convert CrewOutput to a string to send back to the user
result = str(crew_output) return str(crew_output)
return result
except Exception as e: except Exception as e:
# Exit the chat and show the error message # Exit the chat and show the error message
click.secho("An error occurred while running the crew:", fg="red") click.secho("An error occurred while running the crew:", fg="red")
@@ -314,7 +313,7 @@ def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs):
sys.exit(1) sys.exit(1)
def load_crew_and_name() -> Tuple[Crew, str]: def load_crew_and_name() -> tuple[Crew, str]:
""" """
Loads the crew by importing the crew class from the user's project. Loads the crew by importing the crew class from the user's project.
@@ -351,15 +350,17 @@ def load_crew_and_name() -> Tuple[Crew, str]:
try: try:
crew_module = __import__(crew_module_name, fromlist=[crew_class_name]) crew_module = __import__(crew_module_name, fromlist=[crew_class_name])
except ImportError as e: except ImportError as e:
raise ImportError(f"Failed to import crew module {crew_module_name}: {e}") raise ImportError(
f"Failed to import crew module {crew_module_name}: {e}"
) from e
# Get the crew class from the module # Get the crew class from the module
try: try:
crew_class = getattr(crew_module, crew_class_name) crew_class = getattr(crew_module, crew_class_name)
except AttributeError: except AttributeError as e:
raise AttributeError( raise AttributeError(
f"Crew class {crew_class_name} not found in module {crew_module_name}" f"Crew class {crew_class_name} not found in module {crew_module_name}"
) ) from e
# Instantiate the crew # Instantiate the crew
crew_instance = crew_class().crew() crew_instance = crew_class().crew()
@@ -395,7 +396,7 @@ def generate_crew_chat_inputs(crew: Crew, crew_name: str, chat_llm) -> ChatInput
) )
def fetch_required_inputs(crew: Crew) -> Set[str]: def fetch_required_inputs(crew: Crew) -> set[str]:
""" """
Extracts placeholders from the crew's tasks and agents. Extracts placeholders from the crew's tasks and agents.
@@ -405,8 +406,8 @@ def fetch_required_inputs(crew: Crew) -> Set[str]:
Returns: Returns:
Set[str]: A set of placeholder names. Set[str]: A set of placeholder names.
""" """
placeholder_pattern = re.compile(r"\{(.+?)\}") placeholder_pattern = re.compile(r"\{(.+?)}")
required_inputs: Set[str] = set() required_inputs: set[str] = set()
# Scan tasks # Scan tasks
for task in crew.tasks: for task in crew.tasks:
@@ -435,7 +436,7 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
""" """
# Gather context from tasks and agents where the input is used # Gather context from tasks and agents where the input is used
context_texts = [] context_texts = []
placeholder_pattern = re.compile(r"\{(.+?)\}") placeholder_pattern = re.compile(r"\{(.+?)}")
for task in crew.tasks: for task in crew.tasks:
if ( if (
@@ -479,9 +480,7 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
f"{context}" f"{context}"
) )
response = chat_llm.call(messages=[{"role": "user", "content": prompt}]) response = chat_llm.call(messages=[{"role": "user", "content": prompt}])
description = response.strip() return response.strip()
return description
def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str: def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
@@ -497,7 +496,7 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
""" """
# Gather context from tasks and agents # Gather context from tasks and agents
context_texts = [] context_texts = []
placeholder_pattern = re.compile(r"\{(.+?)\}") placeholder_pattern = re.compile(r"\{(.+?)}")
for task in crew.tasks: for task in crew.tasks:
# Replace placeholders with input names # Replace placeholders with input names
@@ -531,6 +530,4 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
f"{context}" f"{context}"
) )
response = chat_llm.call(messages=[{"role": "user", "content": prompt}]) response = chat_llm.call(messages=[{"role": "user", "content": prompt}])
crew_description = response.strip() return response.strip()
return crew_description

View File

@@ -14,11 +14,15 @@ class Repository:
self.fetch() self.fetch()
def is_git_installed(self) -> bool: @staticmethod
def is_git_installed() -> bool:
"""Check if Git is installed and available in the system.""" """Check if Git is installed and available in the system."""
try: try:
subprocess.run( subprocess.run(
["git", "--version"], capture_output=True, check=True, text=True ["git", "--version"], # noqa: S607
capture_output=True,
check=True,
text=True,
) )
return True return True
except (subprocess.CalledProcessError, FileNotFoundError): except (subprocess.CalledProcessError, FileNotFoundError):
@@ -26,22 +30,26 @@ class Repository:
def fetch(self) -> None: def fetch(self) -> None:
"""Fetch latest updates from the remote.""" """Fetch latest updates from the remote."""
subprocess.run(["git", "fetch"], cwd=self.path, check=True) subprocess.run(["git", "fetch"], cwd=self.path, check=True) # noqa: S607
def status(self) -> str: def status(self) -> str:
"""Get the git status in porcelain format.""" """Get the git status in porcelain format."""
return subprocess.check_output( return subprocess.check_output(
["git", "status", "--branch", "--porcelain"], ["git", "status", "--branch", "--porcelain"], # noqa: S607
cwd=self.path, cwd=self.path,
encoding="utf-8", encoding="utf-8",
).strip() ).strip()
@lru_cache(maxsize=None) @lru_cache(maxsize=None) # noqa: B019
def is_git_repo(self) -> bool: def is_git_repo(self) -> bool:
"""Check if the current directory is a git repository.""" """Check if the current directory is a git repository.
Notes:
- TODO: This method is cached to avoid redundant checks, but using lru_cache on methods can lead to memory leaks
"""
try: try:
subprocess.check_output( subprocess.check_output(
["git", "rev-parse", "--is-inside-work-tree"], ["git", "rev-parse", "--is-inside-work-tree"], # noqa: S607
cwd=self.path, cwd=self.path,
encoding="utf-8", encoding="utf-8",
) )
@@ -64,14 +72,13 @@ class Repository:
"""Return True if the Git repository is fully synced with the remote, False otherwise.""" """Return True if the Git repository is fully synced with the remote, False otherwise."""
if self.has_uncommitted_changes() or self.is_ahead_or_behind(): if self.has_uncommitted_changes() or self.is_ahead_or_behind():
return False return False
else: return True
return True
def origin_url(self) -> str | None: def origin_url(self) -> str | None:
"""Get the Git repository's remote URL.""" """Get the Git repository's remote URL."""
try: try:
result = subprocess.run( result = subprocess.run(
["git", "remote", "get-url", "origin"], ["git", "remote", "get-url", "origin"], # noqa: S607
cwd=self.path, cwd=self.path,
capture_output=True, capture_output=True,
text=True, text=True,

View File

@@ -12,8 +12,8 @@ def install_crew(proxy_options: list[str]) -> None:
Install the crew by running the UV command to lock and install. Install the crew by running the UV command to lock and install.
""" """
try: try:
command = ["uv", "sync"] + proxy_options command = ["uv", "sync", *proxy_options]
subprocess.run(command, check=True, capture_output=False, text=True) subprocess.run(command, check=True, capture_output=False, text=True) # noqa: S603
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
click.echo(f"An error occurred while running the crew: {e}", err=True) click.echo(f"An error occurred while running the crew: {e}", err=True)

View File

@@ -1,11 +1,10 @@
from typing import List, Optional
from urllib.parse import urljoin from urllib.parse import urljoin
import requests import requests
from crewai.cli.config import Settings from crewai.cli.config import Settings
from crewai.cli.version import get_crewai_version
from crewai.cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL from crewai.cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL
from crewai.cli.version import get_crewai_version
class PlusAPI: class PlusAPI:
@@ -56,9 +55,9 @@ class PlusAPI:
handle: str, handle: str,
is_public: bool, is_public: bool,
version: str, version: str,
description: Optional[str], description: str | None,
encoded_file: str, encoded_file: str,
available_exports: Optional[List[str]] = None, available_exports: list[str] | None = None,
): ):
params = { params = {
"handle": handle, "handle": handle,

View File

@@ -1,10 +1,10 @@
import os
import certifi
import json import json
import os
import time import time
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
import certifi
import click import click
import requests import requests
@@ -25,7 +25,7 @@ def select_choice(prompt_message, choices):
provider_models = get_provider_data() provider_models = get_provider_data()
if not provider_models: if not provider_models:
return return None
click.secho(prompt_message, fg="cyan") click.secho(prompt_message, fg="cyan")
for idx, choice in enumerate(choices, start=1): for idx, choice in enumerate(choices, start=1):
click.secho(f"{idx}. {choice}", fg="cyan") click.secho(f"{idx}. {choice}", fg="cyan")
@@ -67,7 +67,7 @@ def select_provider(provider_models):
all_providers = sorted(set(predefined_providers + list(provider_models.keys()))) all_providers = sorted(set(predefined_providers + list(provider_models.keys())))
provider = select_choice( provider = select_choice(
"Select a provider to set up:", predefined_providers + ["other"] "Select a provider to set up:", [*predefined_providers, "other"]
) )
if provider is None: # User typed 'q' if provider is None: # User typed 'q'
return None return None
@@ -102,10 +102,9 @@ def select_model(provider, provider_models):
click.secho(f"No models available for provider '{provider}'.", fg="red") click.secho(f"No models available for provider '{provider}'.", fg="red")
return None return None
selected_model = select_choice( return select_choice(
f"Select a model to use for {provider.capitalize()}:", available_models f"Select a model to use for {provider.capitalize()}:", available_models
) )
return selected_model
def load_provider_data(cache_file, cache_expiry): def load_provider_data(cache_file, cache_expiry):
@@ -165,7 +164,7 @@ def fetch_provider_data(cache_file):
Returns: Returns:
- dict or None: The fetched provider data or None if the operation fails. - dict or None: The fetched provider data or None if the operation fails.
""" """
ssl_config = os.environ['SSL_CERT_FILE'] = certifi.where() ssl_config = os.environ["SSL_CERT_FILE"] = certifi.where()
try: try:
response = requests.get(JSON_URL, stream=True, timeout=60, verify=ssl_config) response = requests.get(JSON_URL, stream=True, timeout=60, verify=ssl_config)

View File

@@ -1,6 +1,5 @@
import subprocess import subprocess
from enum import Enum from enum import Enum
from typing import List, Optional
import click import click
from packaging import version from packaging import version
@@ -57,7 +56,7 @@ def execute_command(crew_type: CrewType) -> None:
command = ["uv", "run", "kickoff" if crew_type == CrewType.FLOW else "run_crew"] command = ["uv", "run", "kickoff" if crew_type == CrewType.FLOW else "run_crew"]
try: try:
subprocess.run(command, capture_output=False, text=True, check=True) subprocess.run(command, capture_output=False, text=True, check=True) # noqa: S603
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
handle_error(e, crew_type) handle_error(e, crew_type)

View File

@@ -3,7 +3,7 @@ import os
import sys import sys
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Optional
from cryptography.fernet import Fernet from cryptography.fernet import Fernet
@@ -49,7 +49,7 @@ class TokenManager:
encrypted_data = self.fernet.encrypt(json.dumps(data).encode()) encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
self.save_secure_file(self.file_path, encrypted_data) self.save_secure_file(self.file_path, encrypted_data)
def get_token(self) -> Optional[str]: def get_token(self) -> str | None:
""" """
Get the access token if it is valid and not expired. Get the access token if it is valid and not expired.
@@ -113,7 +113,7 @@ class TokenManager:
# Set appropriate permissions (read/write for owner only) # Set appropriate permissions (read/write for owner only)
os.chmod(file_path, 0o600) os.chmod(file_path, 0o600)
def read_secure_file(self, filename: str) -> Optional[bytes]: def read_secure_file(self, filename: str) -> bytes | None:
""" """
Read the content of a secure file. Read the content of a secure file.

View File

@@ -5,7 +5,7 @@ import sys
from functools import reduce from functools import reduce
from inspect import getmro, isclass, isfunction, ismethod from inspect import getmro, isclass, isfunction, ismethod
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, get_type_hints from typing import Any, get_type_hints
import click import click
import tomli import tomli
@@ -41,8 +41,7 @@ def copy_template(src, dst, name, class_name, folder_name):
def read_toml(file_path: str = "pyproject.toml"): def read_toml(file_path: str = "pyproject.toml"):
"""Read the content of a TOML file and return it as a dictionary.""" """Read the content of a TOML file and return it as a dictionary."""
with open(file_path, "rb") as f: with open(file_path, "rb") as f:
toml_dict = tomli.load(f) return tomli.load(f)
return toml_dict
def parse_toml(content): def parse_toml(content):
@@ -77,7 +76,7 @@ def get_project_description(
def _get_project_attribute( def _get_project_attribute(
pyproject_path: str, keys: List[str], require: bool pyproject_path: str, keys: list[str], require: bool
) -> Any | None: ) -> Any | None:
"""Get an attribute from the pyproject.toml file.""" """Get an attribute from the pyproject.toml file."""
attribute = None attribute = None
@@ -96,16 +95,20 @@ def _get_project_attribute(
except FileNotFoundError: except FileNotFoundError:
console.print(f"Error: {pyproject_path} not found.", style="bold red") console.print(f"Error: {pyproject_path} not found.", style="bold red")
except KeyError: except KeyError:
console.print(f"Error: {pyproject_path} is not a valid pyproject.toml file.", style="bold red")
except tomllib.TOMLDecodeError if sys.version_info >= (3, 11) else Exception as e: # type: ignore
console.print( console.print(
f"Error: {pyproject_path} is not a valid TOML file." f"Error: {pyproject_path} is not a valid pyproject.toml file.",
if sys.version_info >= (3, 11)
else f"Error reading the pyproject.toml file: {e}",
style="bold red", style="bold red",
) )
except Exception as e: except Exception as e:
console.print(f"Error reading the pyproject.toml file: {e}", style="bold red") # Handle TOML decode errors for Python 3.11+
if sys.version_info >= (3, 11) and isinstance(e, tomllib.TOMLDecodeError): # type: ignore
console.print(
f"Error: {pyproject_path} is not a valid TOML file.", style="bold red"
)
else:
console.print(
f"Error reading the pyproject.toml file: {e}", style="bold red"
)
if require and not attribute: if require and not attribute:
console.print( console.print(
@@ -117,7 +120,7 @@ def _get_project_attribute(
return attribute return attribute
def _get_nested_value(data: Dict[str, Any], keys: List[str]) -> Any: def _get_nested_value(data: dict[str, Any], keys: list[str]) -> Any:
return reduce(dict.__getitem__, keys, data) return reduce(dict.__getitem__, keys, data)
@@ -296,7 +299,10 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
try: try:
crew_instances.extend(fetch_crews(module_attr)) crew_instances.extend(fetch_crews(module_attr))
except Exception as e: except Exception as e:
console.print(f"Error processing attribute {attr_name}: {e}", style="bold red") console.print(
f"Error processing attribute {attr_name}: {e}",
style="bold red",
)
continue continue
# If we found crew instances, break out of the loop # If we found crew instances, break out of the loop
@@ -304,12 +310,15 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
break break
except Exception as exec_error: except Exception as exec_error:
console.print(f"Error executing module: {exec_error}", style="bold red") console.print(
f"Error executing module: {exec_error}",
style="bold red",
)
except (ImportError, AttributeError) as e: except (ImportError, AttributeError) as e:
if require: if require:
console.print( console.print(
f"Error importing crew from {crew_path}: {str(e)}", f"Error importing crew from {crew_path}: {e!s}",
style="bold red", style="bold red",
) )
continue continue
@@ -325,9 +334,9 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
except Exception as e: except Exception as e:
if require: if require:
console.print( console.print(
f"Unexpected error while loading crew: {str(e)}", style="bold red" f"Unexpected error while loading crew: {e!s}", style="bold red"
) )
raise SystemExit raise SystemExit from e
return crew_instances return crew_instances
@@ -348,8 +357,7 @@ def get_crew_instance(module_attr) -> Crew | None:
if isinstance(module_attr, Crew): if isinstance(module_attr, Crew):
return module_attr return module_attr
else: return None
return None
def fetch_crews(module_attr) -> list[Crew]: def fetch_crews(module_attr) -> list[Crew]:
@@ -402,11 +410,11 @@ def extract_available_exports(dir_path: str = "src"):
return available_exports return available_exports
except Exception as e: except Exception as e:
console.print(f"[red]Error: Could not extract tool classes: {str(e)}[/red]") console.print(f"[red]Error: Could not extract tool classes: {e!s}[/red]")
console.print( console.print(
"Please ensure your project contains valid tools (classes inheriting from BaseTool or functions with @tool decorator)." "Please ensure your project contains valid tools (classes inheriting from BaseTool or functions with @tool decorator)."
) )
raise SystemExit(1) raise SystemExit(1) from e
def _load_tools_from_init(init_file: Path) -> list[dict[str, Any]]: def _load_tools_from_init(init_file: Path) -> list[dict[str, Any]]:
@@ -440,8 +448,8 @@ def _load_tools_from_init(init_file: Path) -> list[dict[str, Any]]:
] ]
except Exception as e: except Exception as e:
console.print(f"[red]Warning: Could not load {init_file}: {str(e)}[/red]") console.print(f"[red]Warning: Could not load {init_file}: {e!s}[/red]")
raise SystemExit(1) raise SystemExit(1) from e
finally: finally:
sys.modules.pop("temp_module", None) sys.modules.pop("temp_module", None)

View File

@@ -1,17 +1,17 @@
import pytest import pytest
from crewai.cli.authentication.main import Oauth2Settings from crewai.cli.authentication.main import Oauth2Settings
from crewai.cli.authentication.providers.okta import OktaProvider from crewai.cli.authentication.providers.okta import OktaProvider
class TestOktaProvider: class TestOktaProvider:
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def setup_method(self): def setup_method(self):
self.valid_settings = Oauth2Settings( self.valid_settings = Oauth2Settings(
provider="okta", provider="okta",
domain="test-domain.okta.com", domain="test-domain.okta.com",
client_id="test-client-id", client_id="test-client-id",
audience="test-audience" audience="test-audience",
) )
self.provider = OktaProvider(self.valid_settings) self.provider = OktaProvider(self.valid_settings)
@@ -32,7 +32,7 @@ class TestOktaProvider:
provider="okta", provider="okta",
domain="my-company.okta.com", domain="my-company.okta.com",
client_id="test-client", client_id="test-client",
audience="test-audience" audience="test-audience",
) )
provider = OktaProvider(settings) provider = OktaProvider(settings)
expected_url = "https://my-company.okta.com/oauth2/default/v1/device/authorize" expected_url = "https://my-company.okta.com/oauth2/default/v1/device/authorize"
@@ -47,7 +47,7 @@ class TestOktaProvider:
provider="okta", provider="okta",
domain="another-domain.okta.com", domain="another-domain.okta.com",
client_id="test-client", client_id="test-client",
audience="test-audience" audience="test-audience",
) )
provider = OktaProvider(settings) provider = OktaProvider(settings)
expected_url = "https://another-domain.okta.com/oauth2/default/v1/token" expected_url = "https://another-domain.okta.com/oauth2/default/v1/token"
@@ -62,7 +62,7 @@ class TestOktaProvider:
provider="okta", provider="okta",
domain="dev.okta.com", domain="dev.okta.com",
client_id="test-client", client_id="test-client",
audience="test-audience" audience="test-audience",
) )
provider = OktaProvider(settings) provider = OktaProvider(settings)
expected_url = "https://dev.okta.com/oauth2/default/v1/keys" expected_url = "https://dev.okta.com/oauth2/default/v1/keys"
@@ -77,7 +77,7 @@ class TestOktaProvider:
provider="okta", provider="okta",
domain="prod.okta.com", domain="prod.okta.com",
client_id="test-client", client_id="test-client",
audience="test-audience" audience="test-audience",
) )
provider = OktaProvider(settings) provider = OktaProvider(settings)
expected_issuer = "https://prod.okta.com/oauth2/default" expected_issuer = "https://prod.okta.com/oauth2/default"
@@ -91,11 +91,11 @@ class TestOktaProvider:
provider="okta", provider="okta",
domain="test-domain.okta.com", domain="test-domain.okta.com",
client_id="test-client-id", client_id="test-client-id",
audience=None audience=None,
) )
provider = OktaProvider(settings) provider = OktaProvider(settings)
with pytest.raises(AssertionError): with pytest.raises(ValueError, match="Audience is required"):
provider.get_audience() provider.get_audience()
def test_get_client_id(self): def test_get_client_id(self):