mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-28 17:48:13 +00:00
Compare commits
4 Commits
devin/1768
...
devin/1747
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e23eed7f84 | ||
|
|
6a1d4c1a73 | ||
|
|
0c4bdbf379 | ||
|
|
e4063baca7 |
3
src/crewai/knowledge/embedder/chromadb/__init__.py
Normal file
3
src/crewai/knowledge/embedder/chromadb/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .utils.embedding_functions import VoyageAIEmbeddingFunction
|
||||||
|
|
||||||
|
__all__ = ["VoyageAIEmbeddingFunction"]
|
||||||
3
src/crewai/knowledge/embedder/chromadb/utils/__init__.py
Normal file
3
src/crewai/knowledge/embedder/chromadb/utils/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from . import embedding_functions
|
||||||
|
|
||||||
|
__all__ = ["embedding_functions"]
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
from .voyageai_embedding_function import VoyageAIEmbeddingFunction
|
||||||
|
|
||||||
|
__all__ = ["VoyageAIEmbeddingFunction"]
|
||||||
@@ -0,0 +1,91 @@
|
|||||||
|
import logging
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
from chromadb.api.types import Documents, EmbeddingFunction
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||||
|
"""
|
||||||
|
VoyageAI embedding function for ChromaDB.
|
||||||
|
|
||||||
|
This class provides integration with VoyageAI's embedding models for use with ChromaDB.
|
||||||
|
It supports various VoyageAI models including voyage-3, voyage-3.5, and voyage-3.5-lite.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
_api_key (str): The API key for VoyageAI.
|
||||||
|
_model_name (str): The name of the VoyageAI model to use.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, model_name: str = "voyage-3"):
|
||||||
|
"""
|
||||||
|
Initialize the VoyageAI embedding function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key (str): The API key for VoyageAI.
|
||||||
|
model_name (str, optional): The name of the VoyageAI model to use.
|
||||||
|
Defaults to "voyage-3".
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the voyageai package is not installed or if the API key is empty.
|
||||||
|
"""
|
||||||
|
self._ensure_voyageai_installed()
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
raise ValueError("API key is required for VoyageAI embeddings")
|
||||||
|
|
||||||
|
self._api_key = api_key
|
||||||
|
self._model_name = model_name
|
||||||
|
|
||||||
|
def _ensure_voyageai_installed(self):
|
||||||
|
"""
|
||||||
|
Ensure that the voyageai package is installed.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the voyageai package is not installed.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import voyageai # noqa: F401
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"The voyageai python package is not installed. Please install it with `pip install voyageai`"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, input: Union[str, List[str]]) -> List[List[float]]:
|
||||||
|
"""
|
||||||
|
Generate embeddings for the input text(s).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input (Union[str, List[str]]): The text or list of texts to generate embeddings for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[List[float]]: A list of embeddings, where each embedding is a list of floats.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the input is not a string or list of strings.
|
||||||
|
voyageai.VoyageError: If there is an error with the VoyageAI API.
|
||||||
|
"""
|
||||||
|
self._ensure_voyageai_installed()
|
||||||
|
import voyageai
|
||||||
|
|
||||||
|
if not input:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if not isinstance(input, (str, list)):
|
||||||
|
raise ValueError("Input must be a string or a list of strings")
|
||||||
|
|
||||||
|
if isinstance(input, str):
|
||||||
|
input = [input]
|
||||||
|
|
||||||
|
try:
|
||||||
|
embeddings = voyageai.get_embeddings(
|
||||||
|
input, model=self._model_name, api_key=self._api_key
|
||||||
|
)
|
||||||
|
return embeddings
|
||||||
|
except voyageai.VoyageError as e:
|
||||||
|
logger.error(f"VoyageAI API error: {e}")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error during VoyageAI embedding: {e}")
|
||||||
|
raise
|
||||||
@@ -140,7 +140,7 @@ class EmbeddingConfigurator:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_voyageai(config, model_name):
|
def _configure_voyageai(config, model_name):
|
||||||
from chromadb.utils.embedding_functions.voyageai_embedding_function import (
|
from crewai.knowledge.embedder.chromadb.utils.embedding_functions.voyageai_embedding_function import (
|
||||||
VoyageAIEmbeddingFunction,
|
VoyageAIEmbeddingFunction,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
80
tests/utilities/test_embedding_configurator.py
Normal file
80
tests/utilities/test_embedding_configurator.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
from crewai.utilities.embedding_configurator import EmbeddingConfigurator
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmbeddingConfigurator:
|
||||||
|
def test_configure_voyageai_embedder(self):
|
||||||
|
"""Test that the VoyageAI embedder is configured correctly."""
|
||||||
|
with patch(
|
||||||
|
"crewai.utilities.embedding_configurator.VoyageAIEmbeddingFunction"
|
||||||
|
) as mock_voyageai:
|
||||||
|
mock_instance = MagicMock()
|
||||||
|
mock_voyageai.return_value = mock_instance
|
||||||
|
|
||||||
|
config = {"api_key": "test-key"}
|
||||||
|
model_name = "voyage-3"
|
||||||
|
|
||||||
|
configurator = EmbeddingConfigurator()
|
||||||
|
embedder = configurator._configure_voyageai(config, model_name)
|
||||||
|
|
||||||
|
mock_voyageai.assert_called_once_with(
|
||||||
|
model_name=model_name, api_key="test-key"
|
||||||
|
)
|
||||||
|
assert embedder == mock_instance
|
||||||
|
|
||||||
|
def test_configure_embedder_with_voyageai(self):
|
||||||
|
"""Test that the embedder configurator correctly handles VoyageAI provider."""
|
||||||
|
with patch(
|
||||||
|
"crewai.utilities.embedding_configurator.VoyageAIEmbeddingFunction"
|
||||||
|
) as mock_voyageai:
|
||||||
|
mock_instance = MagicMock()
|
||||||
|
mock_voyageai.return_value = mock_instance
|
||||||
|
|
||||||
|
embedder_config = {
|
||||||
|
"provider": "voyageai",
|
||||||
|
"config": {"api_key": "test-key", "model": "voyage-3"},
|
||||||
|
}
|
||||||
|
|
||||||
|
configurator = EmbeddingConfigurator()
|
||||||
|
embedder = configurator.configure_embedder(embedder_config)
|
||||||
|
|
||||||
|
mock_voyageai.assert_called_once_with(
|
||||||
|
model_name="voyage-3", api_key="test-key"
|
||||||
|
)
|
||||||
|
assert embedder == mock_instance
|
||||||
|
|
||||||
|
def test_configure_voyageai_embedder_missing_api_key(self):
|
||||||
|
"""Test that the VoyageAI embedder raises an error when API key is missing."""
|
||||||
|
with patch(
|
||||||
|
"crewai.utilities.embedding_configurator.VoyageAIEmbeddingFunction"
|
||||||
|
) as mock_voyageai:
|
||||||
|
mock_voyageai.side_effect = ValueError("API key is required for VoyageAI embeddings")
|
||||||
|
|
||||||
|
config = {} # Empty config without API key
|
||||||
|
model_name = "voyage-3"
|
||||||
|
|
||||||
|
configurator = EmbeddingConfigurator()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="API key is required"):
|
||||||
|
configurator._configure_voyageai(config, model_name)
|
||||||
|
|
||||||
|
def test_configure_voyageai_embedder_custom_model(self):
|
||||||
|
"""Test that the VoyageAI embedder works with different model names."""
|
||||||
|
with patch(
|
||||||
|
"crewai.utilities.embedding_configurator.VoyageAIEmbeddingFunction"
|
||||||
|
) as mock_voyageai:
|
||||||
|
mock_instance = MagicMock()
|
||||||
|
mock_voyageai.return_value = mock_instance
|
||||||
|
|
||||||
|
config = {"api_key": "test-key"}
|
||||||
|
model_name = "voyage-3.5-lite" # Using a different model
|
||||||
|
|
||||||
|
configurator = EmbeddingConfigurator()
|
||||||
|
embedder = configurator._configure_voyageai(config, model_name)
|
||||||
|
|
||||||
|
mock_voyageai.assert_called_once_with(
|
||||||
|
model_name=model_name, api_key="test-key"
|
||||||
|
)
|
||||||
|
assert embedder == mock_instance
|
||||||
Reference in New Issue
Block a user