From 1442f3e4b6f8dc51a7003dcbd051b014034ee81c Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Tue, 23 Sep 2025 15:41:56 +0000 Subject: [PATCH] fix: add Watson embedding support to factory MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add Watson to EmbeddingProvider type definition - Implement _create_watson_embedding_function in factory.py - Add Watson to embedding_functions dictionary - Add comprehensive tests for Watson embedding functionality - Ensure proper error handling for missing IBM Watson dependencies Fixes #3582 Co-Authored-By: João --- src/crewai/rag/embeddings/factory.py | 56 ++++++++++++++++ src/crewai/rag/embeddings/types.py | 1 + tests/rag/embeddings/test_factory_enhanced.py | 65 +++++++++++++++++++ 3 files changed, 122 insertions(+) diff --git a/src/crewai/rag/embeddings/factory.py b/src/crewai/rag/embeddings/factory.py index 0b76ef36a..f8e6f0f4c 100644 --- a/src/crewai/rag/embeddings/factory.py +++ b/src/crewai/rag/embeddings/factory.py @@ -46,6 +46,51 @@ from chromadb.utils.embedding_functions.text2vec_embedding_function import ( from crewai.rag.embeddings.types import EmbeddingOptions +def _create_watson_embedding_function(**config_dict) -> EmbeddingFunction: + """Create Watson embedding function with proper error handling.""" + try: + import ibm_watsonx_ai.foundation_models as watson_models # type: ignore[import-not-found] + from ibm_watsonx_ai import Credentials # type: ignore[import-not-found] + from ibm_watsonx_ai.metanames import ( # type: ignore[import-not-found] + EmbedTextParamsMetaNames as EmbedParams, + ) + except ImportError as e: + raise ImportError( + "IBM Watson dependencies are not installed. Please install them to use Watson embedding." + ) from e + + class WatsonEmbeddingFunction(EmbeddingFunction): + def __init__(self, **kwargs): + self.config = kwargs + + def __call__(self, input): + if isinstance(input, str): + input = [input] + + embed_params = { + EmbedParams.TRUNCATE_INPUT_TOKENS: 3, + EmbedParams.RETURN_OPTIONS: {"input_text": True}, + } + + embedding = watson_models.Embeddings( + model_id=self.config.get("model_name") or self.config.get("model"), + params=embed_params, + credentials=Credentials( + api_key=self.config.get("api_key"), + url=self.config.get("api_url") or self.config.get("url") + ), + project_id=self.config.get("project_id"), + ) + + try: + embeddings = embedding.embed_documents(input) + return embeddings + except Exception as e: + raise RuntimeError(f"Error during Watson embedding: {e}") from e + + return WatsonEmbeddingFunction(**config_dict) + + def get_embedding_function( config: EmbeddingOptions | dict | None = None, ) -> EmbeddingFunction: @@ -75,6 +120,7 @@ def get_embedding_function( - openclip: OpenCLIP embeddings for multimodal tasks - text2vec: Text2Vec embeddings - onnx: ONNX MiniLM-L6-v2 (no API key needed, included with ChromaDB) + - watson: IBM Watson embeddings Examples: # Use default OpenAI embedding @@ -108,6 +154,15 @@ def get_embedding_function( >>> embedder = get_embedding_function({ ... "provider": "onnx" ... }) + + # Use Watson embeddings + >>> embedder = get_embedding_function({ + ... "provider": "watson", + ... "api_key": "your-watson-api-key", + ... "api_url": "your-watson-url", + ... "project_id": "your-project-id", + ... "model_name": "ibm/slate-125m-english-rtrvr" + ... }) """ if config is None: return OpenAIEmbeddingFunction( @@ -138,6 +193,7 @@ def get_embedding_function( "openclip": OpenCLIPEmbeddingFunction, "text2vec": Text2VecEmbeddingFunction, "onnx": ONNXMiniLM_L6_V2, + "watson": _create_watson_embedding_function, } if provider not in embedding_functions: diff --git a/src/crewai/rag/embeddings/types.py b/src/crewai/rag/embeddings/types.py index 5024d5513..19ad8e56e 100644 --- a/src/crewai/rag/embeddings/types.py +++ b/src/crewai/rag/embeddings/types.py @@ -22,6 +22,7 @@ EmbeddingProvider = Literal[ "openclip", "text2vec", "onnx", + "watson", ] """Supported embedding providers. diff --git a/tests/rag/embeddings/test_factory_enhanced.py b/tests/rag/embeddings/test_factory_enhanced.py index 489064826..440a3714a 100644 --- a/tests/rag/embeddings/test_factory_enhanced.py +++ b/tests/rag/embeddings/test_factory_enhanced.py @@ -248,3 +248,68 @@ def test_get_embedding_function_instructor() -> None: mock_instructor.assert_called_once_with(model_name="hkunlp/instructor-large") assert result == mock_instance + + +def test_get_embedding_function_watson() -> None: + """Test Watson embedding function.""" + with patch("crewai.rag.embeddings.factory._create_watson_embedding_function") as mock_watson: + mock_instance = MagicMock() + mock_watson.return_value = mock_instance + + config = { + "provider": "watson", + "api_key": "watson-api-key", + "api_url": "https://watson-url.com", + "project_id": "watson-project-id", + "model_name": "ibm/slate-125m-english-rtrvr", + } + + result = get_embedding_function(config) + + mock_watson.assert_called_once_with( + api_key="watson-api-key", + api_url="https://watson-url.com", + project_id="watson-project-id", + model_name="ibm/slate-125m-english-rtrvr", + ) + assert result == mock_instance + + +def test_get_embedding_function_watson_missing_dependencies() -> None: + """Test Watson embedding function with missing dependencies.""" + with patch("crewai.rag.embeddings.factory._create_watson_embedding_function") as mock_watson: + mock_watson.side_effect = ImportError( + "IBM Watson dependencies are not installed. Please install them to use Watson embedding." + ) + + config = { + "provider": "watson", + "api_key": "watson-api-key", + "api_url": "https://watson-url.com", + "project_id": "watson-project-id", + "model_name": "ibm/slate-125m-english-rtrvr", + } + + with pytest.raises(ImportError, match="IBM Watson dependencies are not installed"): + get_embedding_function(config) + + +def test_get_embedding_function_watson_with_embedding_options() -> None: + """Test Watson embedding function with EmbeddingOptions object.""" + with patch("crewai.rag.embeddings.factory._create_watson_embedding_function") as mock_watson: + mock_instance = MagicMock() + mock_watson.return_value = mock_instance + + options = EmbeddingOptions( + provider="watson", + api_key="watson-key", + model_name="ibm/slate-125m-english-rtrvr" + ) + + result = get_embedding_function(options) + + call_kwargs = mock_watson.call_args.kwargs + assert "api_key" in call_kwargs + assert call_kwargs["api_key"].get_secret_value() == "watson-key" + assert call_kwargs["model_name"] == "ibm/slate-125m-english-rtrvr" + assert result == mock_instance