fix: update embedding functions to inherit from chromadb callable

This commit is contained in:
Greyson LaLonde
2025-09-26 12:25:19 -04:00
committed by GitHub
parent 12fa7e2ff1
commit 73e932bfee
4 changed files with 25 additions and 5 deletions

View File

@@ -140,3 +140,10 @@ class EmbeddingFunction(Protocol[D]):
return validate_embeddings(normalized)
cls.__call__ = wrapped_call # type: ignore[method-assign]
def embed_query(self, input: D) -> Embeddings:
"""
Get the embeddings for a query input.
This method is optional, and if not implemented, the default behavior is to call __call__.
"""
return self.__call__(input=input)

View File

@@ -2,10 +2,9 @@
from typing import cast
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
from typing_extensions import Unpack
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
from crewai.rag.core.types import Documents, Embeddings
from crewai.rag.embeddings.providers.ibm.types import WatsonXProviderConfig
@@ -18,8 +17,14 @@ class WatsonXEmbeddingFunction(EmbeddingFunction[Documents]):
Args:
**kwargs: Configuration parameters for WatsonX Embeddings and Credentials.
"""
super().__init__(**kwargs)
self._config = kwargs
@staticmethod
def name() -> str:
"""Return the name of the embedding function for ChromaDB compatibility."""
return "watsonx"
def __call__(self, input: Documents) -> Embeddings:
"""Generate embeddings for input documents.

View File

@@ -2,10 +2,9 @@
from typing import cast
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
from typing_extensions import Unpack
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
from crewai.rag.core.types import Documents, Embeddings
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderConfig
@@ -33,6 +32,11 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]):
timeout=kwargs.get("timeout"),
)
@staticmethod
def name() -> str:
"""Return the name of the embedding function for ChromaDB compatibility."""
return "voyageai"
def __call__(self, input: Documents) -> Embeddings:
"""Generate embeddings for input documents.

View File

@@ -11,7 +11,10 @@ from crewai.rag.embeddings.providers.google.types import (
VertexAIProviderSpec,
)
from crewai.rag.embeddings.providers.huggingface.types import HuggingFaceProviderSpec
from crewai.rag.embeddings.providers.ibm.types import WatsonXProviderSpec
from crewai.rag.embeddings.providers.ibm.types import (
WatsonProviderSpec,
WatsonXProviderSpec,
)
from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec
from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec
from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec
@@ -44,6 +47,7 @@ ProviderSpec = (
| Text2VecProviderSpec
| VertexAIProviderSpec
| VoyageAIProviderSpec
| WatsonProviderSpec # Deprecated, use WatsonXProviderSpec
| WatsonXProviderSpec
)