mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-02 15:52:34 +00:00
fix: use HuggingFaceEmbeddingFunction for embeddings, update keys and add tests (#4005)
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Check Documentation Broken Links / Check broken links (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Check Documentation Broken Links / Check broken links (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
This commit is contained in:
@@ -1,21 +1,35 @@
|
||||
"""HuggingFace embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||
HuggingFaceEmbeddingServer,
|
||||
HuggingFaceEmbeddingFunction,
|
||||
)
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class HuggingFaceProvider(BaseEmbeddingsProvider[HuggingFaceEmbeddingServer]):
|
||||
"""HuggingFace embeddings provider."""
|
||||
class HuggingFaceProvider(BaseEmbeddingsProvider[HuggingFaceEmbeddingFunction]):
|
||||
"""HuggingFace embeddings provider for the HuggingFace Inference API."""
|
||||
|
||||
embedding_callable: type[HuggingFaceEmbeddingServer] = Field(
|
||||
default=HuggingFaceEmbeddingServer,
|
||||
embedding_callable: type[HuggingFaceEmbeddingFunction] = Field(
|
||||
default=HuggingFaceEmbeddingFunction,
|
||||
description="HuggingFace embedding function class",
|
||||
)
|
||||
url: str = Field(
|
||||
description="HuggingFace API URL",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_HUGGINGFACE_URL", "HUGGINGFACE_URL"),
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="HuggingFace API key",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_HUGGINGFACE_API_KEY",
|
||||
"HUGGINGFACE_API_KEY",
|
||||
"HF_TOKEN",
|
||||
),
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="sentence-transformers/all-MiniLM-L6-v2",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_HUGGINGFACE_MODEL_NAME",
|
||||
"HUGGINGFACE_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Type definitions for HuggingFace embedding providers."""
|
||||
|
||||
from typing import Literal
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
@@ -8,7 +8,11 @@ from typing_extensions import Required, TypedDict
|
||||
class HuggingFaceProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for HuggingFace provider."""
|
||||
|
||||
url: str
|
||||
api_key: str
|
||||
model: Annotated[
|
||||
str, "sentence-transformers/all-MiniLM-L6-v2"
|
||||
] # alias for model_name for backward compat
|
||||
model_name: Annotated[str, "sentence-transformers/all-MiniLM-L6-v2"]
|
||||
|
||||
|
||||
class HuggingFaceProviderSpec(TypedDict, total=False):
|
||||
|
||||
Reference in New Issue
Block a user