mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 04:18:35 +00:00
fix: HuggingFace embedder configuration validation error (#3995)
- Update HuggingFaceProvider to use HuggingFaceEmbeddingFunction instead of HuggingFaceEmbeddingServer for HuggingFace Inference API support - Add api_key, model_name, and api_key_env_var fields to match documented config - Accept api_url for compatibility but exclude from model_dump (not used by HuggingFace Inference API) - Add validation aliases for model (maps to model_name) and environment variables - Update HuggingFaceProviderConfig TypedDict with new fields - Add comprehensive tests for HuggingFace provider configuration - Regenerate uv.lock (was corrupted) Fixes #3995 Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -1,21 +1,66 @@
|
|||||||
"""HuggingFace embeddings provider."""
|
"""HuggingFace embeddings provider."""
|
||||||
|
|
||||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||||
HuggingFaceEmbeddingServer,
|
HuggingFaceEmbeddingFunction,
|
||||||
)
|
)
|
||||||
from pydantic import AliasChoices, Field
|
from pydantic import AliasChoices, Field
|
||||||
|
|
||||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||||
|
|
||||||
|
|
||||||
class HuggingFaceProvider(BaseEmbeddingsProvider[HuggingFaceEmbeddingServer]):
|
class HuggingFaceProvider(BaseEmbeddingsProvider[HuggingFaceEmbeddingFunction]):
|
||||||
"""HuggingFace embeddings provider."""
|
"""HuggingFace embeddings provider using the Inference API.
|
||||||
|
|
||||||
embedding_callable: type[HuggingFaceEmbeddingServer] = Field(
|
This provider uses the HuggingFace Inference API for text embeddings.
|
||||||
default=HuggingFaceEmbeddingServer,
|
It supports configuration via direct parameters or environment variables.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
embedder={
|
||||||
|
"provider": "huggingface",
|
||||||
|
"config": {
|
||||||
|
"api_key": "your-hf-token",
|
||||||
|
"model": "sentence-transformers/all-MiniLM-L6-v2"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
embedding_callable: type[HuggingFaceEmbeddingFunction] = Field(
|
||||||
|
default=HuggingFaceEmbeddingFunction,
|
||||||
description="HuggingFace embedding function class",
|
description="HuggingFace embedding function class",
|
||||||
)
|
)
|
||||||
url: str = Field(
|
api_key: str | None = Field(
|
||||||
description="HuggingFace API URL",
|
default=None,
|
||||||
validation_alias=AliasChoices("EMBEDDINGS_HUGGINGFACE_URL", "HUGGINGFACE_URL"),
|
description="HuggingFace API key for authentication",
|
||||||
|
validation_alias=AliasChoices(
|
||||||
|
"EMBEDDINGS_HUGGINGFACE_API_KEY",
|
||||||
|
"HUGGINGFACE_API_KEY",
|
||||||
|
"HF_TOKEN",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
model_name: str = Field(
|
||||||
|
default="sentence-transformers/all-MiniLM-L6-v2",
|
||||||
|
description="Model name to use for embeddings",
|
||||||
|
validation_alias=AliasChoices(
|
||||||
|
"EMBEDDINGS_HUGGINGFACE_MODEL",
|
||||||
|
"HUGGINGFACE_MODEL",
|
||||||
|
"model",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
api_key_env_var: str = Field(
|
||||||
|
default="CHROMA_HUGGINGFACE_API_KEY",
|
||||||
|
description="Environment variable name containing the API key",
|
||||||
|
validation_alias=AliasChoices(
|
||||||
|
"EMBEDDINGS_HUGGINGFACE_API_KEY_ENV_VAR",
|
||||||
|
"HUGGINGFACE_API_KEY_ENV_VAR",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
api_url: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="API URL (accepted for compatibility but not used by HuggingFace Inference API)",
|
||||||
|
validation_alias=AliasChoices(
|
||||||
|
"EMBEDDINGS_HUGGINGFACE_URL",
|
||||||
|
"HUGGINGFACE_URL",
|
||||||
|
"url",
|
||||||
|
),
|
||||||
|
exclude=True,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -6,8 +6,24 @@ from typing_extensions import Required, TypedDict
|
|||||||
|
|
||||||
|
|
||||||
class HuggingFaceProviderConfig(TypedDict, total=False):
|
class HuggingFaceProviderConfig(TypedDict, total=False):
|
||||||
"""Configuration for HuggingFace provider."""
|
"""Configuration for HuggingFace provider.
|
||||||
|
|
||||||
|
Supports HuggingFace Inference API for text embeddings.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
api_key: HuggingFace API key for authentication.
|
||||||
|
model: Model name to use for embeddings (e.g., "sentence-transformers/all-MiniLM-L6-v2").
|
||||||
|
model_name: Alias for model.
|
||||||
|
api_key_env_var: Environment variable name containing the API key.
|
||||||
|
api_url: Optional API URL (accepted but not used, for compatibility).
|
||||||
|
url: Alias for api_url (accepted but not used, for compatibility).
|
||||||
|
"""
|
||||||
|
|
||||||
|
api_key: str
|
||||||
|
model: str
|
||||||
|
model_name: str
|
||||||
|
api_key_env_var: str
|
||||||
|
api_url: str
|
||||||
url: str
|
url: str
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -176,6 +176,98 @@ class TestEmbeddingFactory:
|
|||||||
"crewai.rag.embeddings.providers.ibm.watsonx.WatsonXProvider"
|
"crewai.rag.embeddings.providers.ibm.watsonx.WatsonXProvider"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||||
|
def test_build_embedder_huggingface(self, mock_import):
|
||||||
|
"""Test building HuggingFace embedder with api_key and model."""
|
||||||
|
mock_provider_class = MagicMock()
|
||||||
|
mock_provider_instance = MagicMock()
|
||||||
|
mock_embedding_function = MagicMock()
|
||||||
|
|
||||||
|
mock_import.return_value = mock_provider_class
|
||||||
|
mock_provider_class.return_value = mock_provider_instance
|
||||||
|
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"provider": "huggingface",
|
||||||
|
"config": {
|
||||||
|
"api_key": "hf-test-key",
|
||||||
|
"model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
build_embedder(config)
|
||||||
|
|
||||||
|
mock_import.assert_called_once_with(
|
||||||
|
"crewai.rag.embeddings.providers.huggingface.huggingface_provider.HuggingFaceProvider"
|
||||||
|
)
|
||||||
|
mock_provider_class.assert_called_once()
|
||||||
|
|
||||||
|
call_kwargs = mock_provider_class.call_args.kwargs
|
||||||
|
assert call_kwargs["api_key"] == "hf-test-key"
|
||||||
|
assert call_kwargs["model"] == "sentence-transformers/all-MiniLM-L6-v2"
|
||||||
|
|
||||||
|
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||||
|
def test_build_embedder_huggingface_with_api_url(self, mock_import):
|
||||||
|
"""Test building HuggingFace embedder with api_url (for compatibility)."""
|
||||||
|
mock_provider_class = MagicMock()
|
||||||
|
mock_provider_instance = MagicMock()
|
||||||
|
mock_embedding_function = MagicMock()
|
||||||
|
|
||||||
|
mock_import.return_value = mock_provider_class
|
||||||
|
mock_provider_class.return_value = mock_provider_instance
|
||||||
|
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"provider": "huggingface",
|
||||||
|
"config": {
|
||||||
|
"api_key": "hf-test-key",
|
||||||
|
"model": "Qwen/Qwen3-Embedding-0.6B",
|
||||||
|
"api_url": "https://api-inference.huggingface.co",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
build_embedder(config)
|
||||||
|
|
||||||
|
mock_import.assert_called_once_with(
|
||||||
|
"crewai.rag.embeddings.providers.huggingface.huggingface_provider.HuggingFaceProvider"
|
||||||
|
)
|
||||||
|
mock_provider_class.assert_called_once()
|
||||||
|
|
||||||
|
call_kwargs = mock_provider_class.call_args.kwargs
|
||||||
|
assert call_kwargs["api_key"] == "hf-test-key"
|
||||||
|
assert call_kwargs["model"] == "Qwen/Qwen3-Embedding-0.6B"
|
||||||
|
assert call_kwargs["api_url"] == "https://api-inference.huggingface.co"
|
||||||
|
|
||||||
|
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||||
|
def test_build_embedder_huggingface_with_model_name(self, mock_import):
|
||||||
|
"""Test building HuggingFace embedder with model_name alias."""
|
||||||
|
mock_provider_class = MagicMock()
|
||||||
|
mock_provider_instance = MagicMock()
|
||||||
|
mock_embedding_function = MagicMock()
|
||||||
|
|
||||||
|
mock_import.return_value = mock_provider_class
|
||||||
|
mock_provider_class.return_value = mock_provider_instance
|
||||||
|
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"provider": "huggingface",
|
||||||
|
"config": {
|
||||||
|
"api_key": "hf-test-key",
|
||||||
|
"model_name": "sentence-transformers/all-MiniLM-L6-v2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
build_embedder(config)
|
||||||
|
|
||||||
|
mock_import.assert_called_once_with(
|
||||||
|
"crewai.rag.embeddings.providers.huggingface.huggingface_provider.HuggingFaceProvider"
|
||||||
|
)
|
||||||
|
mock_provider_class.assert_called_once()
|
||||||
|
|
||||||
|
call_kwargs = mock_provider_class.call_args.kwargs
|
||||||
|
assert call_kwargs["api_key"] == "hf-test-key"
|
||||||
|
assert call_kwargs["model_name"] == "sentence-transformers/all-MiniLM-L6-v2"
|
||||||
|
|
||||||
def test_build_embedder_unknown_provider(self):
|
def test_build_embedder_unknown_provider(self):
|
||||||
"""Test error handling for unknown provider."""
|
"""Test error handling for unknown provider."""
|
||||||
config = {"provider": "unknown-provider", "config": {}}
|
config = {"provider": "unknown-provider", "config": {}}
|
||||||
|
|||||||
143
lib/crewai/tests/rag/embeddings/test_huggingface_provider.py
Normal file
143
lib/crewai/tests/rag/embeddings/test_huggingface_provider.py
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
"""Tests for HuggingFace embedding provider."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||||
|
HuggingFaceEmbeddingFunction,
|
||||||
|
)
|
||||||
|
|
||||||
|
from crewai.rag.embeddings.factory import build_embedder
|
||||||
|
from crewai.rag.embeddings.providers.huggingface.huggingface_provider import (
|
||||||
|
HuggingFaceProvider,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestHuggingFaceProvider:
|
||||||
|
"""Test HuggingFace embedding provider."""
|
||||||
|
|
||||||
|
def test_provider_with_api_key_and_model(self):
|
||||||
|
"""Test provider initialization with api_key and model.
|
||||||
|
|
||||||
|
This tests the fix for GitHub issue #3995 where users couldn't
|
||||||
|
configure HuggingFace embedder with api_key and model.
|
||||||
|
"""
|
||||||
|
provider = HuggingFaceProvider(
|
||||||
|
api_key="test-hf-token",
|
||||||
|
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert provider.api_key == "test-hf-token"
|
||||||
|
assert provider.model_name == "sentence-transformers/all-MiniLM-L6-v2"
|
||||||
|
assert provider.embedding_callable == HuggingFaceEmbeddingFunction
|
||||||
|
|
||||||
|
def test_provider_with_model_alias(self):
|
||||||
|
"""Test provider initialization with 'model' alias for model_name."""
|
||||||
|
provider = HuggingFaceProvider(
|
||||||
|
api_key="test-hf-token",
|
||||||
|
model="Qwen/Qwen3-Embedding-0.6B",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert provider.api_key == "test-hf-token"
|
||||||
|
assert provider.model_name == "Qwen/Qwen3-Embedding-0.6B"
|
||||||
|
|
||||||
|
def test_provider_with_api_url_compatibility(self):
|
||||||
|
"""Test provider accepts api_url for compatibility but excludes it from model_dump.
|
||||||
|
|
||||||
|
The api_url parameter is accepted for compatibility with the documented
|
||||||
|
configuration format but is not passed to HuggingFaceEmbeddingFunction
|
||||||
|
since it uses a fixed API endpoint.
|
||||||
|
"""
|
||||||
|
provider = HuggingFaceProvider(
|
||||||
|
api_key="test-hf-token",
|
||||||
|
model="sentence-transformers/all-MiniLM-L6-v2",
|
||||||
|
api_url="https://api-inference.huggingface.co",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert provider.api_key == "test-hf-token"
|
||||||
|
assert provider.model_name == "sentence-transformers/all-MiniLM-L6-v2"
|
||||||
|
assert provider.api_url == "https://api-inference.huggingface.co"
|
||||||
|
|
||||||
|
# api_url should be excluded from model_dump
|
||||||
|
dumped = provider.model_dump(exclude={"embedding_callable"})
|
||||||
|
assert "api_url" not in dumped
|
||||||
|
|
||||||
|
def test_provider_default_model(self):
|
||||||
|
"""Test provider uses default model when not specified."""
|
||||||
|
provider = HuggingFaceProvider(api_key="test-hf-token")
|
||||||
|
|
||||||
|
assert provider.model_name == "sentence-transformers/all-MiniLM-L6-v2"
|
||||||
|
|
||||||
|
def test_provider_default_api_key_env_var(self):
|
||||||
|
"""Test provider uses default api_key_env_var."""
|
||||||
|
provider = HuggingFaceProvider(api_key="test-hf-token")
|
||||||
|
|
||||||
|
assert provider.api_key_env_var == "CHROMA_HUGGINGFACE_API_KEY"
|
||||||
|
|
||||||
|
|
||||||
|
class TestHuggingFaceProviderIntegration:
|
||||||
|
"""Integration tests for HuggingFace provider with build_embedder."""
|
||||||
|
|
||||||
|
def test_build_embedder_with_documented_config(self):
|
||||||
|
"""Test build_embedder with the documented configuration format.
|
||||||
|
|
||||||
|
This tests the exact configuration format shown in the documentation
|
||||||
|
that was failing before the fix for GitHub issue #3995.
|
||||||
|
"""
|
||||||
|
config = {
|
||||||
|
"provider": "huggingface",
|
||||||
|
"config": {
|
||||||
|
"api_key": "test-hf-token",
|
||||||
|
"model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||||
|
"api_url": "https://api-inference.huggingface.co",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# This should not raise a validation error
|
||||||
|
embedder = build_embedder(config)
|
||||||
|
|
||||||
|
assert isinstance(embedder, HuggingFaceEmbeddingFunction)
|
||||||
|
assert embedder.model_name == "sentence-transformers/all-MiniLM-L6-v2"
|
||||||
|
|
||||||
|
def test_build_embedder_with_minimal_config(self):
|
||||||
|
"""Test build_embedder with minimal configuration."""
|
||||||
|
config = {
|
||||||
|
"provider": "huggingface",
|
||||||
|
"config": {
|
||||||
|
"api_key": "test-hf-token",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
embedder = build_embedder(config)
|
||||||
|
|
||||||
|
assert isinstance(embedder, HuggingFaceEmbeddingFunction)
|
||||||
|
# Default model should be used
|
||||||
|
assert embedder.model_name == "sentence-transformers/all-MiniLM-L6-v2"
|
||||||
|
|
||||||
|
def test_build_embedder_with_model_name_config(self):
|
||||||
|
"""Test build_embedder with model_name instead of model."""
|
||||||
|
config = {
|
||||||
|
"provider": "huggingface",
|
||||||
|
"config": {
|
||||||
|
"api_key": "test-hf-token",
|
||||||
|
"model_name": "sentence-transformers/paraphrase-MiniLM-L6-v2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
embedder = build_embedder(config)
|
||||||
|
|
||||||
|
assert isinstance(embedder, HuggingFaceEmbeddingFunction)
|
||||||
|
assert embedder.model_name == "sentence-transformers/paraphrase-MiniLM-L6-v2"
|
||||||
|
|
||||||
|
def test_build_embedder_with_custom_model(self):
|
||||||
|
"""Test build_embedder with a custom model name."""
|
||||||
|
config = {
|
||||||
|
"provider": "huggingface",
|
||||||
|
"config": {
|
||||||
|
"api_key": "test-hf-token",
|
||||||
|
"model": "Qwen/Qwen3-Embedding-0.6B",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
embedder = build_embedder(config)
|
||||||
|
|
||||||
|
assert isinstance(embedder, HuggingFaceEmbeddingFunction)
|
||||||
|
assert embedder.model_name == "Qwen/Qwen3-Embedding-0.6B"
|
||||||
Reference in New Issue
Block a user