fix: use HuggingFaceEmbeddingFunction for embeddings, update keys and add tests (#4005)
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Check Documentation Broken Links / Check broken links (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled

This commit is contained in:
Greyson LaLonde
2025-12-04 18:05:50 -05:00
committed by GitHub
parent 34e09162ba
commit 7fff2b654c
6 changed files with 61 additions and 16 deletions

View File

@@ -515,8 +515,7 @@ crew = Crew(
"provider": "huggingface", "provider": "huggingface",
"config": { "config": {
"api_key": "your-hf-token", # Optional for public models "api_key": "your-hf-token", # Optional for public models
"model": "sentence-transformers/all-MiniLM-L6-v2", "model": "sentence-transformers/all-MiniLM-L6-v2"
"api_url": "https://api-inference.huggingface.co" # or your custom endpoint
} }
} }
) )

View File

@@ -515,8 +515,7 @@ crew = Crew(
"provider": "huggingface", "provider": "huggingface",
"config": { "config": {
"api_key": "your-hf-token", # Optional for public models "api_key": "your-hf-token", # Optional for public models
"model": "sentence-transformers/all-MiniLM-L6-v2", "model": "sentence-transformers/all-MiniLM-L6-v2"
"api_url": "https://api-inference.huggingface.co" # or your custom endpoint
} }
} }
) )

View File

@@ -515,8 +515,7 @@ crew = Crew(
"provider": "huggingface", "provider": "huggingface",
"config": { "config": {
"api_key": "your-hf-token", # Opcional para modelos públicos "api_key": "your-hf-token", # Opcional para modelos públicos
"model": "sentence-transformers/all-MiniLM-L6-v2", "model": "sentence-transformers/all-MiniLM-L6-v2"
"api_url": "https://api-inference.huggingface.co" # ou seu endpoint customizado
} }
} }
) )

View File

@@ -1,21 +1,35 @@
"""HuggingFace embeddings provider.""" """HuggingFace embeddings provider."""
from chromadb.utils.embedding_functions.huggingface_embedding_function import ( from chromadb.utils.embedding_functions.huggingface_embedding_function import (
HuggingFaceEmbeddingServer, HuggingFaceEmbeddingFunction,
) )
from pydantic import AliasChoices, Field from pydantic import AliasChoices, Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class HuggingFaceProvider(BaseEmbeddingsProvider[HuggingFaceEmbeddingServer]): class HuggingFaceProvider(BaseEmbeddingsProvider[HuggingFaceEmbeddingFunction]):
"""HuggingFace embeddings provider.""" """HuggingFace embeddings provider for the HuggingFace Inference API."""
embedding_callable: type[HuggingFaceEmbeddingServer] = Field( embedding_callable: type[HuggingFaceEmbeddingFunction] = Field(
default=HuggingFaceEmbeddingServer, default=HuggingFaceEmbeddingFunction,
description="HuggingFace embedding function class", description="HuggingFace embedding function class",
) )
url: str = Field( api_key: str | None = Field(
description="HuggingFace API URL", default=None,
validation_alias=AliasChoices("EMBEDDINGS_HUGGINGFACE_URL", "HUGGINGFACE_URL"), description="HuggingFace API key",
validation_alias=AliasChoices(
"EMBEDDINGS_HUGGINGFACE_API_KEY",
"HUGGINGFACE_API_KEY",
"HF_TOKEN",
),
)
model_name: str = Field(
default="sentence-transformers/all-MiniLM-L6-v2",
description="Model name to use for embeddings",
validation_alias=AliasChoices(
"EMBEDDINGS_HUGGINGFACE_MODEL_NAME",
"HUGGINGFACE_MODEL_NAME",
"model",
),
) )

View File

@@ -1,6 +1,6 @@
"""Type definitions for HuggingFace embedding providers.""" """Type definitions for HuggingFace embedding providers."""
from typing import Literal from typing import Annotated, Literal
from typing_extensions import Required, TypedDict from typing_extensions import Required, TypedDict
@@ -8,7 +8,11 @@ from typing_extensions import Required, TypedDict
class HuggingFaceProviderConfig(TypedDict, total=False): class HuggingFaceProviderConfig(TypedDict, total=False):
"""Configuration for HuggingFace provider.""" """Configuration for HuggingFace provider."""
url: str api_key: str
model: Annotated[
str, "sentence-transformers/all-MiniLM-L6-v2"
] # alias for model_name for backward compat
model_name: Annotated[str, "sentence-transformers/all-MiniLM-L6-v2"]
class HuggingFaceProviderSpec(TypedDict, total=False): class HuggingFaceProviderSpec(TypedDict, total=False):

View File

@@ -99,6 +99,36 @@ class TestEmbeddingFactory:
"crewai.rag.embeddings.providers.ollama.ollama_provider.OllamaProvider" "crewai.rag.embeddings.providers.ollama.ollama_provider.OllamaProvider"
) )
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_huggingface(self, mock_import):
"""Test building HuggingFace embedder."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
config = {
"provider": "huggingface",
"config": {
"api_key": "hf-test-key",
"model": "sentence-transformers/all-MiniLM-L6-v2",
},
}
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.huggingface.huggingface_provider.HuggingFaceProvider"
)
mock_provider_class.assert_called_once()
call_kwargs = mock_provider_class.call_args.kwargs
assert call_kwargs["api_key"] == "hf-test-key"
assert call_kwargs["model"] == "sentence-transformers/all-MiniLM-L6-v2"
@patch("crewai.rag.embeddings.factory.import_and_validate_definition") @patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_cohere(self, mock_import): def test_build_embedder_cohere(self, mock_import):
"""Test building Cohere embedder.""" """Test building Cohere embedder."""