mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
Fix: Update import sorting and implement abstract methods
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -4,7 +4,7 @@ import io
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from typing import Any, Dict, List, Optional, Union, cast, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
|
||||
|
||||
# Type checking imports that don't cause runtime imports
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -4,15 +4,18 @@ import logging
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
# Type checking imports that don't cause runtime imports
|
||||
if TYPE_CHECKING:
|
||||
import chromadb
|
||||
from chromadb.api import ClientAPI
|
||||
from chromadb.config import Settings
|
||||
|
||||
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
||||
from crewai.utilities import EmbeddingConfigurator
|
||||
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
|
||||
from crewai.utilities.chromadb import sanitize_collection_name
|
||||
from crewai.utilities.logger import Logger
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
|
||||
@@ -39,77 +42,45 @@ class RAGStorage(BaseRAGStorage):
|
||||
search efficiency.
|
||||
"""
|
||||
|
||||
collection: Optional[Any] = None
|
||||
collection_name: Optional[str] = "memory"
|
||||
app: Optional[Any] = None
|
||||
|
||||
def __init__(
|
||||
self, type, allow_reset=True, embedder_config=None, crew=None, path=None
|
||||
self,
|
||||
type: str = "memory",
|
||||
allow_reset: bool = True,
|
||||
embedder_config: Optional[Dict[str, Any]] = None,
|
||||
crew: Any = None,
|
||||
collection_name: Optional[str] = None,
|
||||
):
|
||||
super().__init__(type, allow_reset, embedder_config, crew)
|
||||
agents = crew.agents if crew else []
|
||||
agents = [self._sanitize_role(agent.role) for agent in agents]
|
||||
agents = "_".join(agents)
|
||||
self.agents = agents
|
||||
self.storage_file_name = self._build_storage_file_name(type, agents)
|
||||
self.collection_name = collection_name or type
|
||||
self._set_embedder_config(embedder_config)
|
||||
|
||||
self.type = type
|
||||
def save(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: Dict[str, Any],
|
||||
) -> None:
|
||||
with suppress_logging():
|
||||
if not self.collection:
|
||||
self._initialize_app()
|
||||
|
||||
self.allow_reset = allow_reset
|
||||
self.path = path
|
||||
self._initialize_app()
|
||||
if isinstance(value, list):
|
||||
documents = value
|
||||
metadatas = [metadata] * len(value) if metadata else None
|
||||
ids = [str(uuid.uuid4()) for _ in range(len(documents))]
|
||||
else:
|
||||
documents = [value]
|
||||
metadatas = [metadata] if metadata else None
|
||||
ids = [str(uuid.uuid4())]
|
||||
|
||||
def _set_embedder_config(self):
|
||||
configurator = EmbeddingConfigurator()
|
||||
self.embedder_config = configurator.configure_embedder(self.embedder_config)
|
||||
|
||||
def _initialize_app(self):
|
||||
# Import chromadb here to avoid importing at module level
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
|
||||
self._set_embedder_config()
|
||||
chroma_client = chromadb.PersistentClient(
|
||||
path=self.path if self.path else self.storage_file_name,
|
||||
settings=Settings(allow_reset=self.allow_reset),
|
||||
)
|
||||
|
||||
self.app = chroma_client
|
||||
|
||||
try:
|
||||
self.collection = self.app.get_collection(
|
||||
name=self.type, embedding_function=self.embedder_config
|
||||
self.collection.add(
|
||||
documents=documents,
|
||||
metadatas=metadatas,
|
||||
ids=ids,
|
||||
)
|
||||
except Exception:
|
||||
self.collection = self.app.create_collection(
|
||||
name=self.type, embedding_function=self.embedder_config
|
||||
)
|
||||
|
||||
def _sanitize_role(self, role: str) -> str:
|
||||
"""
|
||||
Sanitizes agent roles to ensure valid directory names.
|
||||
"""
|
||||
return role.replace("\n", "").replace(" ", "_").replace("/", "_")
|
||||
|
||||
def _build_storage_file_name(self, type: str, file_name: str) -> str:
|
||||
"""
|
||||
Ensures file name does not exceed max allowed by OS
|
||||
"""
|
||||
base_path = f"{db_storage_path()}/{type}"
|
||||
|
||||
if len(file_name) > MAX_FILE_NAME_LENGTH:
|
||||
logging.warning(
|
||||
f"Trimming file name from {len(file_name)} to {MAX_FILE_NAME_LENGTH} characters."
|
||||
)
|
||||
file_name = file_name[:MAX_FILE_NAME_LENGTH]
|
||||
|
||||
return f"{base_path}/{file_name}"
|
||||
|
||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||
if not hasattr(self, "app") or not hasattr(self, "collection"):
|
||||
self._initialize_app()
|
||||
try:
|
||||
self._generate_embedding(value, metadata)
|
||||
except Exception as e:
|
||||
logging.error(f"Error during {self.type} save: {str(e)}")
|
||||
|
||||
def search(
|
||||
self,
|
||||
@@ -118,54 +89,96 @@ class RAGStorage(BaseRAGStorage):
|
||||
filter: Optional[dict] = None,
|
||||
score_threshold: float = 0.35,
|
||||
) -> List[Any]:
|
||||
if not hasattr(self, "app"):
|
||||
self._initialize_app()
|
||||
with suppress_logging():
|
||||
if not hasattr(self, "collection") or not self.collection:
|
||||
self._initialize_app()
|
||||
|
||||
try:
|
||||
with suppress_logging():
|
||||
response = self.collection.query(query_texts=query, n_results=limit)
|
||||
if isinstance(query, str):
|
||||
query = [query]
|
||||
|
||||
fetched = self.collection.query(
|
||||
query_texts=query,
|
||||
n_results=limit,
|
||||
where=filter,
|
||||
)
|
||||
results = []
|
||||
for i in range(len(response["ids"][0])):
|
||||
for i in range(len(fetched["ids"][0])): # type: ignore
|
||||
result = {
|
||||
"id": response["ids"][0][i],
|
||||
"metadata": response["metadatas"][0][i],
|
||||
"context": response["documents"][0][i],
|
||||
"score": response["distances"][0][i],
|
||||
"id": fetched["ids"][0][i], # type: ignore
|
||||
"metadata": fetched["metadatas"][0][i], # type: ignore
|
||||
"context": fetched["documents"][0][i], # type: ignore
|
||||
"score": fetched["distances"][0][i], # type: ignore
|
||||
}
|
||||
if result["score"] >= score_threshold:
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
logging.error(f"Error during {self.type} search: {str(e)}")
|
||||
return []
|
||||
|
||||
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> None: # type: ignore
|
||||
if not hasattr(self, "app") or not hasattr(self, "collection"):
|
||||
def _initialize_app(self):
|
||||
# Import chromadb here to avoid importing at module level
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
|
||||
base_path = os.path.join(db_storage_path(), "memory")
|
||||
chroma_client = chromadb.PersistentClient(
|
||||
path=base_path,
|
||||
settings=Settings(allow_reset=self.allow_reset),
|
||||
)
|
||||
|
||||
self.app = chroma_client
|
||||
|
||||
try:
|
||||
collection_name = (
|
||||
f"memory_{self.collection_name}"
|
||||
if self.collection_name
|
||||
else "memory"
|
||||
)
|
||||
if self.app:
|
||||
self.collection = self.app.get_or_create_collection(
|
||||
name=sanitize_collection_name(collection_name),
|
||||
embedding_function=self.embedder,
|
||||
)
|
||||
else:
|
||||
raise Exception("Vector Database Client not initialized")
|
||||
except Exception:
|
||||
raise Exception("Failed to create or get collection")
|
||||
|
||||
def initialize_rag_storage(self):
|
||||
self._initialize_app()
|
||||
|
||||
def reset(self) -> None:
|
||||
# Import chromadb here to avoid importing at module level
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
|
||||
base_path = os.path.join(db_storage_path(), "memory")
|
||||
if not self.app:
|
||||
self.app = chromadb.PersistentClient(
|
||||
path=base_path,
|
||||
settings=Settings(allow_reset=True),
|
||||
)
|
||||
|
||||
self.app.reset()
|
||||
shutil.rmtree(base_path)
|
||||
self.app = None
|
||||
self.collection = None
|
||||
|
||||
def _generate_embedding(
|
||||
self, text: str, metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Any:
|
||||
if not hasattr(self, "collection") or not self.collection:
|
||||
self._initialize_app()
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
self.collection.add(
|
||||
documents=[text],
|
||||
metadatas=[metadata or {}],
|
||||
ids=[str(uuid.uuid4())],
|
||||
ids=[id],
|
||||
)
|
||||
return id
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
if self.app:
|
||||
self.app.reset()
|
||||
shutil.rmtree(f"{db_storage_path()}/{self.type}")
|
||||
self.app = None
|
||||
self.collection = None
|
||||
except Exception as e:
|
||||
if "attempt to write a readonly database" in str(e):
|
||||
# Ignore this specific error
|
||||
pass
|
||||
else:
|
||||
raise Exception(
|
||||
f"An error occurred while resetting the {self.type} memory: {e}"
|
||||
)
|
||||
def _sanitize_role(self, role: str) -> str:
|
||||
"""Sanitize role name for use in file names."""
|
||||
return role.lower().replace(" ", "_").replace("\n", "").replace("/", "_")
|
||||
|
||||
def _create_default_embedding_function(self):
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
@@ -175,3 +188,20 @@ class RAGStorage(BaseRAGStorage):
|
||||
return OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
||||
)
|
||||
|
||||
def _set_embedder_config(self, embedder_config: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""Set the embedding configuration for the RAG storage.
|
||||
|
||||
Args:
|
||||
embedder_config (Optional[Dict[str, Any]]): Configuration dictionary for the embedder.
|
||||
If None or empty, defaults to the default embedding function.
|
||||
"""
|
||||
self.embedder = (
|
||||
EmbeddingConfigurator().configure_embedder(embedder_config)
|
||||
if embedder_config
|
||||
else self._create_default_embedding_function()
|
||||
)
|
||||
|
||||
def _build_storage_file_name(self, role_name: str) -> str:
|
||||
"""Build storage file name from role name."""
|
||||
return f"{self._sanitize_role(role_name)}_memory"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Dict, Optional, cast, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, cast
|
||||
|
||||
# Type checking imports that don't cause runtime imports
|
||||
if TYPE_CHECKING:
|
||||
@@ -189,7 +189,7 @@ class EmbeddingConfigurator:
|
||||
) from e
|
||||
|
||||
# Import chromadb types here to avoid importing at module level
|
||||
from chromadb import Documents, Embeddings, EmbeddingFunction
|
||||
from chromadb import Documents, EmbeddingFunction, Embeddings
|
||||
|
||||
class WatsonEmbeddingFunction(EmbeddingFunction):
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import importlib
|
||||
import sys
|
||||
from unittest.mock import patch, MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestEmbeddingConfiguratorImports:
|
||||
"""Test that ChromaDB is not imported at module level."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user