fix: use huggingface_hub InferenceClient for HuggingFace embeddings

Fixes #4145

The HuggingFace embedder was failing with 'could not convert string to float: error'
because chromadb's HuggingFaceEmbeddingFunction uses the deprecated
api-inference.huggingface.co endpoint which returns error messages instead of embeddings.

This fix creates a custom HuggingFaceEmbeddingFunction that uses huggingface_hub's
InferenceClient with provider='hf-inference' instead of the deprecated endpoint.

Changes:
- Add custom embedding_callable.py using huggingface_hub.InferenceClient
- Update HuggingFaceProvider to use the new embedding callable
- Handle different embedding response formats (1D, 2D, 3D arrays)
- Add comprehensive error handling with actionable error messages
- Add 16 test cases covering initialization, embedding generation, error handling,
  and ChromaDB integration

Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
Devin AI
2025-12-22 22:44:13 +00:00
parent be70a04153
commit 9460e5e182
4 changed files with 455 additions and 3 deletions

View File

@@ -0,0 +1,290 @@
"""Tests for HuggingFace embedding function."""
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from crewai.rag.embeddings.providers.huggingface.embedding_callable import (
HuggingFaceEmbeddingFunction,
)
class TestHuggingFaceEmbeddingFunction:
"""Test HuggingFace embedding function."""
@patch("huggingface_hub.InferenceClient")
def test_initialization_with_api_key(self, mock_client_class):
"""Test initialization with API key."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
ef = HuggingFaceEmbeddingFunction(
api_key="test-api-key",
model_name="sentence-transformers/all-MiniLM-L6-v2",
)
mock_client_class.assert_called_once_with(
provider="hf-inference",
token="test-api-key",
)
assert ef._model_name == "sentence-transformers/all-MiniLM-L6-v2"
@patch("huggingface_hub.InferenceClient")
def test_initialization_without_api_key(self, mock_client_class):
"""Test initialization without API key (for public models)."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
ef = HuggingFaceEmbeddingFunction(
model_name="sentence-transformers/all-MiniLM-L6-v2",
)
mock_client_class.assert_called_once_with(
provider="hf-inference",
token=None,
)
assert ef._model_name == "sentence-transformers/all-MiniLM-L6-v2"
@patch("huggingface_hub.InferenceClient")
def test_initialization_with_default_model(self, mock_client_class):
"""Test initialization with default model name."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
ef = HuggingFaceEmbeddingFunction()
assert ef._model_name == "sentence-transformers/all-MiniLM-L6-v2"
@patch("huggingface_hub.InferenceClient")
def test_call_with_single_document(self, mock_client_class):
"""Test embedding generation for a single document."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
# Mock the feature_extraction response (1D embedding)
mock_embedding = [0.1, 0.2, 0.3, 0.4, 0.5]
mock_client.feature_extraction.return_value = mock_embedding
ef = HuggingFaceEmbeddingFunction(api_key="test-key")
result = ef(["Hello, world!"])
mock_client.feature_extraction.assert_called_once_with(
text="Hello, world!",
model="sentence-transformers/all-MiniLM-L6-v2",
)
assert len(result) == 1
assert result[0] == pytest.approx(mock_embedding, rel=1e-5)
@patch("huggingface_hub.InferenceClient")
def test_call_with_multiple_documents(self, mock_client_class):
"""Test embedding generation for multiple documents."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
# Mock the feature_extraction response
mock_embedding1 = [0.1, 0.2, 0.3]
mock_embedding2 = [0.4, 0.5, 0.6]
mock_client.feature_extraction.side_effect = [mock_embedding1, mock_embedding2]
ef = HuggingFaceEmbeddingFunction(api_key="test-key")
result = ef(["Hello", "World"])
assert mock_client.feature_extraction.call_count == 2
assert len(result) == 2
assert result[0] == pytest.approx(mock_embedding1, rel=1e-5)
assert result[1] == pytest.approx(mock_embedding2, rel=1e-5)
@patch("huggingface_hub.InferenceClient")
def test_call_with_string_input(self, mock_client_class):
"""Test that string input is converted to list."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_embedding = [0.1, 0.2, 0.3]
mock_client.feature_extraction.return_value = mock_embedding
ef = HuggingFaceEmbeddingFunction(api_key="test-key")
result = ef("Hello") # type: ignore[arg-type]
mock_client.feature_extraction.assert_called_once()
assert len(result) == 1
@patch("huggingface_hub.InferenceClient")
def test_process_2d_embedding_result(self, mock_client_class):
"""Test processing of 2D token-level embeddings (mean pooling)."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
# Mock 2D token-level embeddings (3 tokens, 4 dimensions each)
mock_token_embeddings = [
[0.1, 0.2, 0.3, 0.4],
[0.2, 0.3, 0.4, 0.5],
[0.3, 0.4, 0.5, 0.6],
]
mock_client.feature_extraction.return_value = mock_token_embeddings
ef = HuggingFaceEmbeddingFunction(api_key="test-key")
result = ef(["Hello"])
# Expected: mean pooling across tokens
expected = np.mean(mock_token_embeddings, axis=0).tolist()
assert len(result) == 1
assert result[0] == pytest.approx(expected, rel=1e-5)
@patch("huggingface_hub.InferenceClient")
def test_process_3d_embedding_result(self, mock_client_class):
"""Test processing of 3D batch token-level embeddings."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
# Mock 3D embeddings (1 batch, 3 tokens, 4 dimensions)
mock_batch_embeddings = [
[
[0.1, 0.2, 0.3, 0.4],
[0.2, 0.3, 0.4, 0.5],
[0.3, 0.4, 0.5, 0.6],
]
]
mock_client.feature_extraction.return_value = mock_batch_embeddings
ef = HuggingFaceEmbeddingFunction(api_key="test-key")
result = ef(["Hello"])
# Expected: take first batch, then mean pooling
expected = np.mean(mock_batch_embeddings[0], axis=0).tolist()
assert len(result) == 1
assert result[0] == pytest.approx(expected, rel=1e-5)
@patch("huggingface_hub.InferenceClient")
def test_error_handling_deprecated_endpoint(self, mock_client_class):
"""Test error handling for deprecated endpoint error."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.feature_extraction.side_effect = Exception(
"https://api-inference.huggingface.co is no longer supported"
)
ef = HuggingFaceEmbeddingFunction(api_key="test-key")
with pytest.raises(ValueError, match="HuggingFace API endpoint error"):
ef(["Hello"])
@patch("huggingface_hub.InferenceClient")
def test_error_handling_unauthorized(self, mock_client_class):
"""Test error handling for authentication error."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.feature_extraction.side_effect = Exception("401 Unauthorized")
ef = HuggingFaceEmbeddingFunction(api_key="invalid-key")
with pytest.raises(ValueError, match="HuggingFace API authentication error"):
ef(["Hello"])
@patch("huggingface_hub.InferenceClient")
def test_error_handling_model_not_found(self, mock_client_class):
"""Test error handling for model not found error."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.feature_extraction.side_effect = Exception("404 Not Found")
ef = HuggingFaceEmbeddingFunction(
api_key="test-key", model_name="nonexistent/model"
)
with pytest.raises(ValueError, match="HuggingFace model not found"):
ef(["Hello"])
@patch("huggingface_hub.InferenceClient")
def test_error_handling_generic_error(self, mock_client_class):
"""Test error handling for generic API error."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.feature_extraction.side_effect = Exception("Some unexpected error")
ef = HuggingFaceEmbeddingFunction(api_key="test-key")
with pytest.raises(ValueError, match="HuggingFace API error"):
ef(["Hello"])
@patch("huggingface_hub.InferenceClient")
def test_name_method(self, mock_client_class):
"""Test the name() static method."""
assert HuggingFaceEmbeddingFunction.name() == "huggingface"
@patch("huggingface_hub.InferenceClient")
def test_get_config(self, mock_client_class):
"""Test get_config method."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
ef = HuggingFaceEmbeddingFunction(
api_key="test-key",
model_name="custom/model",
)
config = ef.get_config()
assert config["model_name"] == "custom/model"
assert config["api_key"] == "test-key"
class TestHuggingFaceEmbeddingFunctionIntegration:
"""Integration tests for HuggingFace embedding function with RAGStorage."""
@patch("huggingface_hub.InferenceClient")
def test_embedding_function_works_with_rag_storage_validation(
self, mock_client_class
):
"""Test that the embedding function works with RAGStorage validation.
This test simulates the validation that happens in RAGStorage.__init__
where it calls embedding_function(["test"]) to verify the embedder works.
"""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
# Mock a valid embedding response
mock_embedding = [0.1] * 384 # 384 dimensions like all-MiniLM-L6-v2
mock_client.feature_extraction.return_value = mock_embedding
ef = HuggingFaceEmbeddingFunction(api_key="test-key")
# This is what RAGStorage does to validate the embedder
result = ef(["test"])
assert len(result) == 1
assert len(result[0]) == 384
# Values should be numeric (float or numpy float)
assert all(isinstance(x, (int, float)) or hasattr(x, "__float__") for x in result[0])
@patch("huggingface_hub.InferenceClient")
def test_embedding_function_returns_correct_format_for_chromadb(
self, mock_client_class
):
"""Test that embeddings are in the correct format for ChromaDB.
ChromaDB expects embeddings as a sequence of embedding vectors where each
inner element is a 1D embedding vector with numeric values.
"""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_embedding = [0.1, 0.2, 0.3, 0.4, 0.5]
mock_client.feature_extraction.return_value = mock_embedding
ef = HuggingFaceEmbeddingFunction(api_key="test-key")
result = ef(["Hello", "World"])
# ChromaDB expects a sequence of embedding vectors
assert isinstance(result, list)
for embedding in result:
# Each embedding should be a sequence of numeric values
assert len(embedding) == 5
for value in embedding:
# Values should be numeric (float or numpy float)
assert isinstance(value, (int, float)) or hasattr(value, "__float__")