This commit is contained in:
Lorenze Jay
2024-10-21 14:24:07 -07:00
parent 6b12ac9c0b
commit 2786086974
2 changed files with 37 additions and 35 deletions

View File

@@ -16,7 +16,7 @@ class EntityMemory(Memory):
if storage if storage
else RAGStorage( else RAGStorage(
type="entities", type="entities",
allow_reset=False, allow_reset=True,
embedder_config=embedder_config, embedder_config=embedder_config,
crew=crew, crew=crew,
) )

View File

@@ -9,6 +9,8 @@ from crewai.memory.storage.base_rag_storage import BaseRAGStorage
from crewai.utilities.paths import db_storage_path from crewai.utilities.paths import db_storage_path
from chromadb.api import ClientAPI from chromadb.api import ClientAPI
from chromadb.api.types import validate_embedding_function from chromadb.api.types import validate_embedding_function
from chromadb import Documents, EmbeddingFunction, Embeddings
from typing import cast
@contextlib.contextmanager @contextlib.contextmanager
@@ -47,64 +49,64 @@ class RAGStorage(BaseRAGStorage):
self._initialize_app() self._initialize_app()
def _set_embedder_config(self): def _set_embedder_config(self):
import chromadb.utils.embedding_functions as embedding_functions
if self.embedder_config is None: if self.embedder_config is None:
self.embedder_config = self._create_default_embedding_function() self.embedder_config = self._create_default_embedding_function()
if isinstance(self.embedder_config, dict): if isinstance(self.embedder_config, dict):
provider = self.embedder_config.get("provider") provider = self.embedder_config.get("provider")
config = self.embedder_config.get("config", {}) config = self.embedder_config.get("config", {})
model_name = config.get("model") model_name = config.get("model")
if provider == "openai": if provider == "openai":
import chromadb.utils.embedding_functions as embedding_functions
self.embedder_config = embedding_functions.OpenAIEmbeddingFunction( self.embedder_config = embedding_functions.OpenAIEmbeddingFunction(
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"), api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
model_name=model_name, model_name=model_name,
) )
elif provider == "azure": elif provider == "azure":
from chromadb.utils.embedding_functions.openai_embedding_function import ( self.embedder_config = embedding_functions.OpenAIEmbeddingFunction(
OpenAIEmbeddingFunction,
)
self.embedder_config = OpenAIEmbeddingFunction(
api_key=config.get("api_key"), api_key=config.get("api_key"),
api_base=config.get("api_base"), api_base=config.get("api_base"),
api_type=config.get("api_type"), api_type=config.get("api_type", "azure"),
api_version=config.get("api_version"), api_version=config.get("api_version"),
model_name=model_name, model_name=model_name,
) )
elif provider == "ollama": elif provider == "ollama":
from chromadb.utils.embedding_functions.ollama_embedding_function import ( print("using this ollama")
OllamaEmbeddingFunction, from openai import OpenAI
)
self.embedder_config = OllamaEmbeddingFunction( class OllamaEmbeddingFunction(EmbeddingFunction):
model_name=config.get("model"), def __call__(self, input: Documents) -> Embeddings:
url=config.get("url") or "http://localhost:11434", client = OpenAI(
) base_url="http://localhost:11434/v1",
api_key=config.get("api_key", "ollama"),
)
try:
response = client.embeddings.create(
input=input, model=model_name
)
embeddings = [item.embedding for item in response.data]
return cast(Embeddings, embeddings)
except Exception as e:
raise e
self.embedder_config = OllamaEmbeddingFunction()
elif provider == "vertexai": elif provider == "vertexai":
from chromadb.utils.embedding_functions.google_embedding_function import ( self.embedder_config = (
GoogleVertexEmbeddingFunction, embedding_functions.GoogleVertexEmbeddingFunction(
) model_name=model_name,
api_key=config.get("api_key"),
self.embedder_config = GoogleVertexEmbeddingFunction( )
model_name=model_name,
api_key=config.get("api_key"),
) )
elif provider == "google": elif provider == "google":
from chromadb.utils.embedding_functions.google_embedding_function import ( self.embedder_config = (
GoogleGenerativeAiEmbeddingFunction, embedding_functions.GoogleGenerativeAiEmbeddingFunction(
) model_name=model_name,
api_key=config.get("api_key"),
self.embedder_config = GoogleGenerativeAiEmbeddingFunction( )
model_name=model_name,
api_key=config.get("api_key"),
) )
elif provider == "cohere": elif provider == "cohere":
from chromadb.utils.embedding_functions.cohere_embedding_function import ( self.embedder_config = embedding_functions.CohereEmbeddingFunction(
CohereEmbeddingFunction,
)
self.embedder_config = CohereEmbeddingFunction(
model_name=model_name, model_name=model_name,
api_key=config.get("api_key"), api_key=config.get("api_key"),
) )