diff --git a/src/crewai/memory/storage/rag_storage.py b/src/crewai/memory/storage/rag_storage.py index db98c0036..67e8ef515 100644 --- a/src/crewai/memory/storage/rag_storage.py +++ b/src/crewai/memory/storage/rag_storage.py @@ -4,13 +4,13 @@ import logging import os import shutil import uuid -from typing import Any, Dict, List, Optional -from crewai.memory.storage.base_rag_storage import BaseRAGStorage -from crewai.utilities.paths import db_storage_path +from typing import Any, Dict, List, Optional, cast + +from chromadb import Documents, EmbeddingFunction, Embeddings from chromadb.api import ClientAPI from chromadb.api.types import validate_embedding_function -from chromadb import Documents, EmbeddingFunction, Embeddings -from typing import cast +from crewai.memory.storage.base_rag_storage import BaseRAGStorage +from crewai.utilities.paths import db_storage_path @contextlib.contextmanager @@ -21,9 +21,11 @@ def suppress_logging( logger = logging.getLogger(logger_name) original_level = logger.getEffectiveLevel() logger.setLevel(level) - with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr( - io.StringIO() - ), contextlib.suppress(UserWarning): + with ( + contextlib.redirect_stdout(io.StringIO()), + contextlib.redirect_stderr(io.StringIO()), + contextlib.suppress(UserWarning), + ): yield logger.setLevel(original_level) @@ -113,12 +115,58 @@ class RAGStorage(BaseRAGStorage): self.embedder_config = embedding_functions.HuggingFaceEmbeddingServer( url=config.get("api_url"), ) + elif provider == "watson": + + # https://ibm.github.io/watsonx-ai-python-sdk/fm_embeddings.html + try: + import ibm_watsonx_ai.foundation_models as watson_models + from ibm_watsonx_ai import Credentials + from ibm_watsonx_ai.metanames import ( + EmbedTextParamsMetaNames as EmbedParams, + ) + except ImportError as e: + raise ImportError( + "IBM Watson dependencies are not installed. Please install them to use Watson embedding." + ) from e + + class WatsonEmbeddingFunction(EmbeddingFunction): + def __call__(self, input: Documents) -> watson_models.Embeddings: + if isinstance(input, str): + input = [input] + + embed_params = { + EmbedParams.TRUNCATE_INPUT_TOKENS: 3, + EmbedParams.RETURN_OPTIONS: {"input_text": True}, + } + + embedding = watson_models.Embeddings( + model_id=config.get("model"), + params=embed_params, + credentials=Credentials( + api_key=config.get("api_key"), url=config.get("api_url") + ), + project_id=config.get("project_id"), + ) + + try: + print("Embedding input:", input) + embeddings = embedding.embed_documents(input) + print("Embedding output:", embeddings) + casted = cast(Embeddings, embeddings) + print("Casted:", casted) + return casted + + except Exception as e: + print("Error during Watson embedding:", e) + raise e + + self.embedder_config = WatsonEmbeddingFunction() else: raise Exception( - f"Unsupported embedding provider: {provider}, supported providers: [openai, azure, ollama, vertexai, google, cohere, huggingface]" + f"Unsupported embedding provider: {provider}, supported providers: [openai, azure, ollama, vertexai, google, cohere, huggingface, watson]" ) else: - validate_embedding_function(self.embedder_config) # type: ignore # used for validating embedder_config if defined a embedding function/class + validate_embedding_function(self.embedder_config) self.embedder_config = self.embedder_config def _initialize_app(self):