mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-14 18:48:29 +00:00
fix: update embedding functions to inherit from chromadb callable
This commit is contained in:
@@ -140,3 +140,10 @@ class EmbeddingFunction(Protocol[D]):
|
||||
return validate_embeddings(normalized)
|
||||
|
||||
cls.__call__ = wrapped_call # type: ignore[method-assign]
|
||||
|
||||
def embed_query(self, input: D) -> Embeddings:
|
||||
"""
|
||||
Get the embeddings for a query input.
|
||||
This method is optional, and if not implemented, the default behavior is to call __call__.
|
||||
"""
|
||||
return self.__call__(input=input)
|
||||
|
||||
@@ -2,10 +2,9 @@
|
||||
|
||||
from typing import cast
|
||||
|
||||
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||
from crewai.rag.core.types import Documents, Embeddings
|
||||
from crewai.rag.embeddings.providers.ibm.types import WatsonXProviderConfig
|
||||
|
||||
|
||||
@@ -18,8 +17,14 @@ class WatsonXEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
Args:
|
||||
**kwargs: Configuration parameters for WatsonX Embeddings and Credentials.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._config = kwargs
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
"""Return the name of the embedding function for ChromaDB compatibility."""
|
||||
return "watsonx"
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
"""Generate embeddings for input documents.
|
||||
|
||||
|
||||
@@ -2,10 +2,9 @@
|
||||
|
||||
from typing import cast
|
||||
|
||||
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||
from crewai.rag.core.types import Documents, Embeddings
|
||||
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderConfig
|
||||
|
||||
|
||||
@@ -33,6 +32,11 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
timeout=kwargs.get("timeout"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
"""Return the name of the embedding function for ChromaDB compatibility."""
|
||||
return "voyageai"
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
"""Generate embeddings for input documents.
|
||||
|
||||
|
||||
@@ -11,7 +11,10 @@ from crewai.rag.embeddings.providers.google.types import (
|
||||
VertexAIProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.huggingface.types import HuggingFaceProviderSpec
|
||||
from crewai.rag.embeddings.providers.ibm.types import WatsonXProviderSpec
|
||||
from crewai.rag.embeddings.providers.ibm.types import (
|
||||
WatsonProviderSpec,
|
||||
WatsonXProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec
|
||||
from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec
|
||||
from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec
|
||||
@@ -44,6 +47,7 @@ ProviderSpec = (
|
||||
| Text2VecProviderSpec
|
||||
| VertexAIProviderSpec
|
||||
| VoyageAIProviderSpec
|
||||
| WatsonProviderSpec # Deprecated, use WatsonXProviderSpec
|
||||
| WatsonXProviderSpec
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user