mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 00:58:30 +00:00
chore: apply ruff linting fixes to CLI module
fix: apply ruff fixes to CLI and update Okta provider test
This commit is contained in:
0
src/crewai/cli/authentication/providers/__init__.py
Normal file
0
src/crewai/cli/authentication/providers/__init__.py
Normal file
@@ -1,5 +1,6 @@
|
|||||||
from crewai.cli.authentication.providers.base_provider import BaseProvider
|
from crewai.cli.authentication.providers.base_provider import BaseProvider
|
||||||
|
|
||||||
|
|
||||||
class Auth0Provider(BaseProvider):
|
class Auth0Provider(BaseProvider):
|
||||||
def get_authorize_url(self) -> str:
|
def get_authorize_url(self) -> str:
|
||||||
return f"https://{self._get_domain()}/oauth/device/code"
|
return f"https://{self._get_domain()}/oauth/device/code"
|
||||||
@@ -14,13 +15,20 @@ class Auth0Provider(BaseProvider):
|
|||||||
return f"https://{self._get_domain()}/"
|
return f"https://{self._get_domain()}/"
|
||||||
|
|
||||||
def get_audience(self) -> str:
|
def get_audience(self) -> str:
|
||||||
assert self.settings.audience is not None, "Audience is required"
|
if self.settings.audience is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Audience is required. Please set it in the configuration."
|
||||||
|
)
|
||||||
return self.settings.audience
|
return self.settings.audience
|
||||||
|
|
||||||
def get_client_id(self) -> str:
|
def get_client_id(self) -> str:
|
||||||
assert self.settings.client_id is not None, "Client ID is required"
|
if self.settings.client_id is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Client ID is required. Please set it in the configuration."
|
||||||
|
)
|
||||||
return self.settings.client_id
|
return self.settings.client_id
|
||||||
|
|
||||||
def _get_domain(self) -> str:
|
def _get_domain(self) -> str:
|
||||||
assert self.settings.domain is not None, "Domain is required"
|
if self.settings.domain is None:
|
||||||
|
raise ValueError("Domain is required. Please set it in the configuration.")
|
||||||
return self.settings.domain
|
return self.settings.domain
|
||||||
|
|||||||
@@ -1,30 +1,26 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
from crewai.cli.authentication.main import Oauth2Settings
|
from crewai.cli.authentication.main import Oauth2Settings
|
||||||
|
|
||||||
|
|
||||||
class BaseProvider(ABC):
|
class BaseProvider(ABC):
|
||||||
def __init__(self, settings: Oauth2Settings):
|
def __init__(self, settings: Oauth2Settings):
|
||||||
self.settings = settings
|
self.settings = settings
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_authorize_url(self) -> str:
|
def get_authorize_url(self) -> str: ...
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_token_url(self) -> str:
|
def get_token_url(self) -> str: ...
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_jwks_url(self) -> str:
|
def get_jwks_url(self) -> str: ...
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_issuer(self) -> str:
|
def get_issuer(self) -> str: ...
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_audience(self) -> str:
|
def get_audience(self) -> str: ...
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_client_id(self) -> str:
|
def get_client_id(self) -> str: ...
|
||||||
...
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from crewai.cli.authentication.providers.base_provider import BaseProvider
|
from crewai.cli.authentication.providers.base_provider import BaseProvider
|
||||||
|
|
||||||
|
|
||||||
class OktaProvider(BaseProvider):
|
class OktaProvider(BaseProvider):
|
||||||
def get_authorize_url(self) -> str:
|
def get_authorize_url(self) -> str:
|
||||||
return f"https://{self.settings.domain}/oauth2/default/v1/device/authorize"
|
return f"https://{self.settings.domain}/oauth2/default/v1/device/authorize"
|
||||||
@@ -14,9 +15,15 @@ class OktaProvider(BaseProvider):
|
|||||||
return f"https://{self.settings.domain}/oauth2/default"
|
return f"https://{self.settings.domain}/oauth2/default"
|
||||||
|
|
||||||
def get_audience(self) -> str:
|
def get_audience(self) -> str:
|
||||||
assert self.settings.audience is not None
|
if self.settings.audience is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Audience is required. Please set it in the configuration."
|
||||||
|
)
|
||||||
return self.settings.audience
|
return self.settings.audience
|
||||||
|
|
||||||
def get_client_id(self) -> str:
|
def get_client_id(self) -> str:
|
||||||
assert self.settings.client_id is not None
|
if self.settings.client_id is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Client ID is required. Please set it in the configuration."
|
||||||
|
)
|
||||||
return self.settings.client_id
|
return self.settings.client_id
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from crewai.cli.authentication.providers.base_provider import BaseProvider
|
from crewai.cli.authentication.providers.base_provider import BaseProvider
|
||||||
|
|
||||||
|
|
||||||
class WorkosProvider(BaseProvider):
|
class WorkosProvider(BaseProvider):
|
||||||
def get_authorize_url(self) -> str:
|
def get_authorize_url(self) -> str:
|
||||||
return f"https://{self._get_domain()}/oauth2/device_authorization"
|
return f"https://{self._get_domain()}/oauth2/device_authorization"
|
||||||
@@ -17,9 +18,13 @@ class WorkosProvider(BaseProvider):
|
|||||||
return self.settings.audience or ""
|
return self.settings.audience or ""
|
||||||
|
|
||||||
def get_client_id(self) -> str:
|
def get_client_id(self) -> str:
|
||||||
assert self.settings.client_id is not None, "Client ID is required"
|
if self.settings.client_id is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Client ID is required. Please set it in the configuration."
|
||||||
|
)
|
||||||
return self.settings.client_id
|
return self.settings.client_id
|
||||||
|
|
||||||
def _get_domain(self) -> str:
|
def _get_domain(self) -> str:
|
||||||
assert self.settings.domain is not None, "Domain is required"
|
if self.settings.domain is None:
|
||||||
|
raise ValueError("Domain is required. Please set it in the configuration.")
|
||||||
return self.settings.domain
|
return self.settings.domain
|
||||||
|
|||||||
@@ -17,8 +17,6 @@ def validate_jwt_token(
|
|||||||
missing required claims).
|
missing required claims).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
decoded_token = None
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
jwk_client = PyJWKClient(jwks_url)
|
jwk_client = PyJWKClient(jwks_url)
|
||||||
signing_key = jwk_client.get_signing_key_from_jwt(jwt_token)
|
signing_key = jwk_client.get_signing_key_from_jwt(jwt_token)
|
||||||
@@ -26,7 +24,7 @@ def validate_jwt_token(
|
|||||||
_unverified_decoded_token = jwt.decode(
|
_unverified_decoded_token = jwt.decode(
|
||||||
jwt_token, options={"verify_signature": False}
|
jwt_token, options={"verify_signature": False}
|
||||||
)
|
)
|
||||||
decoded_token = jwt.decode(
|
return jwt.decode(
|
||||||
jwt_token,
|
jwt_token,
|
||||||
signing_key.key,
|
signing_key.key,
|
||||||
algorithms=["RS256"],
|
algorithms=["RS256"],
|
||||||
@@ -40,23 +38,22 @@ def validate_jwt_token(
|
|||||||
"require": ["exp", "iat", "iss", "aud", "sub"],
|
"require": ["exp", "iat", "iss", "aud", "sub"],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return decoded_token
|
|
||||||
|
|
||||||
except jwt.ExpiredSignatureError:
|
except jwt.ExpiredSignatureError as e:
|
||||||
raise Exception("Token has expired.")
|
raise Exception("Token has expired.") from e
|
||||||
except jwt.InvalidAudienceError:
|
except jwt.InvalidAudienceError as e:
|
||||||
actual_audience = _unverified_decoded_token.get("aud", "[no audience found]")
|
actual_audience = _unverified_decoded_token.get("aud", "[no audience found]")
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Invalid token audience. Got: '{actual_audience}'. Expected: '{audience}'"
|
f"Invalid token audience. Got: '{actual_audience}'. Expected: '{audience}'"
|
||||||
)
|
) from e
|
||||||
except jwt.InvalidIssuerError:
|
except jwt.InvalidIssuerError as e:
|
||||||
actual_issuer = _unverified_decoded_token.get("iss", "[no issuer found]")
|
actual_issuer = _unverified_decoded_token.get("iss", "[no issuer found]")
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Invalid token issuer. Got: '{actual_issuer}'. Expected: '{issuer}'"
|
f"Invalid token issuer. Got: '{actual_issuer}'. Expected: '{issuer}'"
|
||||||
)
|
) from e
|
||||||
except jwt.MissingRequiredClaimError as e:
|
except jwt.MissingRequiredClaimError as e:
|
||||||
raise Exception(f"Token is missing required claims: {str(e)}")
|
raise Exception(f"Token is missing required claims: {e!s}") from e
|
||||||
except jwt.exceptions.PyJWKClientError as e:
|
except jwt.exceptions.PyJWKClientError as e:
|
||||||
raise Exception(f"JWKS or key processing error: {str(e)}")
|
raise Exception(f"JWKS or key processing error: {e!s}") from e
|
||||||
except jwt.InvalidTokenError as e:
|
except jwt.InvalidTokenError as e:
|
||||||
raise Exception(f"Invalid token: {str(e)}")
|
raise Exception(f"Invalid token: {e!s}") from e
|
||||||
|
|||||||
@@ -1,13 +1,13 @@
|
|||||||
from importlib.metadata import version as get_version
|
from importlib.metadata import version as get_version
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from crewai.cli.config import Settings
|
|
||||||
from crewai.cli.settings.main import SettingsCommand
|
|
||||||
from crewai.cli.add_crew_to_flow import add_crew_to_flow
|
from crewai.cli.add_crew_to_flow import add_crew_to_flow
|
||||||
|
from crewai.cli.config import Settings
|
||||||
from crewai.cli.create_crew import create_crew
|
from crewai.cli.create_crew import create_crew
|
||||||
from crewai.cli.create_flow import create_flow
|
from crewai.cli.create_flow import create_flow
|
||||||
from crewai.cli.crew_chat import run_chat
|
from crewai.cli.crew_chat import run_chat
|
||||||
|
from crewai.cli.settings.main import SettingsCommand
|
||||||
from crewai.memory.storage.kickoff_task_outputs_storage import (
|
from crewai.memory.storage.kickoff_task_outputs_storage import (
|
||||||
KickoffTaskOutputsSQLiteStorage,
|
KickoffTaskOutputsSQLiteStorage,
|
||||||
)
|
)
|
||||||
@@ -237,13 +237,11 @@ def login():
|
|||||||
@crewai.group()
|
@crewai.group()
|
||||||
def deploy():
|
def deploy():
|
||||||
"""Deploy the Crew CLI group."""
|
"""Deploy the Crew CLI group."""
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@crewai.group()
|
@crewai.group()
|
||||||
def tool():
|
def tool():
|
||||||
"""Tool Repository related commands."""
|
"""Tool Repository related commands."""
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@deploy.command(name="create")
|
@deploy.command(name="create")
|
||||||
@@ -263,7 +261,7 @@ def deploy_list():
|
|||||||
|
|
||||||
@deploy.command(name="push")
|
@deploy.command(name="push")
|
||||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||||
def deploy_push(uuid: Optional[str]):
|
def deploy_push(uuid: str | None):
|
||||||
"""Deploy the Crew."""
|
"""Deploy the Crew."""
|
||||||
deploy_cmd = DeployCommand()
|
deploy_cmd = DeployCommand()
|
||||||
deploy_cmd.deploy(uuid=uuid)
|
deploy_cmd.deploy(uuid=uuid)
|
||||||
@@ -271,7 +269,7 @@ def deploy_push(uuid: Optional[str]):
|
|||||||
|
|
||||||
@deploy.command(name="status")
|
@deploy.command(name="status")
|
||||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||||
def deply_status(uuid: Optional[str]):
|
def deply_status(uuid: str | None):
|
||||||
"""Get the status of a deployment."""
|
"""Get the status of a deployment."""
|
||||||
deploy_cmd = DeployCommand()
|
deploy_cmd = DeployCommand()
|
||||||
deploy_cmd.get_crew_status(uuid=uuid)
|
deploy_cmd.get_crew_status(uuid=uuid)
|
||||||
@@ -279,7 +277,7 @@ def deply_status(uuid: Optional[str]):
|
|||||||
|
|
||||||
@deploy.command(name="logs")
|
@deploy.command(name="logs")
|
||||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||||
def deploy_logs(uuid: Optional[str]):
|
def deploy_logs(uuid: str | None):
|
||||||
"""Get the logs of a deployment."""
|
"""Get the logs of a deployment."""
|
||||||
deploy_cmd = DeployCommand()
|
deploy_cmd = DeployCommand()
|
||||||
deploy_cmd.get_crew_logs(uuid=uuid)
|
deploy_cmd.get_crew_logs(uuid=uuid)
|
||||||
@@ -287,7 +285,7 @@ def deploy_logs(uuid: Optional[str]):
|
|||||||
|
|
||||||
@deploy.command(name="remove")
|
@deploy.command(name="remove")
|
||||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||||
def deploy_remove(uuid: Optional[str]):
|
def deploy_remove(uuid: str | None):
|
||||||
"""Remove a deployment."""
|
"""Remove a deployment."""
|
||||||
deploy_cmd = DeployCommand()
|
deploy_cmd = DeployCommand()
|
||||||
deploy_cmd.remove_crew(uuid=uuid)
|
deploy_cmd.remove_crew(uuid=uuid)
|
||||||
@@ -327,7 +325,6 @@ def tool_publish(is_public: bool, force: bool):
|
|||||||
@crewai.group()
|
@crewai.group()
|
||||||
def flow():
|
def flow():
|
||||||
"""Flow related commands."""
|
"""Flow related commands."""
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@flow.command(name="kickoff")
|
@flow.command(name="kickoff")
|
||||||
@@ -359,7 +356,7 @@ def chat():
|
|||||||
and using the Chat LLM to generate responses.
|
and using the Chat LLM to generate responses.
|
||||||
"""
|
"""
|
||||||
click.secho(
|
click.secho(
|
||||||
"\nStarting a conversation with the Crew\n" "Type 'exit' or Ctrl+C to quit.\n",
|
"\nStarting a conversation with the Crew\nType 'exit' or Ctrl+C to quit.\n",
|
||||||
)
|
)
|
||||||
|
|
||||||
run_chat()
|
run_chat()
|
||||||
@@ -368,7 +365,6 @@ def chat():
|
|||||||
@crewai.group(invoke_without_command=True)
|
@crewai.group(invoke_without_command=True)
|
||||||
def org():
|
def org():
|
||||||
"""Organization management commands."""
|
"""Organization management commands."""
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@org.command("list")
|
@org.command("list")
|
||||||
@@ -396,7 +392,6 @@ def current():
|
|||||||
@crewai.group()
|
@crewai.group()
|
||||||
def enterprise():
|
def enterprise():
|
||||||
"""Enterprise Configuration commands."""
|
"""Enterprise Configuration commands."""
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@enterprise.command("configure")
|
@enterprise.command("configure")
|
||||||
@@ -410,7 +405,6 @@ def enterprise_configure(enterprise_url: str):
|
|||||||
@crewai.group()
|
@crewai.group()
|
||||||
def config():
|
def config():
|
||||||
"""CLI Configuration commands."""
|
"""CLI Configuration commands."""
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@config.command("list")
|
@config.command("list")
|
||||||
|
|||||||
@@ -1,15 +1,14 @@
|
|||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from crewai.cli.constants import (
|
from crewai.cli.constants import (
|
||||||
DEFAULT_CREWAI_ENTERPRISE_URL,
|
|
||||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER,
|
|
||||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
|
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
|
||||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
|
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
|
||||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
|
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
|
||||||
|
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER,
|
||||||
|
DEFAULT_CREWAI_ENTERPRISE_URL,
|
||||||
)
|
)
|
||||||
from crewai.cli.shared.token_manager import TokenManager
|
from crewai.cli.shared.token_manager import TokenManager
|
||||||
|
|
||||||
@@ -56,20 +55,20 @@ HIDDEN_SETTINGS_KEYS = [
|
|||||||
|
|
||||||
|
|
||||||
class Settings(BaseModel):
|
class Settings(BaseModel):
|
||||||
enterprise_base_url: Optional[str] = Field(
|
enterprise_base_url: str | None = Field(
|
||||||
default=DEFAULT_CLI_SETTINGS["enterprise_base_url"],
|
default=DEFAULT_CLI_SETTINGS["enterprise_base_url"],
|
||||||
description="Base URL of the CrewAI Enterprise instance",
|
description="Base URL of the CrewAI Enterprise instance",
|
||||||
)
|
)
|
||||||
tool_repository_username: Optional[str] = Field(
|
tool_repository_username: str | None = Field(
|
||||||
None, description="Username for interacting with the Tool Repository"
|
None, description="Username for interacting with the Tool Repository"
|
||||||
)
|
)
|
||||||
tool_repository_password: Optional[str] = Field(
|
tool_repository_password: str | None = Field(
|
||||||
None, description="Password for interacting with the Tool Repository"
|
None, description="Password for interacting with the Tool Repository"
|
||||||
)
|
)
|
||||||
org_name: Optional[str] = Field(
|
org_name: str | None = Field(
|
||||||
None, description="Name of the currently active organization"
|
None, description="Name of the currently active organization"
|
||||||
)
|
)
|
||||||
org_uuid: Optional[str] = Field(
|
org_uuid: str | None = Field(
|
||||||
None, description="UUID of the currently active organization"
|
None, description="UUID of the currently active organization"
|
||||||
)
|
)
|
||||||
config_path: Path = Field(default=DEFAULT_CONFIG_PATH, frozen=True, exclude=True)
|
config_path: Path = Field(default=DEFAULT_CONFIG_PATH, frozen=True, exclude=True)
|
||||||
@@ -79,7 +78,7 @@ class Settings(BaseModel):
|
|||||||
default=DEFAULT_CLI_SETTINGS["oauth2_provider"],
|
default=DEFAULT_CLI_SETTINGS["oauth2_provider"],
|
||||||
)
|
)
|
||||||
|
|
||||||
oauth2_audience: Optional[str] = Field(
|
oauth2_audience: str | None = Field(
|
||||||
description="OAuth2 audience value, typically used to identify the target API or resource.",
|
description="OAuth2 audience value, typically used to identify the target API or resource.",
|
||||||
default=DEFAULT_CLI_SETTINGS["oauth2_audience"],
|
default=DEFAULT_CLI_SETTINGS["oauth2_audience"],
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -16,48 +16,72 @@ from crewai.cli.utils import copy_template, load_env_vars, write_env_file
|
|||||||
def create_folder_structure(name, parent_folder=None):
|
def create_folder_structure(name, parent_folder=None):
|
||||||
import keyword
|
import keyword
|
||||||
import re
|
import re
|
||||||
|
|
||||||
name = name.rstrip('/')
|
name = name.rstrip("/")
|
||||||
|
|
||||||
if not name.strip():
|
if not name.strip():
|
||||||
raise ValueError("Project name cannot be empty or contain only whitespace")
|
raise ValueError("Project name cannot be empty or contain only whitespace")
|
||||||
|
|
||||||
folder_name = name.replace(" ", "_").replace("-", "_").lower()
|
folder_name = name.replace(" ", "_").replace("-", "_").lower()
|
||||||
folder_name = re.sub(r'[^a-zA-Z0-9_]', '', folder_name)
|
folder_name = re.sub(r"[^a-zA-Z0-9_]", "", folder_name)
|
||||||
|
|
||||||
# Check if the name starts with invalid characters or is primarily invalid
|
# Check if the name starts with invalid characters or is primarily invalid
|
||||||
if re.match(r'^[^a-zA-Z0-9_-]+', name):
|
if re.match(r"^[^a-zA-Z0-9_-]+", name):
|
||||||
raise ValueError(f"Project name '{name}' contains no valid characters for a Python module name")
|
raise ValueError(
|
||||||
|
f"Project name '{name}' contains no valid characters for a Python module name"
|
||||||
|
)
|
||||||
|
|
||||||
if not folder_name:
|
if not folder_name:
|
||||||
raise ValueError(f"Project name '{name}' contains no valid characters for a Python module name")
|
raise ValueError(
|
||||||
|
f"Project name '{name}' contains no valid characters for a Python module name"
|
||||||
|
)
|
||||||
|
|
||||||
if folder_name[0].isdigit():
|
if folder_name[0].isdigit():
|
||||||
raise ValueError(f"Project name '{name}' would generate folder name '{folder_name}' which cannot start with a digit (invalid Python module name)")
|
raise ValueError(
|
||||||
|
f"Project name '{name}' would generate folder name '{folder_name}' which cannot start with a digit (invalid Python module name)"
|
||||||
|
)
|
||||||
|
|
||||||
if keyword.iskeyword(folder_name):
|
if keyword.iskeyword(folder_name):
|
||||||
raise ValueError(f"Project name '{name}' would generate folder name '{folder_name}' which is a reserved Python keyword")
|
raise ValueError(
|
||||||
|
f"Project name '{name}' would generate folder name '{folder_name}' which is a reserved Python keyword"
|
||||||
|
)
|
||||||
|
|
||||||
if not folder_name.isidentifier():
|
if not folder_name.isidentifier():
|
||||||
raise ValueError(f"Project name '{name}' would generate invalid Python module name '{folder_name}'")
|
raise ValueError(
|
||||||
|
f"Project name '{name}' would generate invalid Python module name '{folder_name}'"
|
||||||
|
)
|
||||||
|
|
||||||
class_name = name.replace("_", " ").replace("-", " ").title().replace(" ", "")
|
class_name = name.replace("_", " ").replace("-", " ").title().replace(" ", "")
|
||||||
|
|
||||||
class_name = re.sub(r'[^a-zA-Z0-9_]', '', class_name)
|
class_name = re.sub(r"[^a-zA-Z0-9_]", "", class_name)
|
||||||
|
|
||||||
if not class_name:
|
if not class_name:
|
||||||
raise ValueError(f"Project name '{name}' contains no valid characters for a Python class name")
|
raise ValueError(
|
||||||
|
f"Project name '{name}' contains no valid characters for a Python class name"
|
||||||
|
)
|
||||||
|
|
||||||
if class_name[0].isdigit():
|
if class_name[0].isdigit():
|
||||||
raise ValueError(f"Project name '{name}' would generate class name '{class_name}' which cannot start with a digit")
|
raise ValueError(
|
||||||
|
f"Project name '{name}' would generate class name '{class_name}' which cannot start with a digit"
|
||||||
|
)
|
||||||
|
|
||||||
# Check if the original name (before title casing) is a keyword
|
# Check if the original name (before title casing) is a keyword
|
||||||
original_name_clean = re.sub(r'[^a-zA-Z0-9_]', '', name.replace("_", "").replace("-", "").lower())
|
original_name_clean = re.sub(
|
||||||
if keyword.iskeyword(original_name_clean) or keyword.iskeyword(class_name) or class_name in ('True', 'False', 'None'):
|
r"[^a-zA-Z0-9_]", "", name.replace("_", "").replace("-", "").lower()
|
||||||
raise ValueError(f"Project name '{name}' would generate class name '{class_name}' which is a reserved Python keyword")
|
)
|
||||||
|
if (
|
||||||
|
keyword.iskeyword(original_name_clean)
|
||||||
|
or keyword.iskeyword(class_name)
|
||||||
|
or class_name in ("True", "False", "None")
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Project name '{name}' would generate class name '{class_name}' which is a reserved Python keyword"
|
||||||
|
)
|
||||||
|
|
||||||
if not class_name.isidentifier():
|
if not class_name.isidentifier():
|
||||||
raise ValueError(f"Project name '{name}' would generate invalid Python class name '{class_name}'")
|
raise ValueError(
|
||||||
|
f"Project name '{name}' would generate invalid Python class name '{class_name}'"
|
||||||
|
)
|
||||||
|
|
||||||
if parent_folder:
|
if parent_folder:
|
||||||
folder_path = Path(parent_folder) / folder_name
|
folder_path = Path(parent_folder) / folder_name
|
||||||
@@ -172,7 +196,7 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Check if the selected provider has predefined models
|
# Check if the selected provider has predefined models
|
||||||
if selected_provider in MODELS and MODELS[selected_provider]:
|
if MODELS.get(selected_provider):
|
||||||
while True:
|
while True:
|
||||||
selected_model = select_model(selected_provider, provider_models)
|
selected_model = select_model(selected_provider, provider_models)
|
||||||
if selected_model is None: # User typed 'q'
|
if selected_model is None: # User typed 'q'
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import sys
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
from typing import Any
|
||||||
|
|
||||||
import click
|
import click
|
||||||
import tomli
|
import tomli
|
||||||
@@ -116,7 +116,7 @@ def show_loading(event: threading.Event):
|
|||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
||||||
def initialize_chat_llm(crew: Crew) -> Optional[LLM | BaseLLM]:
|
def initialize_chat_llm(crew: Crew) -> LLM | BaseLLM | None:
|
||||||
"""Initializes the chat LLM and handles exceptions."""
|
"""Initializes the chat LLM and handles exceptions."""
|
||||||
try:
|
try:
|
||||||
return create_llm(crew.chat_llm)
|
return create_llm(crew.chat_llm)
|
||||||
@@ -157,7 +157,7 @@ def build_system_message(crew_chat_inputs: ChatInputs) -> str:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_tool_function(crew: Crew, messages: List[Dict[str, str]]) -> Any:
|
def create_tool_function(crew: Crew, messages: list[dict[str, str]]) -> Any:
|
||||||
"""Creates a wrapper function for running the crew tool with messages."""
|
"""Creates a wrapper function for running the crew tool with messages."""
|
||||||
|
|
||||||
def run_crew_tool_with_messages(**kwargs):
|
def run_crew_tool_with_messages(**kwargs):
|
||||||
@@ -193,7 +193,7 @@ def chat_loop(chat_llm, messages, crew_tool_schema, available_functions):
|
|||||||
user_input, chat_llm, messages, crew_tool_schema, available_functions
|
user_input, chat_llm, messages, crew_tool_schema, available_functions
|
||||||
)
|
)
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt: # noqa: PERF203
|
||||||
click.echo("\nExiting chat. Goodbye!")
|
click.echo("\nExiting chat. Goodbye!")
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -221,9 +221,9 @@ def get_user_input() -> str:
|
|||||||
def handle_user_input(
|
def handle_user_input(
|
||||||
user_input: str,
|
user_input: str,
|
||||||
chat_llm: LLM,
|
chat_llm: LLM,
|
||||||
messages: List[Dict[str, str]],
|
messages: list[dict[str, str]],
|
||||||
crew_tool_schema: Dict[str, Any],
|
crew_tool_schema: dict[str, Any],
|
||||||
available_functions: Dict[str, Any],
|
available_functions: dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
if user_input.strip().lower() == "exit":
|
if user_input.strip().lower() == "exit":
|
||||||
click.echo("Exiting chat. Goodbye!")
|
click.echo("Exiting chat. Goodbye!")
|
||||||
@@ -281,7 +281,7 @@ def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs):
|
def run_crew_tool(crew: Crew, messages: list[dict[str, str]], **kwargs):
|
||||||
"""
|
"""
|
||||||
Runs the crew using crew.kickoff(inputs=kwargs) and returns the output.
|
Runs the crew using crew.kickoff(inputs=kwargs) and returns the output.
|
||||||
|
|
||||||
@@ -304,9 +304,8 @@ def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs):
|
|||||||
crew_output = crew.kickoff(inputs=kwargs)
|
crew_output = crew.kickoff(inputs=kwargs)
|
||||||
|
|
||||||
# Convert CrewOutput to a string to send back to the user
|
# Convert CrewOutput to a string to send back to the user
|
||||||
result = str(crew_output)
|
return str(crew_output)
|
||||||
|
|
||||||
return result
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Exit the chat and show the error message
|
# Exit the chat and show the error message
|
||||||
click.secho("An error occurred while running the crew:", fg="red")
|
click.secho("An error occurred while running the crew:", fg="red")
|
||||||
@@ -314,7 +313,7 @@ def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs):
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def load_crew_and_name() -> Tuple[Crew, str]:
|
def load_crew_and_name() -> tuple[Crew, str]:
|
||||||
"""
|
"""
|
||||||
Loads the crew by importing the crew class from the user's project.
|
Loads the crew by importing the crew class from the user's project.
|
||||||
|
|
||||||
@@ -351,15 +350,17 @@ def load_crew_and_name() -> Tuple[Crew, str]:
|
|||||||
try:
|
try:
|
||||||
crew_module = __import__(crew_module_name, fromlist=[crew_class_name])
|
crew_module = __import__(crew_module_name, fromlist=[crew_class_name])
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ImportError(f"Failed to import crew module {crew_module_name}: {e}")
|
raise ImportError(
|
||||||
|
f"Failed to import crew module {crew_module_name}: {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
# Get the crew class from the module
|
# Get the crew class from the module
|
||||||
try:
|
try:
|
||||||
crew_class = getattr(crew_module, crew_class_name)
|
crew_class = getattr(crew_module, crew_class_name)
|
||||||
except AttributeError:
|
except AttributeError as e:
|
||||||
raise AttributeError(
|
raise AttributeError(
|
||||||
f"Crew class {crew_class_name} not found in module {crew_module_name}"
|
f"Crew class {crew_class_name} not found in module {crew_module_name}"
|
||||||
)
|
) from e
|
||||||
|
|
||||||
# Instantiate the crew
|
# Instantiate the crew
|
||||||
crew_instance = crew_class().crew()
|
crew_instance = crew_class().crew()
|
||||||
@@ -395,7 +396,7 @@ def generate_crew_chat_inputs(crew: Crew, crew_name: str, chat_llm) -> ChatInput
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def fetch_required_inputs(crew: Crew) -> Set[str]:
|
def fetch_required_inputs(crew: Crew) -> set[str]:
|
||||||
"""
|
"""
|
||||||
Extracts placeholders from the crew's tasks and agents.
|
Extracts placeholders from the crew's tasks and agents.
|
||||||
|
|
||||||
@@ -405,8 +406,8 @@ def fetch_required_inputs(crew: Crew) -> Set[str]:
|
|||||||
Returns:
|
Returns:
|
||||||
Set[str]: A set of placeholder names.
|
Set[str]: A set of placeholder names.
|
||||||
"""
|
"""
|
||||||
placeholder_pattern = re.compile(r"\{(.+?)\}")
|
placeholder_pattern = re.compile(r"\{(.+?)}")
|
||||||
required_inputs: Set[str] = set()
|
required_inputs: set[str] = set()
|
||||||
|
|
||||||
# Scan tasks
|
# Scan tasks
|
||||||
for task in crew.tasks:
|
for task in crew.tasks:
|
||||||
@@ -435,7 +436,7 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
|
|||||||
"""
|
"""
|
||||||
# Gather context from tasks and agents where the input is used
|
# Gather context from tasks and agents where the input is used
|
||||||
context_texts = []
|
context_texts = []
|
||||||
placeholder_pattern = re.compile(r"\{(.+?)\}")
|
placeholder_pattern = re.compile(r"\{(.+?)}")
|
||||||
|
|
||||||
for task in crew.tasks:
|
for task in crew.tasks:
|
||||||
if (
|
if (
|
||||||
@@ -479,9 +480,7 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
|
|||||||
f"{context}"
|
f"{context}"
|
||||||
)
|
)
|
||||||
response = chat_llm.call(messages=[{"role": "user", "content": prompt}])
|
response = chat_llm.call(messages=[{"role": "user", "content": prompt}])
|
||||||
description = response.strip()
|
return response.strip()
|
||||||
|
|
||||||
return description
|
|
||||||
|
|
||||||
|
|
||||||
def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
|
def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
|
||||||
@@ -497,7 +496,7 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
|
|||||||
"""
|
"""
|
||||||
# Gather context from tasks and agents
|
# Gather context from tasks and agents
|
||||||
context_texts = []
|
context_texts = []
|
||||||
placeholder_pattern = re.compile(r"\{(.+?)\}")
|
placeholder_pattern = re.compile(r"\{(.+?)}")
|
||||||
|
|
||||||
for task in crew.tasks:
|
for task in crew.tasks:
|
||||||
# Replace placeholders with input names
|
# Replace placeholders with input names
|
||||||
@@ -531,6 +530,4 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
|
|||||||
f"{context}"
|
f"{context}"
|
||||||
)
|
)
|
||||||
response = chat_llm.call(messages=[{"role": "user", "content": prompt}])
|
response = chat_llm.call(messages=[{"role": "user", "content": prompt}])
|
||||||
crew_description = response.strip()
|
return response.strip()
|
||||||
|
|
||||||
return crew_description
|
|
||||||
|
|||||||
@@ -14,11 +14,15 @@ class Repository:
|
|||||||
|
|
||||||
self.fetch()
|
self.fetch()
|
||||||
|
|
||||||
def is_git_installed(self) -> bool:
|
@staticmethod
|
||||||
|
def is_git_installed() -> bool:
|
||||||
"""Check if Git is installed and available in the system."""
|
"""Check if Git is installed and available in the system."""
|
||||||
try:
|
try:
|
||||||
subprocess.run(
|
subprocess.run(
|
||||||
["git", "--version"], capture_output=True, check=True, text=True
|
["git", "--version"], # noqa: S607
|
||||||
|
capture_output=True,
|
||||||
|
check=True,
|
||||||
|
text=True,
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||||
@@ -26,22 +30,26 @@ class Repository:
|
|||||||
|
|
||||||
def fetch(self) -> None:
|
def fetch(self) -> None:
|
||||||
"""Fetch latest updates from the remote."""
|
"""Fetch latest updates from the remote."""
|
||||||
subprocess.run(["git", "fetch"], cwd=self.path, check=True)
|
subprocess.run(["git", "fetch"], cwd=self.path, check=True) # noqa: S607
|
||||||
|
|
||||||
def status(self) -> str:
|
def status(self) -> str:
|
||||||
"""Get the git status in porcelain format."""
|
"""Get the git status in porcelain format."""
|
||||||
return subprocess.check_output(
|
return subprocess.check_output(
|
||||||
["git", "status", "--branch", "--porcelain"],
|
["git", "status", "--branch", "--porcelain"], # noqa: S607
|
||||||
cwd=self.path,
|
cwd=self.path,
|
||||||
encoding="utf-8",
|
encoding="utf-8",
|
||||||
).strip()
|
).strip()
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None) # noqa: B019
|
||||||
def is_git_repo(self) -> bool:
|
def is_git_repo(self) -> bool:
|
||||||
"""Check if the current directory is a git repository."""
|
"""Check if the current directory is a git repository.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- TODO: This method is cached to avoid redundant checks, but using lru_cache on methods can lead to memory leaks
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
subprocess.check_output(
|
subprocess.check_output(
|
||||||
["git", "rev-parse", "--is-inside-work-tree"],
|
["git", "rev-parse", "--is-inside-work-tree"], # noqa: S607
|
||||||
cwd=self.path,
|
cwd=self.path,
|
||||||
encoding="utf-8",
|
encoding="utf-8",
|
||||||
)
|
)
|
||||||
@@ -64,14 +72,13 @@ class Repository:
|
|||||||
"""Return True if the Git repository is fully synced with the remote, False otherwise."""
|
"""Return True if the Git repository is fully synced with the remote, False otherwise."""
|
||||||
if self.has_uncommitted_changes() or self.is_ahead_or_behind():
|
if self.has_uncommitted_changes() or self.is_ahead_or_behind():
|
||||||
return False
|
return False
|
||||||
else:
|
return True
|
||||||
return True
|
|
||||||
|
|
||||||
def origin_url(self) -> str | None:
|
def origin_url(self) -> str | None:
|
||||||
"""Get the Git repository's remote URL."""
|
"""Get the Git repository's remote URL."""
|
||||||
try:
|
try:
|
||||||
result = subprocess.run(
|
result = subprocess.run(
|
||||||
["git", "remote", "get-url", "origin"],
|
["git", "remote", "get-url", "origin"], # noqa: S607
|
||||||
cwd=self.path,
|
cwd=self.path,
|
||||||
capture_output=True,
|
capture_output=True,
|
||||||
text=True,
|
text=True,
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ def install_crew(proxy_options: list[str]) -> None:
|
|||||||
Install the crew by running the UV command to lock and install.
|
Install the crew by running the UV command to lock and install.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
command = ["uv", "sync"] + proxy_options
|
command = ["uv", "sync", *proxy_options]
|
||||||
subprocess.run(command, check=True, capture_output=False, text=True)
|
subprocess.run(command, check=True, capture_output=False, text=True) # noqa: S603
|
||||||
|
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
click.echo(f"An error occurred while running the crew: {e}", err=True)
|
click.echo(f"An error occurred while running the crew: {e}", err=True)
|
||||||
|
|||||||
@@ -1,11 +1,10 @@
|
|||||||
from typing import List, Optional
|
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from crewai.cli.config import Settings
|
from crewai.cli.config import Settings
|
||||||
from crewai.cli.version import get_crewai_version
|
|
||||||
from crewai.cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL
|
from crewai.cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL
|
||||||
|
from crewai.cli.version import get_crewai_version
|
||||||
|
|
||||||
|
|
||||||
class PlusAPI:
|
class PlusAPI:
|
||||||
@@ -56,9 +55,9 @@ class PlusAPI:
|
|||||||
handle: str,
|
handle: str,
|
||||||
is_public: bool,
|
is_public: bool,
|
||||||
version: str,
|
version: str,
|
||||||
description: Optional[str],
|
description: str | None,
|
||||||
encoded_file: str,
|
encoded_file: str,
|
||||||
available_exports: Optional[List[str]] = None,
|
available_exports: list[str] | None = None,
|
||||||
):
|
):
|
||||||
params = {
|
params = {
|
||||||
"handle": handle,
|
"handle": handle,
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
import os
|
|
||||||
import certifi
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import certifi
|
||||||
import click
|
import click
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
@@ -25,7 +25,7 @@ def select_choice(prompt_message, choices):
|
|||||||
|
|
||||||
provider_models = get_provider_data()
|
provider_models = get_provider_data()
|
||||||
if not provider_models:
|
if not provider_models:
|
||||||
return
|
return None
|
||||||
click.secho(prompt_message, fg="cyan")
|
click.secho(prompt_message, fg="cyan")
|
||||||
for idx, choice in enumerate(choices, start=1):
|
for idx, choice in enumerate(choices, start=1):
|
||||||
click.secho(f"{idx}. {choice}", fg="cyan")
|
click.secho(f"{idx}. {choice}", fg="cyan")
|
||||||
@@ -67,7 +67,7 @@ def select_provider(provider_models):
|
|||||||
all_providers = sorted(set(predefined_providers + list(provider_models.keys())))
|
all_providers = sorted(set(predefined_providers + list(provider_models.keys())))
|
||||||
|
|
||||||
provider = select_choice(
|
provider = select_choice(
|
||||||
"Select a provider to set up:", predefined_providers + ["other"]
|
"Select a provider to set up:", [*predefined_providers, "other"]
|
||||||
)
|
)
|
||||||
if provider is None: # User typed 'q'
|
if provider is None: # User typed 'q'
|
||||||
return None
|
return None
|
||||||
@@ -102,10 +102,9 @@ def select_model(provider, provider_models):
|
|||||||
click.secho(f"No models available for provider '{provider}'.", fg="red")
|
click.secho(f"No models available for provider '{provider}'.", fg="red")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
selected_model = select_choice(
|
return select_choice(
|
||||||
f"Select a model to use for {provider.capitalize()}:", available_models
|
f"Select a model to use for {provider.capitalize()}:", available_models
|
||||||
)
|
)
|
||||||
return selected_model
|
|
||||||
|
|
||||||
|
|
||||||
def load_provider_data(cache_file, cache_expiry):
|
def load_provider_data(cache_file, cache_expiry):
|
||||||
@@ -165,7 +164,7 @@ def fetch_provider_data(cache_file):
|
|||||||
Returns:
|
Returns:
|
||||||
- dict or None: The fetched provider data or None if the operation fails.
|
- dict or None: The fetched provider data or None if the operation fails.
|
||||||
"""
|
"""
|
||||||
ssl_config = os.environ['SSL_CERT_FILE'] = certifi.where()
|
ssl_config = os.environ["SSL_CERT_FILE"] = certifi.where()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = requests.get(JSON_URL, stream=True, timeout=60, verify=ssl_config)
|
response = requests.get(JSON_URL, stream=True, timeout=60, verify=ssl_config)
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import subprocess
|
import subprocess
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from packaging import version
|
from packaging import version
|
||||||
@@ -57,7 +56,7 @@ def execute_command(crew_type: CrewType) -> None:
|
|||||||
command = ["uv", "run", "kickoff" if crew_type == CrewType.FLOW else "run_crew"]
|
command = ["uv", "run", "kickoff" if crew_type == CrewType.FLOW else "run_crew"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
subprocess.run(command, capture_output=False, text=True, check=True)
|
subprocess.run(command, capture_output=False, text=True, check=True) # noqa: S603
|
||||||
|
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
handle_error(e, crew_type)
|
handle_error(e, crew_type)
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
|
||||||
from cryptography.fernet import Fernet
|
from cryptography.fernet import Fernet
|
||||||
|
|
||||||
|
|
||||||
@@ -49,7 +49,7 @@ class TokenManager:
|
|||||||
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
|
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
|
||||||
self.save_secure_file(self.file_path, encrypted_data)
|
self.save_secure_file(self.file_path, encrypted_data)
|
||||||
|
|
||||||
def get_token(self) -> Optional[str]:
|
def get_token(self) -> str | None:
|
||||||
"""
|
"""
|
||||||
Get the access token if it is valid and not expired.
|
Get the access token if it is valid and not expired.
|
||||||
|
|
||||||
@@ -113,7 +113,7 @@ class TokenManager:
|
|||||||
# Set appropriate permissions (read/write for owner only)
|
# Set appropriate permissions (read/write for owner only)
|
||||||
os.chmod(file_path, 0o600)
|
os.chmod(file_path, 0o600)
|
||||||
|
|
||||||
def read_secure_file(self, filename: str) -> Optional[bytes]:
|
def read_secure_file(self, filename: str) -> bytes | None:
|
||||||
"""
|
"""
|
||||||
Read the content of a secure file.
|
Read the content of a secure file.
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import sys
|
|||||||
from functools import reduce
|
from functools import reduce
|
||||||
from inspect import getmro, isclass, isfunction, ismethod
|
from inspect import getmro, isclass, isfunction, ismethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, get_type_hints
|
from typing import Any, get_type_hints
|
||||||
|
|
||||||
import click
|
import click
|
||||||
import tomli
|
import tomli
|
||||||
@@ -41,8 +41,7 @@ def copy_template(src, dst, name, class_name, folder_name):
|
|||||||
def read_toml(file_path: str = "pyproject.toml"):
|
def read_toml(file_path: str = "pyproject.toml"):
|
||||||
"""Read the content of a TOML file and return it as a dictionary."""
|
"""Read the content of a TOML file and return it as a dictionary."""
|
||||||
with open(file_path, "rb") as f:
|
with open(file_path, "rb") as f:
|
||||||
toml_dict = tomli.load(f)
|
return tomli.load(f)
|
||||||
return toml_dict
|
|
||||||
|
|
||||||
|
|
||||||
def parse_toml(content):
|
def parse_toml(content):
|
||||||
@@ -77,7 +76,7 @@ def get_project_description(
|
|||||||
|
|
||||||
|
|
||||||
def _get_project_attribute(
|
def _get_project_attribute(
|
||||||
pyproject_path: str, keys: List[str], require: bool
|
pyproject_path: str, keys: list[str], require: bool
|
||||||
) -> Any | None:
|
) -> Any | None:
|
||||||
"""Get an attribute from the pyproject.toml file."""
|
"""Get an attribute from the pyproject.toml file."""
|
||||||
attribute = None
|
attribute = None
|
||||||
@@ -96,16 +95,20 @@ def _get_project_attribute(
|
|||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
console.print(f"Error: {pyproject_path} not found.", style="bold red")
|
console.print(f"Error: {pyproject_path} not found.", style="bold red")
|
||||||
except KeyError:
|
except KeyError:
|
||||||
console.print(f"Error: {pyproject_path} is not a valid pyproject.toml file.", style="bold red")
|
|
||||||
except tomllib.TOMLDecodeError if sys.version_info >= (3, 11) else Exception as e: # type: ignore
|
|
||||||
console.print(
|
console.print(
|
||||||
f"Error: {pyproject_path} is not a valid TOML file."
|
f"Error: {pyproject_path} is not a valid pyproject.toml file.",
|
||||||
if sys.version_info >= (3, 11)
|
|
||||||
else f"Error reading the pyproject.toml file: {e}",
|
|
||||||
style="bold red",
|
style="bold red",
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
console.print(f"Error reading the pyproject.toml file: {e}", style="bold red")
|
# Handle TOML decode errors for Python 3.11+
|
||||||
|
if sys.version_info >= (3, 11) and isinstance(e, tomllib.TOMLDecodeError): # type: ignore
|
||||||
|
console.print(
|
||||||
|
f"Error: {pyproject_path} is not a valid TOML file.", style="bold red"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
console.print(
|
||||||
|
f"Error reading the pyproject.toml file: {e}", style="bold red"
|
||||||
|
)
|
||||||
|
|
||||||
if require and not attribute:
|
if require and not attribute:
|
||||||
console.print(
|
console.print(
|
||||||
@@ -117,7 +120,7 @@ def _get_project_attribute(
|
|||||||
return attribute
|
return attribute
|
||||||
|
|
||||||
|
|
||||||
def _get_nested_value(data: Dict[str, Any], keys: List[str]) -> Any:
|
def _get_nested_value(data: dict[str, Any], keys: list[str]) -> Any:
|
||||||
return reduce(dict.__getitem__, keys, data)
|
return reduce(dict.__getitem__, keys, data)
|
||||||
|
|
||||||
|
|
||||||
@@ -296,7 +299,10 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
|
|||||||
try:
|
try:
|
||||||
crew_instances.extend(fetch_crews(module_attr))
|
crew_instances.extend(fetch_crews(module_attr))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
console.print(f"Error processing attribute {attr_name}: {e}", style="bold red")
|
console.print(
|
||||||
|
f"Error processing attribute {attr_name}: {e}",
|
||||||
|
style="bold red",
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# If we found crew instances, break out of the loop
|
# If we found crew instances, break out of the loop
|
||||||
@@ -304,12 +310,15 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
|
|||||||
break
|
break
|
||||||
|
|
||||||
except Exception as exec_error:
|
except Exception as exec_error:
|
||||||
console.print(f"Error executing module: {exec_error}", style="bold red")
|
console.print(
|
||||||
|
f"Error executing module: {exec_error}",
|
||||||
|
style="bold red",
|
||||||
|
)
|
||||||
|
|
||||||
except (ImportError, AttributeError) as e:
|
except (ImportError, AttributeError) as e:
|
||||||
if require:
|
if require:
|
||||||
console.print(
|
console.print(
|
||||||
f"Error importing crew from {crew_path}: {str(e)}",
|
f"Error importing crew from {crew_path}: {e!s}",
|
||||||
style="bold red",
|
style="bold red",
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
@@ -325,9 +334,9 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
if require:
|
if require:
|
||||||
console.print(
|
console.print(
|
||||||
f"Unexpected error while loading crew: {str(e)}", style="bold red"
|
f"Unexpected error while loading crew: {e!s}", style="bold red"
|
||||||
)
|
)
|
||||||
raise SystemExit
|
raise SystemExit from e
|
||||||
return crew_instances
|
return crew_instances
|
||||||
|
|
||||||
|
|
||||||
@@ -348,8 +357,7 @@ def get_crew_instance(module_attr) -> Crew | None:
|
|||||||
|
|
||||||
if isinstance(module_attr, Crew):
|
if isinstance(module_attr, Crew):
|
||||||
return module_attr
|
return module_attr
|
||||||
else:
|
return None
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def fetch_crews(module_attr) -> list[Crew]:
|
def fetch_crews(module_attr) -> list[Crew]:
|
||||||
@@ -402,11 +410,11 @@ def extract_available_exports(dir_path: str = "src"):
|
|||||||
return available_exports
|
return available_exports
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
console.print(f"[red]Error: Could not extract tool classes: {str(e)}[/red]")
|
console.print(f"[red]Error: Could not extract tool classes: {e!s}[/red]")
|
||||||
console.print(
|
console.print(
|
||||||
"Please ensure your project contains valid tools (classes inheriting from BaseTool or functions with @tool decorator)."
|
"Please ensure your project contains valid tools (classes inheriting from BaseTool or functions with @tool decorator)."
|
||||||
)
|
)
|
||||||
raise SystemExit(1)
|
raise SystemExit(1) from e
|
||||||
|
|
||||||
|
|
||||||
def _load_tools_from_init(init_file: Path) -> list[dict[str, Any]]:
|
def _load_tools_from_init(init_file: Path) -> list[dict[str, Any]]:
|
||||||
@@ -440,8 +448,8 @@ def _load_tools_from_init(init_file: Path) -> list[dict[str, Any]]:
|
|||||||
]
|
]
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
console.print(f"[red]Warning: Could not load {init_file}: {str(e)}[/red]")
|
console.print(f"[red]Warning: Could not load {init_file}: {e!s}[/red]")
|
||||||
raise SystemExit(1)
|
raise SystemExit(1) from e
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
sys.modules.pop("temp_module", None)
|
sys.modules.pop("temp_module", None)
|
||||||
|
|||||||
@@ -1,17 +1,17 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from crewai.cli.authentication.main import Oauth2Settings
|
from crewai.cli.authentication.main import Oauth2Settings
|
||||||
from crewai.cli.authentication.providers.okta import OktaProvider
|
from crewai.cli.authentication.providers.okta import OktaProvider
|
||||||
|
|
||||||
|
|
||||||
class TestOktaProvider:
|
class TestOktaProvider:
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def setup_method(self):
|
def setup_method(self):
|
||||||
self.valid_settings = Oauth2Settings(
|
self.valid_settings = Oauth2Settings(
|
||||||
provider="okta",
|
provider="okta",
|
||||||
domain="test-domain.okta.com",
|
domain="test-domain.okta.com",
|
||||||
client_id="test-client-id",
|
client_id="test-client-id",
|
||||||
audience="test-audience"
|
audience="test-audience",
|
||||||
)
|
)
|
||||||
self.provider = OktaProvider(self.valid_settings)
|
self.provider = OktaProvider(self.valid_settings)
|
||||||
|
|
||||||
@@ -32,7 +32,7 @@ class TestOktaProvider:
|
|||||||
provider="okta",
|
provider="okta",
|
||||||
domain="my-company.okta.com",
|
domain="my-company.okta.com",
|
||||||
client_id="test-client",
|
client_id="test-client",
|
||||||
audience="test-audience"
|
audience="test-audience",
|
||||||
)
|
)
|
||||||
provider = OktaProvider(settings)
|
provider = OktaProvider(settings)
|
||||||
expected_url = "https://my-company.okta.com/oauth2/default/v1/device/authorize"
|
expected_url = "https://my-company.okta.com/oauth2/default/v1/device/authorize"
|
||||||
@@ -47,7 +47,7 @@ class TestOktaProvider:
|
|||||||
provider="okta",
|
provider="okta",
|
||||||
domain="another-domain.okta.com",
|
domain="another-domain.okta.com",
|
||||||
client_id="test-client",
|
client_id="test-client",
|
||||||
audience="test-audience"
|
audience="test-audience",
|
||||||
)
|
)
|
||||||
provider = OktaProvider(settings)
|
provider = OktaProvider(settings)
|
||||||
expected_url = "https://another-domain.okta.com/oauth2/default/v1/token"
|
expected_url = "https://another-domain.okta.com/oauth2/default/v1/token"
|
||||||
@@ -62,7 +62,7 @@ class TestOktaProvider:
|
|||||||
provider="okta",
|
provider="okta",
|
||||||
domain="dev.okta.com",
|
domain="dev.okta.com",
|
||||||
client_id="test-client",
|
client_id="test-client",
|
||||||
audience="test-audience"
|
audience="test-audience",
|
||||||
)
|
)
|
||||||
provider = OktaProvider(settings)
|
provider = OktaProvider(settings)
|
||||||
expected_url = "https://dev.okta.com/oauth2/default/v1/keys"
|
expected_url = "https://dev.okta.com/oauth2/default/v1/keys"
|
||||||
@@ -77,7 +77,7 @@ class TestOktaProvider:
|
|||||||
provider="okta",
|
provider="okta",
|
||||||
domain="prod.okta.com",
|
domain="prod.okta.com",
|
||||||
client_id="test-client",
|
client_id="test-client",
|
||||||
audience="test-audience"
|
audience="test-audience",
|
||||||
)
|
)
|
||||||
provider = OktaProvider(settings)
|
provider = OktaProvider(settings)
|
||||||
expected_issuer = "https://prod.okta.com/oauth2/default"
|
expected_issuer = "https://prod.okta.com/oauth2/default"
|
||||||
@@ -91,11 +91,11 @@ class TestOktaProvider:
|
|||||||
provider="okta",
|
provider="okta",
|
||||||
domain="test-domain.okta.com",
|
domain="test-domain.okta.com",
|
||||||
client_id="test-client-id",
|
client_id="test-client-id",
|
||||||
audience=None
|
audience=None,
|
||||||
)
|
)
|
||||||
provider = OktaProvider(settings)
|
provider = OktaProvider(settings)
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(ValueError, match="Audience is required"):
|
||||||
provider.get_audience()
|
provider.get_audience()
|
||||||
|
|
||||||
def test_get_client_id(self):
|
def test_get_client_id(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user