mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-28 01:28:14 +00:00
ensure original embedding config works
This commit is contained in:
@@ -32,10 +32,10 @@ def reset_memories_command(long, short, entity, kickoff_outputs, all) -> None:
|
|||||||
click.echo("Long term memory has been reset.")
|
click.echo("Long term memory has been reset.")
|
||||||
|
|
||||||
if short:
|
if short:
|
||||||
ShortTermMemory().reset()
|
ShortTermMemory(allow_reset=True).reset()
|
||||||
click.echo("Short term memory has been reset.")
|
click.echo("Short term memory has been reset.")
|
||||||
if entity:
|
if entity:
|
||||||
EntityMemory().reset()
|
EntityMemory(allow_reset=True).reset()
|
||||||
click.echo("Entity memory has been reset.")
|
click.echo("Entity memory has been reset.")
|
||||||
if kickoff_outputs:
|
if kickoff_outputs:
|
||||||
TaskOutputStorageHandler().reset()
|
TaskOutputStorageHandler().reset()
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional
|
|||||||
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
||||||
from crewai.utilities.paths import db_storage_path
|
from crewai.utilities.paths import db_storage_path
|
||||||
from chromadb.api import ClientAPI
|
from chromadb.api import ClientAPI
|
||||||
|
from chromadb.api.types import validate_embedding_function
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
@@ -41,16 +42,87 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
self.agents = agents
|
self.agents = agents
|
||||||
|
|
||||||
self.type = type
|
self.type = type
|
||||||
self.embedder_config = embedder_config or self._create_embedding_function()
|
|
||||||
self.allow_reset = allow_reset
|
self.allow_reset = allow_reset
|
||||||
self._initialize_app()
|
self._initialize_app()
|
||||||
|
|
||||||
|
def set_embedder_config(self):
|
||||||
|
if self.embedder_config is None:
|
||||||
|
self.embedder_config = self._create_default_embedding_function()
|
||||||
|
if isinstance(self.embedder_config, dict):
|
||||||
|
provider = self.embedder_config.get("provider")
|
||||||
|
config = self.embedder_config.get("config", {})
|
||||||
|
model_name = config.get("model")
|
||||||
|
if provider == "openai":
|
||||||
|
import chromadb.utils.embedding_functions as embedding_functions
|
||||||
|
|
||||||
|
self.embedder_config = embedding_functions.OpenAIEmbeddingFunction(
|
||||||
|
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
|
||||||
|
model_name=model_name,
|
||||||
|
)
|
||||||
|
elif provider == "azure":
|
||||||
|
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||||
|
OpenAIEmbeddingFunction,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.embedder_config = OpenAIEmbeddingFunction(
|
||||||
|
api_key=config.get("api_key"),
|
||||||
|
api_base=config.get("api_base"),
|
||||||
|
api_type=config.get("api_type"),
|
||||||
|
api_version=config.get("api_version"),
|
||||||
|
model_name=model_name,
|
||||||
|
)
|
||||||
|
elif provider == "ollama":
|
||||||
|
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
||||||
|
OllamaEmbeddingFunction,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.embedder_config = OllamaEmbeddingFunction(
|
||||||
|
model_name=config.get("model"),
|
||||||
|
url=config.get("url") or "http://localhost:11434",
|
||||||
|
)
|
||||||
|
elif provider == "vertexai":
|
||||||
|
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||||
|
GoogleVertexEmbeddingFunction,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.embedder_config = GoogleVertexEmbeddingFunction(
|
||||||
|
model_name=model_name,
|
||||||
|
api_key=config.get("api_key"),
|
||||||
|
)
|
||||||
|
elif provider == "google":
|
||||||
|
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||||
|
GoogleGenerativeAiEmbeddingFunction,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.embedder_config = GoogleGenerativeAiEmbeddingFunction(
|
||||||
|
model_name=model_name,
|
||||||
|
api_key=config.get("api_key"),
|
||||||
|
)
|
||||||
|
elif provider == "cohere":
|
||||||
|
from chromadb.utils.embedding_functions.cohere_embedding_function import (
|
||||||
|
CohereEmbeddingFunction,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.embedder_config = CohereEmbeddingFunction(
|
||||||
|
model_name=model_name,
|
||||||
|
api_key=config.get("api_key"),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.embedder_config = self._create_default_embedding_function()
|
||||||
|
else:
|
||||||
|
validate_embedding_function(self.embedder_config) # type: ignore # used for validating embedder_config if defined a embedding function/class
|
||||||
|
self.embedder_config = self.embedder_config
|
||||||
|
|
||||||
def _initialize_app(self):
|
def _initialize_app(self):
|
||||||
import chromadb
|
import chromadb
|
||||||
|
|
||||||
|
self.set_embedder_config()
|
||||||
chroma_client = chromadb.PersistentClient(
|
chroma_client = chromadb.PersistentClient(
|
||||||
path=f"{db_storage_path()}/{self.type}/{self.agents}"
|
path=f"{db_storage_path()}/{self.type}/{self.agents}",
|
||||||
|
settings=chromadb.Settings(allow_reset=self.allow_reset),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.app = chroma_client
|
self.app = chroma_client
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -122,11 +194,16 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
if self.app:
|
if self.app:
|
||||||
self.app.reset()
|
self.app.reset()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(
|
if "attempt to write a readonly database" in str(e):
|
||||||
f"An error occurred while resetting the {self.type} memory: {e}"
|
print("ignoring error")
|
||||||
)
|
# Ignore this specific error
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
f"An error occurred while resetting the {self.type} memory: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
def _create_embedding_function(self):
|
def _create_default_embedding_function(self):
|
||||||
import chromadb.utils.embedding_functions as embedding_functions
|
import chromadb.utils.embedding_functions as embedding_functions
|
||||||
|
|
||||||
return embedding_functions.OpenAIEmbeddingFunction(
|
return embedding_functions.OpenAIEmbeddingFunction(
|
||||||
|
|||||||
Reference in New Issue
Block a user