From 7b276e67978b82ae7a160ee0762072da94ead8a2 Mon Sep 17 00:00:00 2001 From: "Brandon Hancock (bhancock_ai)" <109994880+bhancockio@users.noreply.github.com> Date: Thu, 5 Dec 2024 12:17:23 -0500 Subject: [PATCH] Incorporate Stale PRs that have feedback (#1693) * incorporate #1683 * add in --version flag to cli. closes #1679. * Fix env issue * Add in suggestions from @caike to make sure ragstorage doesnt exceed os file limit. Also, included additional checks to support windows. * remove poetry.lock as pointed out by @sanders41 in #1574. * Incorporate feedback from crewai reviewer * Incorporate @lorenzejay feedback --- src/crewai/agent.py | 21 ++++++------------ src/crewai/cli/cli.py | 6 +++++- src/crewai/cli/constants.py | 3 +++ src/crewai/memory/storage/rag_storage.py | 27 ++++++++++++++++++++---- src/crewai/utilities/constants.py | 1 + tests/cli/cli_test.py | 7 ++++++ 6 files changed, 45 insertions(+), 20 deletions(-) diff --git a/src/crewai/agent.py b/src/crewai/agent.py index abe678db1..8c79c6eb8 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -8,7 +8,7 @@ from pydantic import Field, InstanceOf, PrivateAttr, model_validator from crewai.agents import CacheHandler from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.crew_agent_executor import CrewAgentExecutor -from crewai.cli.constants import ENV_VARS +from crewai.cli.constants import ENV_VARS, LITELLM_PARAMS from crewai.knowledge.knowledge import Knowledge from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context @@ -181,20 +181,11 @@ class Agent(BaseAgent): if key_name and key_name not in unaccepted_attributes: env_value = os.environ.get(key_name) if env_value: - # Map key names containing "API_KEY" to "api_key" - key_name = ( - "api_key" if "API_KEY" in key_name else key_name - ) - # Map key names containing "API_BASE" to "api_base" - key_name = ( - "api_base" if "API_BASE" in key_name else key_name - ) - # Map key names containing "API_VERSION" to "api_version" - key_name = ( - "api_version" - if "API_VERSION" in key_name - else key_name - ) + key_name = key_name.lower() + for pattern in LITELLM_PARAMS: + if pattern in key_name: + key_name = pattern + break llm_params[key_name] = env_value # Check for default values if the environment variable is not set elif env_var.get("default", False): diff --git a/src/crewai/cli/cli.py b/src/crewai/cli/cli.py index 6e8560133..43dc90eed 100644 --- a/src/crewai/cli/cli.py +++ b/src/crewai/cli/cli.py @@ -25,6 +25,7 @@ from .update_crew import update_crew @click.group() +@click.version_option(pkg_resources.get_distribution("crewai").version) def crewai(): """Top-level command group for crewai.""" @@ -50,7 +51,10 @@ def create(type, name, provider, skip_provider=False): ) def version(tools): """Show the installed version of crewai.""" - crewai_version = pkg_resources.get_distribution("crewai").version + try: + crewai_version = pkg_resources.get_distribution("crewai").version + except Exception: + crewai_version = "unknown version" click.echo(f"crewai version: {crewai_version}") if tools: diff --git a/src/crewai/cli/constants.py b/src/crewai/cli/constants.py index e13349155..13279f8d3 100644 --- a/src/crewai/cli/constants.py +++ b/src/crewai/cli/constants.py @@ -159,3 +159,6 @@ MODELS = { } JSON_URL = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" + + +LITELLM_PARAMS = ["api_key", "api_base", "api_version"] diff --git a/src/crewai/memory/storage/rag_storage.py b/src/crewai/memory/storage/rag_storage.py index ded340a19..bf40aee96 100644 --- a/src/crewai/memory/storage/rag_storage.py +++ b/src/crewai/memory/storage/rag_storage.py @@ -4,12 +4,14 @@ import logging import os import shutil import uuid - from typing import Any, Dict, List, Optional + from chromadb.api import ClientAPI + from crewai.memory.storage.base_rag_storage import BaseRAGStorage -from crewai.utilities.paths import db_storage_path from crewai.utilities import EmbeddingConfigurator +from crewai.utilities.constants import MAX_FILE_NAME_LENGTH +from crewai.utilities.paths import db_storage_path @contextlib.contextmanager @@ -37,12 +39,15 @@ class RAGStorage(BaseRAGStorage): app: ClientAPI | None = None - def __init__(self, type, allow_reset=True, embedder_config=None, crew=None, path=None): + def __init__( + self, type, allow_reset=True, embedder_config=None, crew=None, path=None + ): super().__init__(type, allow_reset, embedder_config, crew) agents = crew.agents if crew else [] agents = [self._sanitize_role(agent.role) for agent in agents] agents = "_".join(agents) self.agents = agents + self.storage_file_name = self._build_storage_file_name(type, agents) self.type = type @@ -60,7 +65,7 @@ class RAGStorage(BaseRAGStorage): self._set_embedder_config() chroma_client = chromadb.PersistentClient( - path=self.path if self.path else f"{db_storage_path()}/{self.type}/{self.agents}", + path=self.path if self.path else self.storage_file_name, settings=Settings(allow_reset=self.allow_reset), ) @@ -81,6 +86,20 @@ class RAGStorage(BaseRAGStorage): """ return role.replace("\n", "").replace(" ", "_").replace("/", "_") + def _build_storage_file_name(self, type: str, file_name: str) -> str: + """ + Ensures file name does not exceed max allowed by OS + """ + base_path = f"{db_storage_path()}/{type}" + + if len(file_name) > MAX_FILE_NAME_LENGTH: + logging.warning( + f"Trimming file name from {len(file_name)} to {MAX_FILE_NAME_LENGTH} characters." + ) + file_name = file_name[:MAX_FILE_NAME_LENGTH] + + return f"{base_path}/{file_name}" + def save(self, value: Any, metadata: Dict[str, Any]) -> None: if not hasattr(self, "app") or not hasattr(self, "collection"): self._initialize_app() diff --git a/src/crewai/utilities/constants.py b/src/crewai/utilities/constants.py index 3578b7ece..096bb7c8c 100644 --- a/src/crewai/utilities/constants.py +++ b/src/crewai/utilities/constants.py @@ -3,3 +3,4 @@ TRAINED_AGENTS_DATA_FILE = "trained_agents_data.pkl" DEFAULT_SCORE_THRESHOLD = 0.35 KNOWLEDGE_DIRECTORY = "knowledge" MAX_LLM_RETRY = 3 +MAX_FILE_NAME_LENGTH = 255 diff --git a/tests/cli/cli_test.py b/tests/cli/cli_test.py index 05e1cf03a..15ed81637 100644 --- a/tests/cli/cli_test.py +++ b/tests/cli/cli_test.py @@ -131,6 +131,13 @@ def test_reset_no_memory_flags(runner): ) +def test_version_flag(runner): + result = runner.invoke(version) + + assert result.exit_code == 0 + assert "crewai version:" in result.output + + def test_version_command(runner): result = runner.invoke(version)