ensure original embedding config works

This commit is contained in:
Lorenze Jay
2024-10-20 18:12:57 -07:00
parent 40f81aecf5
commit 3fc83c624b
2 changed files with 85 additions and 8 deletions

View File

@@ -32,10 +32,10 @@ def reset_memories_command(long, short, entity, kickoff_outputs, all) -> None:
click.echo("Long term memory has been reset.") click.echo("Long term memory has been reset.")
if short: if short:
ShortTermMemory().reset() ShortTermMemory(allow_reset=True).reset()
click.echo("Short term memory has been reset.") click.echo("Short term memory has been reset.")
if entity: if entity:
EntityMemory().reset() EntityMemory(allow_reset=True).reset()
click.echo("Entity memory has been reset.") click.echo("Entity memory has been reset.")
if kickoff_outputs: if kickoff_outputs:
TaskOutputStorageHandler().reset() TaskOutputStorageHandler().reset()

View File

@@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional
from crewai.memory.storage.base_rag_storage import BaseRAGStorage from crewai.memory.storage.base_rag_storage import BaseRAGStorage
from crewai.utilities.paths import db_storage_path from crewai.utilities.paths import db_storage_path
from chromadb.api import ClientAPI from chromadb.api import ClientAPI
from chromadb.api.types import validate_embedding_function
@contextlib.contextmanager @contextlib.contextmanager
@@ -41,16 +42,87 @@ class RAGStorage(BaseRAGStorage):
self.agents = agents self.agents = agents
self.type = type self.type = type
self.embedder_config = embedder_config or self._create_embedding_function()
self.allow_reset = allow_reset self.allow_reset = allow_reset
self._initialize_app() self._initialize_app()
def set_embedder_config(self):
if self.embedder_config is None:
self.embedder_config = self._create_default_embedding_function()
if isinstance(self.embedder_config, dict):
provider = self.embedder_config.get("provider")
config = self.embedder_config.get("config", {})
model_name = config.get("model")
if provider == "openai":
import chromadb.utils.embedding_functions as embedding_functions
self.embedder_config = embedding_functions.OpenAIEmbeddingFunction(
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
model_name=model_name,
)
elif provider == "azure":
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
self.embedder_config = OpenAIEmbeddingFunction(
api_key=config.get("api_key"),
api_base=config.get("api_base"),
api_type=config.get("api_type"),
api_version=config.get("api_version"),
model_name=model_name,
)
elif provider == "ollama":
from chromadb.utils.embedding_functions.ollama_embedding_function import (
OllamaEmbeddingFunction,
)
self.embedder_config = OllamaEmbeddingFunction(
model_name=config.get("model"),
url=config.get("url") or "http://localhost:11434",
)
elif provider == "vertexai":
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleVertexEmbeddingFunction,
)
self.embedder_config = GoogleVertexEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
)
elif provider == "google":
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleGenerativeAiEmbeddingFunction,
)
self.embedder_config = GoogleGenerativeAiEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
)
elif provider == "cohere":
from chromadb.utils.embedding_functions.cohere_embedding_function import (
CohereEmbeddingFunction,
)
self.embedder_config = CohereEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
)
else:
self.embedder_config = self._create_default_embedding_function()
else:
validate_embedding_function(self.embedder_config) # type: ignore # used for validating embedder_config if defined a embedding function/class
self.embedder_config = self.embedder_config
def _initialize_app(self): def _initialize_app(self):
import chromadb import chromadb
self.set_embedder_config()
chroma_client = chromadb.PersistentClient( chroma_client = chromadb.PersistentClient(
path=f"{db_storage_path()}/{self.type}/{self.agents}" path=f"{db_storage_path()}/{self.type}/{self.agents}",
settings=chromadb.Settings(allow_reset=self.allow_reset),
) )
self.app = chroma_client self.app = chroma_client
try: try:
@@ -122,11 +194,16 @@ class RAGStorage(BaseRAGStorage):
if self.app: if self.app:
self.app.reset() self.app.reset()
except Exception as e: except Exception as e:
raise Exception( if "attempt to write a readonly database" in str(e):
f"An error occurred while resetting the {self.type} memory: {e}" print("ignoring error")
) # Ignore this specific error
pass
else:
raise Exception(
f"An error occurred while resetting the {self.type} memory: {e}"
)
def _create_embedding_function(self): def _create_default_embedding_function(self):
import chromadb.utils.embedding_functions as embedding_functions import chromadb.utils.embedding_functions as embedding_functions
return embedding_functions.OpenAIEmbeddingFunction( return embedding_functions.OpenAIEmbeddingFunction(