From 263544524d489b8ecdda4ad0c934cf77ab2b059d Mon Sep 17 00:00:00 2001 From: Rip&Tear <84775494+theCyberTech@users.noreply.github.com> Date: Wed, 23 Oct 2024 17:20:05 +0800 Subject: [PATCH] ruff updates --- src/crewai/cli/provider.py | 109 +++++++++++++++---------------------- 1 file changed, 44 insertions(+), 65 deletions(-) diff --git a/src/crewai/cli/provider.py b/src/crewai/cli/provider.py index 8d68398ab..6fa3ecb0b 100644 --- a/src/crewai/cli/provider.py +++ b/src/crewai/cli/provider.py @@ -3,54 +3,53 @@ import time import requests from collections import defaultdict import click -from pathlib import Path +from pathlib import Path from crewai.cli.constants import PROVIDERS, MODELS, JSON_URL + def select_choice(prompt_message, choices): """ Presents a list of choices to the user and prompts them to select one. - + Args: - prompt_message (str): The message to display to the user before presenting the choices. - choices (list): A list of options to present to the user. - + Returns: - str: The selected choice from the list, or None if the user chooses to quit. """ + + provider_models = get_provider_data() + if not provider_models: + return click.secho(prompt_message, fg="cyan") for idx, choice in enumerate(choices, start=1): click.secho(f"{idx}. {choice}", fg="cyan") click.secho("q. Quit", fg="cyan") - + while True: - choice = click.prompt( - "Enter the number of your choice or 'q' to quit", type=str - ) - - if choice.lower() == "q": + choice = click.prompt("Enter the number of your choice or 'q' to quit", type=str) + + if choice.lower() == 'q': return None - + try: selected_index = int(choice) - 1 if 0 <= selected_index < len(choices): return choices[selected_index] except ValueError: pass - - click.secho( - "Invalid selection. Please select a number between 1 and 6 or 'q' to quit.", - fg="red", - ) - + + click.secho("Invalid selection. Please select a number between 1 and 6 or 'q' to quit.", fg="red") def select_provider(provider_models): """ Presents a list of providers to the user and prompts them to select one. - + Args: - provider_models (dict): A dictionary of provider models. - + Returns: - str: The selected provider - None: If user explicitly quits @@ -58,28 +57,25 @@ def select_provider(provider_models): predefined_providers = [p.lower() for p in PROVIDERS] all_providers = sorted(set(predefined_providers + list(provider_models.keys()))) - provider = select_choice( - "Select a provider to set up:", predefined_providers + ["other"] - ) + provider = select_choice("Select a provider to set up:", predefined_providers + ['other']) if provider is None: # User typed 'q' return None - - if provider == "other": + + if provider == 'other': provider = select_choice("Select a provider from the full list:", all_providers) if provider is None: # User typed 'q' return None - + return provider.lower() if provider else False - def select_model(provider, provider_models): """ Presents a list of models for a given provider to the user and prompts them to select one. - + Args: - provider (str): The provider for which to select a model. - provider_models (dict): A dictionary of provider models. - + Returns: - str: The selected model, or None if the operation is aborted or an invalid selection is made. """ @@ -94,49 +90,37 @@ def select_model(provider, provider_models): click.secho(f"No models available for provider '{provider}'.", fg="red") return None - selected_model = select_choice( - f"Select a model to use for {provider.capitalize()}:", available_models - ) + selected_model = select_choice(f"Select a model to use for {provider.capitalize()}:", available_models) return selected_model - def load_provider_data(cache_file, cache_expiry): """ Loads provider data from a cache file if it exists and is not expired. If the cache is expired or corrupted, it fetches the data from the web. - + Args: - cache_file (Path): The path to the cache file. - cache_expiry (int): The cache expiry time in seconds. - + Returns: - dict or None: The loaded provider data or None if the operation fails. """ current_time = time.time() - if ( - cache_file.exists() - and (current_time - cache_file.stat().st_mtime) < cache_expiry - ): + if cache_file.exists() and (current_time - cache_file.stat().st_mtime) < cache_expiry: data = read_cache_file(cache_file) if data: return data - click.secho( - "Cache is corrupted. Fetching provider data from the web...", fg="yellow" - ) + click.secho("Cache is corrupted. Fetching provider data from the web...", fg="yellow") else: - click.secho( - "Cache expired or not found. Fetching provider data from the web...", - fg="cyan", - ) + click.secho("Cache expired or not found. Fetching provider data from the web...", fg="cyan") return fetch_provider_data(cache_file) - def read_cache_file(cache_file): """ Reads and returns the JSON content from a cache file. Returns None if the file contains invalid JSON. - + Args: - cache_file (Path): The path to the cache file. - + Returns: - dict or None: The JSON content of the cache file or None if the JSON is invalid. """ @@ -146,14 +130,13 @@ def read_cache_file(cache_file): except json.JSONDecodeError: return None - def fetch_provider_data(cache_file): """ Fetches provider data from a specified URL and caches it to a file. - + Args: - cache_file (Path): The path to the cache file. - + Returns: - dict or None: The fetched provider data or None if the operation fails. """ @@ -170,42 +153,38 @@ def fetch_provider_data(cache_file): click.secho("Error parsing provider data. Invalid JSON format.", fg="red") return None - def download_data(response): """ Downloads data from a given HTTP response and returns the JSON content. - + Args: - response (requests.Response): The HTTP response object. - + Returns: - dict: The JSON content of the response. """ - total_size = int(response.headers.get("content-length", 0)) + total_size = int(response.headers.get('content-length', 0)) block_size = 8192 data_chunks = [] - with click.progressbar( - length=total_size, label="Downloading", show_pos=True - ) as progress_bar: + with click.progressbar(length=total_size, label='Downloading', show_pos=True) as progress_bar: for chunk in response.iter_content(block_size): if chunk: data_chunks.append(chunk) progress_bar.update(len(chunk)) - data_content = b"".join(data_chunks) - return json.loads(data_content.decode("utf-8")) - + data_content = b''.join(data_chunks) + return json.loads(data_content.decode('utf-8')) def get_provider_data(): """ Retrieves provider data from a cache file, filters out models based on provider criteria, and returns a dictionary of providers mapped to their models. - + Returns: - dict or None: A dictionary of providers mapped to their models or None if the operation fails. """ - cache_dir = Path.home() / ".crewai" + cache_dir = Path.home() / '.crewai' cache_dir.mkdir(exist_ok=True) - cache_file = cache_dir / "provider_cache.json" - cache_expiry = 24 * 3600 + cache_file = cache_dir / 'provider_cache.json' + cache_expiry = 24 * 3600 data = load_provider_data(cache_file, cache_expiry) if not data: @@ -214,7 +193,7 @@ def get_provider_data(): provider_models = defaultdict(list) for model_name, properties in data.items(): provider = properties.get("litellm_provider", "").strip().lower() - if "http" in provider or provider == "other": + if 'http' in provider or provider == 'other': continue if provider: provider_models[provider].append(model_name)