Add in suggestions from @caike to make sure ragstorage doesnt exceed os file limit. Also, included additional checks to support windows.

This commit is contained in:
Brandon Hancock
2024-12-03 10:46:54 -05:00
parent 87a7f38deb
commit 3a2e9bf365
2 changed files with 29 additions and 4 deletions

View File

@@ -181,7 +181,7 @@ class Agent(BaseAgent):
if key_name and key_name not in unaccepted_attributes: if key_name and key_name not in unaccepted_attributes:
env_value = os.environ.get(key_name) env_value = os.environ.get(key_name)
if env_value: if env_value:
param_name = env_value.lower() param_name = key_name.lower()
# Map key names containing "API_KEY" to "api_key" # Map key names containing "API_KEY" to "api_key"
if "api_key" in param_name: if "api_key" in param_name:
param_name = "api_key" param_name = "api_key"

View File

@@ -4,12 +4,13 @@ import logging
import os import os
import shutil import shutil
import uuid import uuid
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from chromadb.api import ClientAPI from chromadb.api import ClientAPI
from crewai.memory.storage.base_rag_storage import BaseRAGStorage from crewai.memory.storage.base_rag_storage import BaseRAGStorage
from crewai.utilities.paths import db_storage_path
from crewai.utilities import EmbeddingConfigurator from crewai.utilities import EmbeddingConfigurator
from crewai.utilities.paths import db_storage_path
@contextlib.contextmanager @contextlib.contextmanager
@@ -43,6 +44,7 @@ class RAGStorage(BaseRAGStorage):
agents = [self._sanitize_role(agent.role) for agent in agents] agents = [self._sanitize_role(agent.role) for agent in agents]
agents = "_".join(agents) agents = "_".join(agents)
self.agents = agents self.agents = agents
self.storage_file_name = self._build_storage_file_name(type, agents)
self.type = type self.type = type
@@ -59,7 +61,7 @@ class RAGStorage(BaseRAGStorage):
self._set_embedder_config() self._set_embedder_config()
chroma_client = chromadb.PersistentClient( chroma_client = chromadb.PersistentClient(
path=f"{db_storage_path()}/{self.type}/{self.agents}", path=self.storage_file_name,
settings=Settings(allow_reset=self.allow_reset), settings=Settings(allow_reset=self.allow_reset),
) )
@@ -80,6 +82,29 @@ class RAGStorage(BaseRAGStorage):
""" """
return role.replace("\n", "").replace(" ", "_").replace("/", "_") return role.replace("\n", "").replace(" ", "_").replace("/", "_")
def _build_storage_file_name(self, type: str, file_name: str) -> str:
"""
Ensures file name does not exceed max allowed by OS
"""
base_path = f"{db_storage_path()}/{type}"
try:
# Returns platform-dependent max length for a file name
max_length = os.pathconf(base_path, "PC_NAME_MAX")
except (OSError, AttributeError) as e:
logging.error(f"Error accessing path configuration: {e}")
# Fallback to a reasonable default if necessary
max_length = 255
# Trim if necessary
if len(file_name) > max_length:
logging.warning(
f"Trimming file name from {len(file_name)} to {max_length} characters."
)
file_name = file_name[:max_length]
return f"{base_path}/{file_name}"
def save(self, value: Any, metadata: Dict[str, Any]) -> None: def save(self, value: Any, metadata: Dict[str, Any]) -> None:
if not hasattr(self, "app") or not hasattr(self, "collection"): if not hasattr(self, "app") or not hasattr(self, "collection"):
self._initialize_app() self._initialize_app()