Compare commits

...

4 Commits

Author SHA1 Message Date
Devin AI
e23eed7f84 Fix lint errors: Remove unused imports
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-05-14 19:26:45 +00:00
Devin AI
6a1d4c1a73 Address PR feedback: Improve code quality and test coverage
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-05-14 19:24:55 +00:00
Devin AI
0c4bdbf379 Fix lint errors: Remove unused imports
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-05-14 19:20:29 +00:00
Devin AI
e4063baca7 Fix issue #2832: Add VoyageAI embedding function implementation
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-05-14 19:17:36 +00:00
6 changed files with 181 additions and 1 deletions

View File

@@ -0,0 +1,3 @@
from .utils.embedding_functions import VoyageAIEmbeddingFunction
__all__ = ["VoyageAIEmbeddingFunction"]

View File

@@ -0,0 +1,3 @@
from . import embedding_functions
__all__ = ["embedding_functions"]

View File

@@ -0,0 +1,3 @@
from .voyageai_embedding_function import VoyageAIEmbeddingFunction
__all__ = ["VoyageAIEmbeddingFunction"]

View File

@@ -0,0 +1,91 @@
import logging
from typing import List, Union
from chromadb.api.types import Documents, EmbeddingFunction
logger = logging.getLogger(__name__)
class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]):
"""
VoyageAI embedding function for ChromaDB.
This class provides integration with VoyageAI's embedding models for use with ChromaDB.
It supports various VoyageAI models including voyage-3, voyage-3.5, and voyage-3.5-lite.
Attributes:
_api_key (str): The API key for VoyageAI.
_model_name (str): The name of the VoyageAI model to use.
"""
def __init__(self, api_key: str, model_name: str = "voyage-3"):
"""
Initialize the VoyageAI embedding function.
Args:
api_key (str): The API key for VoyageAI.
model_name (str, optional): The name of the VoyageAI model to use.
Defaults to "voyage-3".
Raises:
ValueError: If the voyageai package is not installed or if the API key is empty.
"""
self._ensure_voyageai_installed()
if not api_key:
raise ValueError("API key is required for VoyageAI embeddings")
self._api_key = api_key
self._model_name = model_name
def _ensure_voyageai_installed(self):
"""
Ensure that the voyageai package is installed.
Raises:
ValueError: If the voyageai package is not installed.
"""
try:
import voyageai # noqa: F401
except ImportError:
raise ValueError(
"The voyageai python package is not installed. Please install it with `pip install voyageai`"
)
def __call__(self, input: Union[str, List[str]]) -> List[List[float]]:
"""
Generate embeddings for the input text(s).
Args:
input (Union[str, List[str]]): The text or list of texts to generate embeddings for.
Returns:
List[List[float]]: A list of embeddings, where each embedding is a list of floats.
Raises:
ValueError: If the input is not a string or list of strings.
voyageai.VoyageError: If there is an error with the VoyageAI API.
"""
self._ensure_voyageai_installed()
import voyageai
if not input:
return []
if not isinstance(input, (str, list)):
raise ValueError("Input must be a string or a list of strings")
if isinstance(input, str):
input = [input]
try:
embeddings = voyageai.get_embeddings(
input, model=self._model_name, api_key=self._api_key
)
return embeddings
except voyageai.VoyageError as e:
logger.error(f"VoyageAI API error: {e}")
raise
except Exception as e:
logger.error(f"Unexpected error during VoyageAI embedding: {e}")
raise

View File

@@ -140,7 +140,7 @@ class EmbeddingConfigurator:
@staticmethod
def _configure_voyageai(config, model_name):
from chromadb.utils.embedding_functions.voyageai_embedding_function import (
from crewai.knowledge.embedder.chromadb.utils.embedding_functions.voyageai_embedding_function import (
VoyageAIEmbeddingFunction,
)

View File

@@ -0,0 +1,80 @@
import pytest
from unittest.mock import patch, MagicMock
from crewai.utilities.embedding_configurator import EmbeddingConfigurator
class TestEmbeddingConfigurator:
def test_configure_voyageai_embedder(self):
"""Test that the VoyageAI embedder is configured correctly."""
with patch(
"crewai.utilities.embedding_configurator.VoyageAIEmbeddingFunction"
) as mock_voyageai:
mock_instance = MagicMock()
mock_voyageai.return_value = mock_instance
config = {"api_key": "test-key"}
model_name = "voyage-3"
configurator = EmbeddingConfigurator()
embedder = configurator._configure_voyageai(config, model_name)
mock_voyageai.assert_called_once_with(
model_name=model_name, api_key="test-key"
)
assert embedder == mock_instance
def test_configure_embedder_with_voyageai(self):
"""Test that the embedder configurator correctly handles VoyageAI provider."""
with patch(
"crewai.utilities.embedding_configurator.VoyageAIEmbeddingFunction"
) as mock_voyageai:
mock_instance = MagicMock()
mock_voyageai.return_value = mock_instance
embedder_config = {
"provider": "voyageai",
"config": {"api_key": "test-key", "model": "voyage-3"},
}
configurator = EmbeddingConfigurator()
embedder = configurator.configure_embedder(embedder_config)
mock_voyageai.assert_called_once_with(
model_name="voyage-3", api_key="test-key"
)
assert embedder == mock_instance
def test_configure_voyageai_embedder_missing_api_key(self):
"""Test that the VoyageAI embedder raises an error when API key is missing."""
with patch(
"crewai.utilities.embedding_configurator.VoyageAIEmbeddingFunction"
) as mock_voyageai:
mock_voyageai.side_effect = ValueError("API key is required for VoyageAI embeddings")
config = {} # Empty config without API key
model_name = "voyage-3"
configurator = EmbeddingConfigurator()
with pytest.raises(ValueError, match="API key is required"):
configurator._configure_voyageai(config, model_name)
def test_configure_voyageai_embedder_custom_model(self):
"""Test that the VoyageAI embedder works with different model names."""
with patch(
"crewai.utilities.embedding_configurator.VoyageAIEmbeddingFunction"
) as mock_voyageai:
mock_instance = MagicMock()
mock_voyageai.return_value = mock_instance
config = {"api_key": "test-key"}
model_name = "voyage-3.5-lite" # Using a different model
configurator = EmbeddingConfigurator()
embedder = configurator._configure_voyageai(config, model_name)
mock_voyageai.assert_called_once_with(
model_name=model_name, api_key="test-key"
)
assert embedder == mock_instance