diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/huggingface/__init__.py b/lib/crewai/src/crewai/rag/embeddings/providers/huggingface/__init__.py index 36cf86f17..6c74540f2 100644 --- a/lib/crewai/src/crewai/rag/embeddings/providers/huggingface/__init__.py +++ b/lib/crewai/src/crewai/rag/embeddings/providers/huggingface/__init__.py @@ -1,5 +1,8 @@ """HuggingFace embedding providers.""" +from crewai.rag.embeddings.providers.huggingface.embedding_callable import ( + HuggingFaceEmbeddingFunction, +) from crewai.rag.embeddings.providers.huggingface.huggingface_provider import ( HuggingFaceProvider, ) @@ -10,6 +13,7 @@ from crewai.rag.embeddings.providers.huggingface.types import ( __all__ = [ + "HuggingFaceEmbeddingFunction", "HuggingFaceProvider", "HuggingFaceProviderConfig", "HuggingFaceProviderSpec", diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/huggingface/embedding_callable.py b/lib/crewai/src/crewai/rag/embeddings/providers/huggingface/embedding_callable.py new file mode 100644 index 000000000..48b35e285 --- /dev/null +++ b/lib/crewai/src/crewai/rag/embeddings/providers/huggingface/embedding_callable.py @@ -0,0 +1,158 @@ +"""HuggingFace embedding function implementation using huggingface_hub.""" + +from typing import Any + +from chromadb.api.types import Documents, EmbeddingFunction, Embeddings +import numpy as np +from typing_extensions import Unpack + +from crewai.rag.embeddings.providers.huggingface.types import HuggingFaceProviderConfig + + +class HuggingFaceEmbeddingFunction(EmbeddingFunction[Documents]): + """Embedding function for HuggingFace models using the Inference API. + + This implementation uses huggingface_hub's InferenceClient instead of the + deprecated api-inference.huggingface.co endpoint that chromadb uses. + """ + + def __init__(self, **kwargs: Unpack[HuggingFaceProviderConfig]) -> None: + """Initialize HuggingFace embedding function. + + Args: + **kwargs: Configuration parameters for HuggingFace. + - api_key: HuggingFace API key (optional for public models) + - model_name: Model name to use for embeddings + """ + try: + from huggingface_hub import InferenceClient + except ImportError as e: + raise ImportError( + "huggingface_hub is required for HuggingFace embeddings. " + "Install it with: uv add huggingface_hub" + ) from e + + self._config = kwargs + self._model_name = kwargs.get( + "model_name", "sentence-transformers/all-MiniLM-L6-v2" + ) + api_key = kwargs.get("api_key") + + self._client = InferenceClient( + provider="hf-inference", + token=api_key, + ) + + @staticmethod + def name() -> str: + """Return the name of the embedding function for ChromaDB compatibility.""" + return "huggingface" + + def __call__(self, input: Documents) -> Embeddings: + """Generate embeddings for input documents. + + Args: + input: List of documents to embed. + + Returns: + List of embedding vectors. + + Raises: + ValueError: If the API returns an error or unexpected response format. + """ + if isinstance(input, str): + input = [input] + + embeddings: list[list[float]] = [] + + for text in input: + embedding = self._get_embedding_for_text(text) + embeddings.append(embedding) + + return embeddings + + def _get_embedding_for_text(self, text: str) -> list[float]: + """Get embedding for a single text. + + Args: + text: The text to embed. + + Returns: + The embedding vector. + + Raises: + ValueError: If the API returns an error. + """ + try: + result = self._client.feature_extraction( + text=text, + model=self._model_name, + ) + + # Handle different response formats + return self._process_embedding_result(result) + + except Exception as e: + error_msg = str(e) + # Provide more helpful error messages for common issues + if "deprecated" in error_msg.lower() or "no longer supported" in error_msg.lower(): + raise ValueError( + f"HuggingFace API endpoint error: {error_msg}. " + "Please ensure you have the latest version of huggingface_hub installed." + ) from e + if "unauthorized" in error_msg.lower() or "401" in error_msg: + raise ValueError( + f"HuggingFace API authentication error: {error_msg}. " + "Please check your API key configuration." + ) from e + if "not found" in error_msg.lower() or "404" in error_msg: + raise ValueError( + f"HuggingFace model not found: {error_msg}. " + f"Please verify the model name '{self._model_name}' is correct " + "and supports feature extraction." + ) from e + raise ValueError(f"HuggingFace API error: {error_msg}") from e + + def _process_embedding_result(self, result: Any) -> list[float]: + """Process the embedding result from the API. + + The HuggingFace API can return different formats depending on the model: + - 1D array: Direct embedding vector + - 2D array: Token-level embeddings (needs pooling) + - Nested structure: Various model-specific formats + + Args: + result: The raw result from the API. + + Returns: + A 1D list of floats representing the embedding. + + Raises: + ValueError: If the result format is unexpected. + """ + # Convert to numpy array for easier processing + arr = np.array(result) + + # Handle different dimensionalities + if arr.ndim == 1: + # Already a 1D embedding vector + return arr.astype(np.float32).tolist() + if arr.ndim == 2: + # Token-level embeddings - apply mean pooling + pooled = np.mean(arr, axis=0) + return pooled.astype(np.float32).tolist() + if arr.ndim == 3: + # Batch of token-level embeddings - take first and apply mean pooling + pooled = np.mean(arr[0], axis=0) + return pooled.astype(np.float32).tolist() + raise ValueError( + f"Unexpected embedding result shape: {arr.shape}. " + "Expected 1D, 2D, or 3D array." + ) + + def get_config(self) -> dict[str, Any]: + """Return the configuration for serialization.""" + return { + "model_name": self._model_name, + "api_key": self._config.get("api_key"), + } diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/huggingface/huggingface_provider.py b/lib/crewai/src/crewai/rag/embeddings/providers/huggingface/huggingface_provider.py index 8dc32b1f1..b62351871 100644 --- a/lib/crewai/src/crewai/rag/embeddings/providers/huggingface/huggingface_provider.py +++ b/lib/crewai/src/crewai/rag/embeddings/providers/huggingface/huggingface_provider.py @@ -1,11 +1,11 @@ """HuggingFace embeddings provider.""" -from chromadb.utils.embedding_functions.huggingface_embedding_function import ( - HuggingFaceEmbeddingFunction, -) from pydantic import AliasChoices, Field from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider +from crewai.rag.embeddings.providers.huggingface.embedding_callable import ( + HuggingFaceEmbeddingFunction, +) class HuggingFaceProvider(BaseEmbeddingsProvider[HuggingFaceEmbeddingFunction]): diff --git a/lib/crewai/tests/rag/embeddings/test_huggingface_embedder.py b/lib/crewai/tests/rag/embeddings/test_huggingface_embedder.py new file mode 100644 index 000000000..34232429e --- /dev/null +++ b/lib/crewai/tests/rag/embeddings/test_huggingface_embedder.py @@ -0,0 +1,290 @@ +"""Tests for HuggingFace embedding function.""" + +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest + +from crewai.rag.embeddings.providers.huggingface.embedding_callable import ( + HuggingFaceEmbeddingFunction, +) + + +class TestHuggingFaceEmbeddingFunction: + """Test HuggingFace embedding function.""" + + @patch("huggingface_hub.InferenceClient") + def test_initialization_with_api_key(self, mock_client_class): + """Test initialization with API key.""" + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + ef = HuggingFaceEmbeddingFunction( + api_key="test-api-key", + model_name="sentence-transformers/all-MiniLM-L6-v2", + ) + + mock_client_class.assert_called_once_with( + provider="hf-inference", + token="test-api-key", + ) + assert ef._model_name == "sentence-transformers/all-MiniLM-L6-v2" + + @patch("huggingface_hub.InferenceClient") + def test_initialization_without_api_key(self, mock_client_class): + """Test initialization without API key (for public models).""" + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + ef = HuggingFaceEmbeddingFunction( + model_name="sentence-transformers/all-MiniLM-L6-v2", + ) + + mock_client_class.assert_called_once_with( + provider="hf-inference", + token=None, + ) + assert ef._model_name == "sentence-transformers/all-MiniLM-L6-v2" + + @patch("huggingface_hub.InferenceClient") + def test_initialization_with_default_model(self, mock_client_class): + """Test initialization with default model name.""" + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + ef = HuggingFaceEmbeddingFunction() + + assert ef._model_name == "sentence-transformers/all-MiniLM-L6-v2" + + @patch("huggingface_hub.InferenceClient") + def test_call_with_single_document(self, mock_client_class): + """Test embedding generation for a single document.""" + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + # Mock the feature_extraction response (1D embedding) + mock_embedding = [0.1, 0.2, 0.3, 0.4, 0.5] + mock_client.feature_extraction.return_value = mock_embedding + + ef = HuggingFaceEmbeddingFunction(api_key="test-key") + result = ef(["Hello, world!"]) + + mock_client.feature_extraction.assert_called_once_with( + text="Hello, world!", + model="sentence-transformers/all-MiniLM-L6-v2", + ) + assert len(result) == 1 + assert result[0] == pytest.approx(mock_embedding, rel=1e-5) + + @patch("huggingface_hub.InferenceClient") + def test_call_with_multiple_documents(self, mock_client_class): + """Test embedding generation for multiple documents.""" + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + # Mock the feature_extraction response + mock_embedding1 = [0.1, 0.2, 0.3] + mock_embedding2 = [0.4, 0.5, 0.6] + mock_client.feature_extraction.side_effect = [mock_embedding1, mock_embedding2] + + ef = HuggingFaceEmbeddingFunction(api_key="test-key") + result = ef(["Hello", "World"]) + + assert mock_client.feature_extraction.call_count == 2 + assert len(result) == 2 + assert result[0] == pytest.approx(mock_embedding1, rel=1e-5) + assert result[1] == pytest.approx(mock_embedding2, rel=1e-5) + + @patch("huggingface_hub.InferenceClient") + def test_call_with_string_input(self, mock_client_class): + """Test that string input is converted to list.""" + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + mock_embedding = [0.1, 0.2, 0.3] + mock_client.feature_extraction.return_value = mock_embedding + + ef = HuggingFaceEmbeddingFunction(api_key="test-key") + result = ef("Hello") # type: ignore[arg-type] + + mock_client.feature_extraction.assert_called_once() + assert len(result) == 1 + + @patch("huggingface_hub.InferenceClient") + def test_process_2d_embedding_result(self, mock_client_class): + """Test processing of 2D token-level embeddings (mean pooling).""" + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + # Mock 2D token-level embeddings (3 tokens, 4 dimensions each) + mock_token_embeddings = [ + [0.1, 0.2, 0.3, 0.4], + [0.2, 0.3, 0.4, 0.5], + [0.3, 0.4, 0.5, 0.6], + ] + mock_client.feature_extraction.return_value = mock_token_embeddings + + ef = HuggingFaceEmbeddingFunction(api_key="test-key") + result = ef(["Hello"]) + + # Expected: mean pooling across tokens + expected = np.mean(mock_token_embeddings, axis=0).tolist() + assert len(result) == 1 + assert result[0] == pytest.approx(expected, rel=1e-5) + + @patch("huggingface_hub.InferenceClient") + def test_process_3d_embedding_result(self, mock_client_class): + """Test processing of 3D batch token-level embeddings.""" + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + # Mock 3D embeddings (1 batch, 3 tokens, 4 dimensions) + mock_batch_embeddings = [ + [ + [0.1, 0.2, 0.3, 0.4], + [0.2, 0.3, 0.4, 0.5], + [0.3, 0.4, 0.5, 0.6], + ] + ] + mock_client.feature_extraction.return_value = mock_batch_embeddings + + ef = HuggingFaceEmbeddingFunction(api_key="test-key") + result = ef(["Hello"]) + + # Expected: take first batch, then mean pooling + expected = np.mean(mock_batch_embeddings[0], axis=0).tolist() + assert len(result) == 1 + assert result[0] == pytest.approx(expected, rel=1e-5) + + @patch("huggingface_hub.InferenceClient") + def test_error_handling_deprecated_endpoint(self, mock_client_class): + """Test error handling for deprecated endpoint error.""" + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + mock_client.feature_extraction.side_effect = Exception( + "https://api-inference.huggingface.co is no longer supported" + ) + + ef = HuggingFaceEmbeddingFunction(api_key="test-key") + + with pytest.raises(ValueError, match="HuggingFace API endpoint error"): + ef(["Hello"]) + + @patch("huggingface_hub.InferenceClient") + def test_error_handling_unauthorized(self, mock_client_class): + """Test error handling for authentication error.""" + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + mock_client.feature_extraction.side_effect = Exception("401 Unauthorized") + + ef = HuggingFaceEmbeddingFunction(api_key="invalid-key") + + with pytest.raises(ValueError, match="HuggingFace API authentication error"): + ef(["Hello"]) + + @patch("huggingface_hub.InferenceClient") + def test_error_handling_model_not_found(self, mock_client_class): + """Test error handling for model not found error.""" + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + mock_client.feature_extraction.side_effect = Exception("404 Not Found") + + ef = HuggingFaceEmbeddingFunction( + api_key="test-key", model_name="nonexistent/model" + ) + + with pytest.raises(ValueError, match="HuggingFace model not found"): + ef(["Hello"]) + + @patch("huggingface_hub.InferenceClient") + def test_error_handling_generic_error(self, mock_client_class): + """Test error handling for generic API error.""" + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + mock_client.feature_extraction.side_effect = Exception("Some unexpected error") + + ef = HuggingFaceEmbeddingFunction(api_key="test-key") + + with pytest.raises(ValueError, match="HuggingFace API error"): + ef(["Hello"]) + + @patch("huggingface_hub.InferenceClient") + def test_name_method(self, mock_client_class): + """Test the name() static method.""" + assert HuggingFaceEmbeddingFunction.name() == "huggingface" + + @patch("huggingface_hub.InferenceClient") + def test_get_config(self, mock_client_class): + """Test get_config method.""" + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + ef = HuggingFaceEmbeddingFunction( + api_key="test-key", + model_name="custom/model", + ) + + config = ef.get_config() + assert config["model_name"] == "custom/model" + assert config["api_key"] == "test-key" + + +class TestHuggingFaceEmbeddingFunctionIntegration: + """Integration tests for HuggingFace embedding function with RAGStorage.""" + + @patch("huggingface_hub.InferenceClient") + def test_embedding_function_works_with_rag_storage_validation( + self, mock_client_class + ): + """Test that the embedding function works with RAGStorage validation. + + This test simulates the validation that happens in RAGStorage.__init__ + where it calls embedding_function(["test"]) to verify the embedder works. + """ + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + # Mock a valid embedding response + mock_embedding = [0.1] * 384 # 384 dimensions like all-MiniLM-L6-v2 + mock_client.feature_extraction.return_value = mock_embedding + + ef = HuggingFaceEmbeddingFunction(api_key="test-key") + + # This is what RAGStorage does to validate the embedder + result = ef(["test"]) + + assert len(result) == 1 + assert len(result[0]) == 384 + # Values should be numeric (float or numpy float) + assert all(isinstance(x, (int, float)) or hasattr(x, "__float__") for x in result[0]) + + @patch("huggingface_hub.InferenceClient") + def test_embedding_function_returns_correct_format_for_chromadb( + self, mock_client_class + ): + """Test that embeddings are in the correct format for ChromaDB. + + ChromaDB expects embeddings as a sequence of embedding vectors where each + inner element is a 1D embedding vector with numeric values. + """ + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + mock_embedding = [0.1, 0.2, 0.3, 0.4, 0.5] + mock_client.feature_extraction.return_value = mock_embedding + + ef = HuggingFaceEmbeddingFunction(api_key="test-key") + result = ef(["Hello", "World"]) + + # ChromaDB expects a sequence of embedding vectors + assert isinstance(result, list) + for embedding in result: + # Each embedding should be a sequence of numeric values + assert len(embedding) == 5 + for value in embedding: + # Values should be numeric (float or numpy float) + assert isinstance(value, (int, float)) or hasattr(value, "__float__")