diff --git a/src/crewai/memory/entity/entity_memory.py b/src/crewai/memory/entity/entity_memory.py index d4f8d9aae..4de0594c7 100644 --- a/src/crewai/memory/entity/entity_memory.py +++ b/src/crewai/memory/entity/entity_memory.py @@ -16,7 +16,7 @@ class EntityMemory(Memory): if storage else RAGStorage( type="entities", - allow_reset=False, + allow_reset=True, embedder_config=embedder_config, crew=crew, ) diff --git a/src/crewai/memory/storage/rag_storage.py b/src/crewai/memory/storage/rag_storage.py index ca67492c4..56b7373d9 100644 --- a/src/crewai/memory/storage/rag_storage.py +++ b/src/crewai/memory/storage/rag_storage.py @@ -9,6 +9,8 @@ from crewai.memory.storage.base_rag_storage import BaseRAGStorage from crewai.utilities.paths import db_storage_path from chromadb.api import ClientAPI from chromadb.api.types import validate_embedding_function +from chromadb import Documents, EmbeddingFunction, Embeddings +from typing import cast @contextlib.contextmanager @@ -47,64 +49,64 @@ class RAGStorage(BaseRAGStorage): self._initialize_app() def _set_embedder_config(self): + import chromadb.utils.embedding_functions as embedding_functions + if self.embedder_config is None: self.embedder_config = self._create_default_embedding_function() + if isinstance(self.embedder_config, dict): provider = self.embedder_config.get("provider") config = self.embedder_config.get("config", {}) model_name = config.get("model") if provider == "openai": - import chromadb.utils.embedding_functions as embedding_functions - self.embedder_config = embedding_functions.OpenAIEmbeddingFunction( api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"), model_name=model_name, ) elif provider == "azure": - from chromadb.utils.embedding_functions.openai_embedding_function import ( - OpenAIEmbeddingFunction, - ) - - self.embedder_config = OpenAIEmbeddingFunction( + self.embedder_config = embedding_functions.OpenAIEmbeddingFunction( api_key=config.get("api_key"), api_base=config.get("api_base"), - api_type=config.get("api_type"), + api_type=config.get("api_type", "azure"), api_version=config.get("api_version"), model_name=model_name, ) elif provider == "ollama": - from chromadb.utils.embedding_functions.ollama_embedding_function import ( - OllamaEmbeddingFunction, - ) + print("using this ollama") + from openai import OpenAI - self.embedder_config = OllamaEmbeddingFunction( - model_name=config.get("model"), - url=config.get("url") or "http://localhost:11434", - ) + class OllamaEmbeddingFunction(EmbeddingFunction): + def __call__(self, input: Documents) -> Embeddings: + client = OpenAI( + base_url="http://localhost:11434/v1", + api_key=config.get("api_key", "ollama"), + ) + try: + response = client.embeddings.create( + input=input, model=model_name + ) + embeddings = [item.embedding for item in response.data] + return cast(Embeddings, embeddings) + except Exception as e: + raise e + + self.embedder_config = OllamaEmbeddingFunction() elif provider == "vertexai": - from chromadb.utils.embedding_functions.google_embedding_function import ( - GoogleVertexEmbeddingFunction, - ) - - self.embedder_config = GoogleVertexEmbeddingFunction( - model_name=model_name, - api_key=config.get("api_key"), + self.embedder_config = ( + embedding_functions.GoogleVertexEmbeddingFunction( + model_name=model_name, + api_key=config.get("api_key"), + ) ) elif provider == "google": - from chromadb.utils.embedding_functions.google_embedding_function import ( - GoogleGenerativeAiEmbeddingFunction, - ) - - self.embedder_config = GoogleGenerativeAiEmbeddingFunction( - model_name=model_name, - api_key=config.get("api_key"), + self.embedder_config = ( + embedding_functions.GoogleGenerativeAiEmbeddingFunction( + model_name=model_name, + api_key=config.get("api_key"), + ) ) elif provider == "cohere": - from chromadb.utils.embedding_functions.cohere_embedding_function import ( - CohereEmbeddingFunction, - ) - - self.embedder_config = CohereEmbeddingFunction( + self.embedder_config = embedding_functions.CohereEmbeddingFunction( model_name=model_name, api_key=config.get("api_key"), )