diff --git a/docs/concepts/memory.mdx b/docs/concepts/memory.mdx index b07096442..de1fb3510 100644 --- a/docs/concepts/memory.mdx +++ b/docs/concepts/memory.mdx @@ -105,9 +105,48 @@ my_crew = Crew( process=Process.sequential, memory=True, verbose=True, - embedder=embedding_functions.OpenAIEmbeddingFunction( - api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small" - ) + embedder={ + "provider": "openai", + "config": { + "model": 'text-embedding-3-small' + } + } +) +``` +Alternatively, you can directly pass the OpenAIEmbeddingFunction to the embedder parameter. + +Example: +```python Code +from crewai import Crew, Agent, Task, Process +from chromadb.utils.embedding_functions.openai_embedding_function import OpenAIEmbeddingFunction + +my_crew = Crew( + agents=[...], + tasks=[...], + process=Process.sequential, + memory=True, + verbose=True, + embedder=OpenAIEmbeddingFunction(api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"), +) +``` + +### Using Ollama embeddings + +```python Code +from crewai import Crew, Agent, Task, Process + +my_crew = Crew( + agents=[...], + tasks=[...], + process=Process.sequential, + memory=True, + verbose=True, + embedder={ + "provider": "ollama", + "config": { + "model": "mxbai-embed-large" + } + } ) ``` @@ -122,10 +161,13 @@ my_crew = Crew( process=Process.sequential, memory=True, verbose=True, - embedder=embedding_functions.OpenAIEmbeddingFunction( - api_key=os.getenv("OPENAI_API_KEY"), - model_name="text-embedding-ada-002" - ) + embedder={ + "provider": "google", + "config": { + "api_key": "", + "model_name": "" + } + } ) ``` @@ -181,10 +223,32 @@ my_crew = Crew( process=Process.sequential, memory=True, verbose=True, - embedder=embedding_functions.CohereEmbeddingFunction( - api_key=YOUR_API_KEY, - model_name="" - ) + embedder={ + "provider": "cohere", + "config": { + "api_key": "YOUR_API_KEY", + "model_name": "" + } + } +) +``` +### Using HuggingFace embeddings + +```python Code +from crewai import Crew, Agent, Task, Process + +my_crew = Crew( + agents=[...], + tasks=[...], + process=Process.sequential, + memory=True, + verbose=True, + embedder={ + "provider": "huggingface", + "config": { + "api_url": "", + } + } ) ``` diff --git a/src/crewai/memory/entity/entity_memory.py b/src/crewai/memory/entity/entity_memory.py index d4f8d9aae..4de0594c7 100644 --- a/src/crewai/memory/entity/entity_memory.py +++ b/src/crewai/memory/entity/entity_memory.py @@ -16,7 +16,7 @@ class EntityMemory(Memory): if storage else RAGStorage( type="entities", - allow_reset=False, + allow_reset=True, embedder_config=embedder_config, crew=crew, ) diff --git a/src/crewai/memory/storage/rag_storage.py b/src/crewai/memory/storage/rag_storage.py index 8d45d9f5a..db98c0036 100644 --- a/src/crewai/memory/storage/rag_storage.py +++ b/src/crewai/memory/storage/rag_storage.py @@ -8,6 +8,9 @@ from typing import Any, Dict, List, Optional from crewai.memory.storage.base_rag_storage import BaseRAGStorage from crewai.utilities.paths import db_storage_path from chromadb.api import ClientAPI +from chromadb.api.types import validate_embedding_function +from chromadb import Documents, EmbeddingFunction, Embeddings +from typing import cast @contextlib.contextmanager @@ -41,16 +44,93 @@ class RAGStorage(BaseRAGStorage): self.agents = agents self.type = type - self.embedder_config = embedder_config or self._create_embedding_function() + self.allow_reset = allow_reset self._initialize_app() + def _set_embedder_config(self): + import chromadb.utils.embedding_functions as embedding_functions + + if self.embedder_config is None: + self.embedder_config = self._create_default_embedding_function() + + if isinstance(self.embedder_config, dict): + provider = self.embedder_config.get("provider") + config = self.embedder_config.get("config", {}) + model_name = config.get("model") + if provider == "openai": + self.embedder_config = embedding_functions.OpenAIEmbeddingFunction( + api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"), + model_name=model_name, + ) + elif provider == "azure": + self.embedder_config = embedding_functions.OpenAIEmbeddingFunction( + api_key=config.get("api_key"), + api_base=config.get("api_base"), + api_type=config.get("api_type", "azure"), + api_version=config.get("api_version"), + model_name=model_name, + ) + elif provider == "ollama": + from openai import OpenAI + + class OllamaEmbeddingFunction(EmbeddingFunction): + def __call__(self, input: Documents) -> Embeddings: + client = OpenAI( + base_url="http://localhost:11434/v1", + api_key=config.get("api_key", "ollama"), + ) + try: + response = client.embeddings.create( + input=input, model=model_name + ) + embeddings = [item.embedding for item in response.data] + return cast(Embeddings, embeddings) + except Exception as e: + raise e + + self.embedder_config = OllamaEmbeddingFunction() + elif provider == "vertexai": + self.embedder_config = ( + embedding_functions.GoogleVertexEmbeddingFunction( + model_name=model_name, + api_key=config.get("api_key"), + ) + ) + elif provider == "google": + self.embedder_config = ( + embedding_functions.GoogleGenerativeAiEmbeddingFunction( + model_name=model_name, + api_key=config.get("api_key"), + ) + ) + elif provider == "cohere": + self.embedder_config = embedding_functions.CohereEmbeddingFunction( + model_name=model_name, + api_key=config.get("api_key"), + ) + elif provider == "huggingface": + self.embedder_config = embedding_functions.HuggingFaceEmbeddingServer( + url=config.get("api_url"), + ) + else: + raise Exception( + f"Unsupported embedding provider: {provider}, supported providers: [openai, azure, ollama, vertexai, google, cohere, huggingface]" + ) + else: + validate_embedding_function(self.embedder_config) # type: ignore # used for validating embedder_config if defined a embedding function/class + self.embedder_config = self.embedder_config + def _initialize_app(self): import chromadb + from chromadb.config import Settings + self._set_embedder_config() chroma_client = chromadb.PersistentClient( - path=f"{db_storage_path()}/{self.type}/{self.agents}" + path=f"{db_storage_path()}/{self.type}/{self.agents}", + settings=Settings(allow_reset=self.allow_reset), ) + self.app = chroma_client try: @@ -122,11 +202,15 @@ class RAGStorage(BaseRAGStorage): if self.app: self.app.reset() except Exception as e: - raise Exception( - f"An error occurred while resetting the {self.type} memory: {e}" - ) + if "attempt to write a readonly database" in str(e): + # Ignore this specific error + pass + else: + raise Exception( + f"An error occurred while resetting the {self.type} memory: {e}" + ) - def _create_embedding_function(self): + def _create_default_embedding_function(self): import chromadb.utils.embedding_functions as embedding_functions return embedding_functions.OpenAIEmbeddingFunction(