This commit is contained in:
Lorenze Jay
2024-10-21 14:24:07 -07:00
parent 6b12ac9c0b
commit 2786086974
2 changed files with 37 additions and 35 deletions

View File

@@ -16,7 +16,7 @@ class EntityMemory(Memory):
if storage
else RAGStorage(
type="entities",
allow_reset=False,
allow_reset=True,
embedder_config=embedder_config,
crew=crew,
)

View File

@@ -9,6 +9,8 @@ from crewai.memory.storage.base_rag_storage import BaseRAGStorage
from crewai.utilities.paths import db_storage_path
from chromadb.api import ClientAPI
from chromadb.api.types import validate_embedding_function
from chromadb import Documents, EmbeddingFunction, Embeddings
from typing import cast
@contextlib.contextmanager
@@ -47,64 +49,64 @@ class RAGStorage(BaseRAGStorage):
self._initialize_app()
def _set_embedder_config(self):
import chromadb.utils.embedding_functions as embedding_functions
if self.embedder_config is None:
self.embedder_config = self._create_default_embedding_function()
if isinstance(self.embedder_config, dict):
provider = self.embedder_config.get("provider")
config = self.embedder_config.get("config", {})
model_name = config.get("model")
if provider == "openai":
import chromadb.utils.embedding_functions as embedding_functions
self.embedder_config = embedding_functions.OpenAIEmbeddingFunction(
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
model_name=model_name,
)
elif provider == "azure":
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
self.embedder_config = OpenAIEmbeddingFunction(
self.embedder_config = embedding_functions.OpenAIEmbeddingFunction(
api_key=config.get("api_key"),
api_base=config.get("api_base"),
api_type=config.get("api_type"),
api_type=config.get("api_type", "azure"),
api_version=config.get("api_version"),
model_name=model_name,
)
elif provider == "ollama":
from chromadb.utils.embedding_functions.ollama_embedding_function import (
OllamaEmbeddingFunction,
)
print("using this ollama")
from openai import OpenAI
self.embedder_config = OllamaEmbeddingFunction(
model_name=config.get("model"),
url=config.get("url") or "http://localhost:11434",
)
class OllamaEmbeddingFunction(EmbeddingFunction):
def __call__(self, input: Documents) -> Embeddings:
client = OpenAI(
base_url="http://localhost:11434/v1",
api_key=config.get("api_key", "ollama"),
)
try:
response = client.embeddings.create(
input=input, model=model_name
)
embeddings = [item.embedding for item in response.data]
return cast(Embeddings, embeddings)
except Exception as e:
raise e
self.embedder_config = OllamaEmbeddingFunction()
elif provider == "vertexai":
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleVertexEmbeddingFunction,
)
self.embedder_config = GoogleVertexEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
self.embedder_config = (
embedding_functions.GoogleVertexEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
)
)
elif provider == "google":
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleGenerativeAiEmbeddingFunction,
)
self.embedder_config = GoogleGenerativeAiEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
self.embedder_config = (
embedding_functions.GoogleGenerativeAiEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
)
)
elif provider == "cohere":
from chromadb.utils.embedding_functions.cohere_embedding_function import (
CohereEmbeddingFunction,
)
self.embedder_config = CohereEmbeddingFunction(
self.embedder_config = embedding_functions.CohereEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
)