mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-26 00:28:13 +00:00
fix: add Watson embedding support to factory
- Add Watson to EmbeddingProvider type definition - Implement _create_watson_embedding_function in factory.py - Add Watson to embedding_functions dictionary - Add comprehensive tests for Watson embedding functionality - Ensure proper error handling for missing IBM Watson dependencies Fixes #3582 Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -248,3 +248,68 @@ def test_get_embedding_function_instructor() -> None:
|
||||
|
||||
mock_instructor.assert_called_once_with(model_name="hkunlp/instructor-large")
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_watson() -> None:
|
||||
"""Test Watson embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory._create_watson_embedding_function") as mock_watson:
|
||||
mock_instance = MagicMock()
|
||||
mock_watson.return_value = mock_instance
|
||||
|
||||
config = {
|
||||
"provider": "watson",
|
||||
"api_key": "watson-api-key",
|
||||
"api_url": "https://watson-url.com",
|
||||
"project_id": "watson-project-id",
|
||||
"model_name": "ibm/slate-125m-english-rtrvr",
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_watson.assert_called_once_with(
|
||||
api_key="watson-api-key",
|
||||
api_url="https://watson-url.com",
|
||||
project_id="watson-project-id",
|
||||
model_name="ibm/slate-125m-english-rtrvr",
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_watson_missing_dependencies() -> None:
|
||||
"""Test Watson embedding function with missing dependencies."""
|
||||
with patch("crewai.rag.embeddings.factory._create_watson_embedding_function") as mock_watson:
|
||||
mock_watson.side_effect = ImportError(
|
||||
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
|
||||
)
|
||||
|
||||
config = {
|
||||
"provider": "watson",
|
||||
"api_key": "watson-api-key",
|
||||
"api_url": "https://watson-url.com",
|
||||
"project_id": "watson-project-id",
|
||||
"model_name": "ibm/slate-125m-english-rtrvr",
|
||||
}
|
||||
|
||||
with pytest.raises(ImportError, match="IBM Watson dependencies are not installed"):
|
||||
get_embedding_function(config)
|
||||
|
||||
|
||||
def test_get_embedding_function_watson_with_embedding_options() -> None:
|
||||
"""Test Watson embedding function with EmbeddingOptions object."""
|
||||
with patch("crewai.rag.embeddings.factory._create_watson_embedding_function") as mock_watson:
|
||||
mock_instance = MagicMock()
|
||||
mock_watson.return_value = mock_instance
|
||||
|
||||
options = EmbeddingOptions(
|
||||
provider="watson",
|
||||
api_key="watson-key",
|
||||
model_name="ibm/slate-125m-english-rtrvr"
|
||||
)
|
||||
|
||||
result = get_embedding_function(options)
|
||||
|
||||
call_kwargs = mock_watson.call_args.kwargs
|
||||
assert "api_key" in call_kwargs
|
||||
assert call_kwargs["api_key"].get_secret_value() == "watson-key"
|
||||
assert call_kwargs["model_name"] == "ibm/slate-125m-english-rtrvr"
|
||||
assert result == mock_instance
|
||||
|
||||
Reference in New Issue
Block a user