diff --git a/src/crewai/rag/core/base_embeddings_callable.py b/src/crewai/rag/core/base_embeddings_callable.py index 090a3d026..85fe88584 100644 --- a/src/crewai/rag/core/base_embeddings_callable.py +++ b/src/crewai/rag/core/base_embeddings_callable.py @@ -140,3 +140,10 @@ class EmbeddingFunction(Protocol[D]): return validate_embeddings(normalized) cls.__call__ = wrapped_call # type: ignore[method-assign] + + def embed_query(self, input: D) -> Embeddings: + """ + Get the embeddings for a query input. + This method is optional, and if not implemented, the default behavior is to call __call__. + """ + return self.__call__(input=input) diff --git a/src/crewai/rag/embeddings/providers/ibm/embedding_callable.py b/src/crewai/rag/embeddings/providers/ibm/embedding_callable.py index dfc487750..56198987d 100644 --- a/src/crewai/rag/embeddings/providers/ibm/embedding_callable.py +++ b/src/crewai/rag/embeddings/providers/ibm/embedding_callable.py @@ -2,10 +2,9 @@ from typing import cast +from chromadb.api.types import Documents, EmbeddingFunction, Embeddings from typing_extensions import Unpack -from crewai.rag.core.base_embeddings_callable import EmbeddingFunction -from crewai.rag.core.types import Documents, Embeddings from crewai.rag.embeddings.providers.ibm.types import WatsonXProviderConfig @@ -18,8 +17,14 @@ class WatsonXEmbeddingFunction(EmbeddingFunction[Documents]): Args: **kwargs: Configuration parameters for WatsonX Embeddings and Credentials. """ + super().__init__(**kwargs) self._config = kwargs + @staticmethod + def name() -> str: + """Return the name of the embedding function for ChromaDB compatibility.""" + return "watsonx" + def __call__(self, input: Documents) -> Embeddings: """Generate embeddings for input documents. diff --git a/src/crewai/rag/embeddings/providers/voyageai/embedding_callable.py b/src/crewai/rag/embeddings/providers/voyageai/embedding_callable.py index 71610b839..f7d7f7103 100644 --- a/src/crewai/rag/embeddings/providers/voyageai/embedding_callable.py +++ b/src/crewai/rag/embeddings/providers/voyageai/embedding_callable.py @@ -2,10 +2,9 @@ from typing import cast +from chromadb.api.types import Documents, EmbeddingFunction, Embeddings from typing_extensions import Unpack -from crewai.rag.core.base_embeddings_callable import EmbeddingFunction -from crewai.rag.core.types import Documents, Embeddings from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderConfig @@ -33,6 +32,11 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]): timeout=kwargs.get("timeout"), ) + @staticmethod + def name() -> str: + """Return the name of the embedding function for ChromaDB compatibility.""" + return "voyageai" + def __call__(self, input: Documents) -> Embeddings: """Generate embeddings for input documents. diff --git a/src/crewai/rag/embeddings/types.py b/src/crewai/rag/embeddings/types.py index 7400acd58..f727cd220 100644 --- a/src/crewai/rag/embeddings/types.py +++ b/src/crewai/rag/embeddings/types.py @@ -11,7 +11,10 @@ from crewai.rag.embeddings.providers.google.types import ( VertexAIProviderSpec, ) from crewai.rag.embeddings.providers.huggingface.types import HuggingFaceProviderSpec -from crewai.rag.embeddings.providers.ibm.types import WatsonXProviderSpec +from crewai.rag.embeddings.providers.ibm.types import ( + WatsonProviderSpec, + WatsonXProviderSpec, +) from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec @@ -44,6 +47,7 @@ ProviderSpec = ( | Text2VecProviderSpec | VertexAIProviderSpec | VoyageAIProviderSpec + | WatsonProviderSpec # Deprecated, use WatsonXProviderSpec | WatsonXProviderSpec )