mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-01 15:22:37 +00:00
fix: use HuggingFaceEmbeddingFunction for embeddings, update keys and add tests (#4005)
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Check Documentation Broken Links / Check broken links (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Check Documentation Broken Links / Check broken links (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
This commit is contained in:
@@ -515,8 +515,7 @@ crew = Crew(
|
|||||||
"provider": "huggingface",
|
"provider": "huggingface",
|
||||||
"config": {
|
"config": {
|
||||||
"api_key": "your-hf-token", # Optional for public models
|
"api_key": "your-hf-token", # Optional for public models
|
||||||
"model": "sentence-transformers/all-MiniLM-L6-v2",
|
"model": "sentence-transformers/all-MiniLM-L6-v2"
|
||||||
"api_url": "https://api-inference.huggingface.co" # or your custom endpoint
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -515,8 +515,7 @@ crew = Crew(
|
|||||||
"provider": "huggingface",
|
"provider": "huggingface",
|
||||||
"config": {
|
"config": {
|
||||||
"api_key": "your-hf-token", # Optional for public models
|
"api_key": "your-hf-token", # Optional for public models
|
||||||
"model": "sentence-transformers/all-MiniLM-L6-v2",
|
"model": "sentence-transformers/all-MiniLM-L6-v2"
|
||||||
"api_url": "https://api-inference.huggingface.co" # or your custom endpoint
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -515,8 +515,7 @@ crew = Crew(
|
|||||||
"provider": "huggingface",
|
"provider": "huggingface",
|
||||||
"config": {
|
"config": {
|
||||||
"api_key": "your-hf-token", # Opcional para modelos públicos
|
"api_key": "your-hf-token", # Opcional para modelos públicos
|
||||||
"model": "sentence-transformers/all-MiniLM-L6-v2",
|
"model": "sentence-transformers/all-MiniLM-L6-v2"
|
||||||
"api_url": "https://api-inference.huggingface.co" # ou seu endpoint customizado
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,21 +1,35 @@
|
|||||||
"""HuggingFace embeddings provider."""
|
"""HuggingFace embeddings provider."""
|
||||||
|
|
||||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||||
HuggingFaceEmbeddingServer,
|
HuggingFaceEmbeddingFunction,
|
||||||
)
|
)
|
||||||
from pydantic import AliasChoices, Field
|
from pydantic import AliasChoices, Field
|
||||||
|
|
||||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||||
|
|
||||||
|
|
||||||
class HuggingFaceProvider(BaseEmbeddingsProvider[HuggingFaceEmbeddingServer]):
|
class HuggingFaceProvider(BaseEmbeddingsProvider[HuggingFaceEmbeddingFunction]):
|
||||||
"""HuggingFace embeddings provider."""
|
"""HuggingFace embeddings provider for the HuggingFace Inference API."""
|
||||||
|
|
||||||
embedding_callable: type[HuggingFaceEmbeddingServer] = Field(
|
embedding_callable: type[HuggingFaceEmbeddingFunction] = Field(
|
||||||
default=HuggingFaceEmbeddingServer,
|
default=HuggingFaceEmbeddingFunction,
|
||||||
description="HuggingFace embedding function class",
|
description="HuggingFace embedding function class",
|
||||||
)
|
)
|
||||||
url: str = Field(
|
api_key: str | None = Field(
|
||||||
description="HuggingFace API URL",
|
default=None,
|
||||||
validation_alias=AliasChoices("EMBEDDINGS_HUGGINGFACE_URL", "HUGGINGFACE_URL"),
|
description="HuggingFace API key",
|
||||||
|
validation_alias=AliasChoices(
|
||||||
|
"EMBEDDINGS_HUGGINGFACE_API_KEY",
|
||||||
|
"HUGGINGFACE_API_KEY",
|
||||||
|
"HF_TOKEN",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
model_name: str = Field(
|
||||||
|
default="sentence-transformers/all-MiniLM-L6-v2",
|
||||||
|
description="Model name to use for embeddings",
|
||||||
|
validation_alias=AliasChoices(
|
||||||
|
"EMBEDDINGS_HUGGINGFACE_MODEL_NAME",
|
||||||
|
"HUGGINGFACE_MODEL_NAME",
|
||||||
|
"model",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
"""Type definitions for HuggingFace embedding providers."""
|
"""Type definitions for HuggingFace embedding providers."""
|
||||||
|
|
||||||
from typing import Literal
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
from typing_extensions import Required, TypedDict
|
from typing_extensions import Required, TypedDict
|
||||||
|
|
||||||
@@ -8,7 +8,11 @@ from typing_extensions import Required, TypedDict
|
|||||||
class HuggingFaceProviderConfig(TypedDict, total=False):
|
class HuggingFaceProviderConfig(TypedDict, total=False):
|
||||||
"""Configuration for HuggingFace provider."""
|
"""Configuration for HuggingFace provider."""
|
||||||
|
|
||||||
url: str
|
api_key: str
|
||||||
|
model: Annotated[
|
||||||
|
str, "sentence-transformers/all-MiniLM-L6-v2"
|
||||||
|
] # alias for model_name for backward compat
|
||||||
|
model_name: Annotated[str, "sentence-transformers/all-MiniLM-L6-v2"]
|
||||||
|
|
||||||
|
|
||||||
class HuggingFaceProviderSpec(TypedDict, total=False):
|
class HuggingFaceProviderSpec(TypedDict, total=False):
|
||||||
|
|||||||
@@ -99,6 +99,36 @@ class TestEmbeddingFactory:
|
|||||||
"crewai.rag.embeddings.providers.ollama.ollama_provider.OllamaProvider"
|
"crewai.rag.embeddings.providers.ollama.ollama_provider.OllamaProvider"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||||
|
def test_build_embedder_huggingface(self, mock_import):
|
||||||
|
"""Test building HuggingFace embedder."""
|
||||||
|
mock_provider_class = MagicMock()
|
||||||
|
mock_provider_instance = MagicMock()
|
||||||
|
mock_embedding_function = MagicMock()
|
||||||
|
|
||||||
|
mock_import.return_value = mock_provider_class
|
||||||
|
mock_provider_class.return_value = mock_provider_instance
|
||||||
|
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"provider": "huggingface",
|
||||||
|
"config": {
|
||||||
|
"api_key": "hf-test-key",
|
||||||
|
"model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
build_embedder(config)
|
||||||
|
|
||||||
|
mock_import.assert_called_once_with(
|
||||||
|
"crewai.rag.embeddings.providers.huggingface.huggingface_provider.HuggingFaceProvider"
|
||||||
|
)
|
||||||
|
mock_provider_class.assert_called_once()
|
||||||
|
|
||||||
|
call_kwargs = mock_provider_class.call_args.kwargs
|
||||||
|
assert call_kwargs["api_key"] == "hf-test-key"
|
||||||
|
assert call_kwargs["model"] == "sentence-transformers/all-MiniLM-L6-v2"
|
||||||
|
|
||||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||||
def test_build_embedder_cohere(self, mock_import):
|
def test_build_embedder_cohere(self, mock_import):
|
||||||
"""Test building Cohere embedder."""
|
"""Test building Cohere embedder."""
|
||||||
|
|||||||
Reference in New Issue
Block a user