mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-27 00:58:13 +00:00
fix: update embedding functions to inherit from chromadb callable
This commit is contained in:
@@ -140,3 +140,10 @@ class EmbeddingFunction(Protocol[D]):
|
|||||||
return validate_embeddings(normalized)
|
return validate_embeddings(normalized)
|
||||||
|
|
||||||
cls.__call__ = wrapped_call # type: ignore[method-assign]
|
cls.__call__ = wrapped_call # type: ignore[method-assign]
|
||||||
|
|
||||||
|
def embed_query(self, input: D) -> Embeddings:
|
||||||
|
"""
|
||||||
|
Get the embeddings for a query input.
|
||||||
|
This method is optional, and if not implemented, the default behavior is to call __call__.
|
||||||
|
"""
|
||||||
|
return self.__call__(input=input)
|
||||||
|
|||||||
@@ -2,10 +2,9 @@
|
|||||||
|
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
|
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
|
||||||
from typing_extensions import Unpack
|
from typing_extensions import Unpack
|
||||||
|
|
||||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
|
||||||
from crewai.rag.core.types import Documents, Embeddings
|
|
||||||
from crewai.rag.embeddings.providers.ibm.types import WatsonXProviderConfig
|
from crewai.rag.embeddings.providers.ibm.types import WatsonXProviderConfig
|
||||||
|
|
||||||
|
|
||||||
@@ -18,8 +17,14 @@ class WatsonXEmbeddingFunction(EmbeddingFunction[Documents]):
|
|||||||
Args:
|
Args:
|
||||||
**kwargs: Configuration parameters for WatsonX Embeddings and Credentials.
|
**kwargs: Configuration parameters for WatsonX Embeddings and Credentials.
|
||||||
"""
|
"""
|
||||||
|
super().__init__(**kwargs)
|
||||||
self._config = kwargs
|
self._config = kwargs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def name() -> str:
|
||||||
|
"""Return the name of the embedding function for ChromaDB compatibility."""
|
||||||
|
return "watsonx"
|
||||||
|
|
||||||
def __call__(self, input: Documents) -> Embeddings:
|
def __call__(self, input: Documents) -> Embeddings:
|
||||||
"""Generate embeddings for input documents.
|
"""Generate embeddings for input documents.
|
||||||
|
|
||||||
|
|||||||
@@ -2,10 +2,9 @@
|
|||||||
|
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
|
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
|
||||||
from typing_extensions import Unpack
|
from typing_extensions import Unpack
|
||||||
|
|
||||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
|
||||||
from crewai.rag.core.types import Documents, Embeddings
|
|
||||||
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderConfig
|
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderConfig
|
||||||
|
|
||||||
|
|
||||||
@@ -33,6 +32,11 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]):
|
|||||||
timeout=kwargs.get("timeout"),
|
timeout=kwargs.get("timeout"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def name() -> str:
|
||||||
|
"""Return the name of the embedding function for ChromaDB compatibility."""
|
||||||
|
return "voyageai"
|
||||||
|
|
||||||
def __call__(self, input: Documents) -> Embeddings:
|
def __call__(self, input: Documents) -> Embeddings:
|
||||||
"""Generate embeddings for input documents.
|
"""Generate embeddings for input documents.
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,10 @@ from crewai.rag.embeddings.providers.google.types import (
|
|||||||
VertexAIProviderSpec,
|
VertexAIProviderSpec,
|
||||||
)
|
)
|
||||||
from crewai.rag.embeddings.providers.huggingface.types import HuggingFaceProviderSpec
|
from crewai.rag.embeddings.providers.huggingface.types import HuggingFaceProviderSpec
|
||||||
from crewai.rag.embeddings.providers.ibm.types import WatsonXProviderSpec
|
from crewai.rag.embeddings.providers.ibm.types import (
|
||||||
|
WatsonProviderSpec,
|
||||||
|
WatsonXProviderSpec,
|
||||||
|
)
|
||||||
from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec
|
from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec
|
||||||
from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec
|
from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec
|
||||||
from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec
|
from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec
|
||||||
@@ -44,6 +47,7 @@ ProviderSpec = (
|
|||||||
| Text2VecProviderSpec
|
| Text2VecProviderSpec
|
||||||
| VertexAIProviderSpec
|
| VertexAIProviderSpec
|
||||||
| VoyageAIProviderSpec
|
| VoyageAIProviderSpec
|
||||||
|
| WatsonProviderSpec # Deprecated, use WatsonXProviderSpec
|
||||||
| WatsonXProviderSpec
|
| WatsonXProviderSpec
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user