fix: add Watson embedding support to factory

- Add Watson to EmbeddingProvider type definition
- Implement _create_watson_embedding_function in factory.py
- Add Watson to embedding_functions dictionary
- Add comprehensive tests for Watson embedding functionality
- Ensure proper error handling for missing IBM Watson dependencies

Fixes #3582

Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
Devin AI
2025-09-23 15:41:56 +00:00
parent 3e97393f58
commit 1442f3e4b6
3 changed files with 122 additions and 0 deletions

View File

@@ -46,6 +46,51 @@ from chromadb.utils.embedding_functions.text2vec_embedding_function import (
from crewai.rag.embeddings.types import EmbeddingOptions
def _create_watson_embedding_function(**config_dict) -> EmbeddingFunction:
"""Create Watson embedding function with proper error handling."""
try:
import ibm_watsonx_ai.foundation_models as watson_models # type: ignore[import-not-found]
from ibm_watsonx_ai import Credentials # type: ignore[import-not-found]
from ibm_watsonx_ai.metanames import ( # type: ignore[import-not-found]
EmbedTextParamsMetaNames as EmbedParams,
)
except ImportError as e:
raise ImportError(
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
) from e
class WatsonEmbeddingFunction(EmbeddingFunction):
def __init__(self, **kwargs):
self.config = kwargs
def __call__(self, input):
if isinstance(input, str):
input = [input]
embed_params = {
EmbedParams.TRUNCATE_INPUT_TOKENS: 3,
EmbedParams.RETURN_OPTIONS: {"input_text": True},
}
embedding = watson_models.Embeddings(
model_id=self.config.get("model_name") or self.config.get("model"),
params=embed_params,
credentials=Credentials(
api_key=self.config.get("api_key"),
url=self.config.get("api_url") or self.config.get("url")
),
project_id=self.config.get("project_id"),
)
try:
embeddings = embedding.embed_documents(input)
return embeddings
except Exception as e:
raise RuntimeError(f"Error during Watson embedding: {e}") from e
return WatsonEmbeddingFunction(**config_dict)
def get_embedding_function(
config: EmbeddingOptions | dict | None = None,
) -> EmbeddingFunction:
@@ -75,6 +120,7 @@ def get_embedding_function(
- openclip: OpenCLIP embeddings for multimodal tasks
- text2vec: Text2Vec embeddings
- onnx: ONNX MiniLM-L6-v2 (no API key needed, included with ChromaDB)
- watson: IBM Watson embeddings
Examples:
# Use default OpenAI embedding
@@ -108,6 +154,15 @@ def get_embedding_function(
>>> embedder = get_embedding_function({
... "provider": "onnx"
... })
# Use Watson embeddings
>>> embedder = get_embedding_function({
... "provider": "watson",
... "api_key": "your-watson-api-key",
... "api_url": "your-watson-url",
... "project_id": "your-project-id",
... "model_name": "ibm/slate-125m-english-rtrvr"
... })
"""
if config is None:
return OpenAIEmbeddingFunction(
@@ -138,6 +193,7 @@ def get_embedding_function(
"openclip": OpenCLIPEmbeddingFunction,
"text2vec": Text2VecEmbeddingFunction,
"onnx": ONNXMiniLM_L6_V2,
"watson": _create_watson_embedding_function,
}
if provider not in embedding_functions:

View File

@@ -22,6 +22,7 @@ EmbeddingProvider = Literal[
"openclip",
"text2vec",
"onnx",
"watson",
]
"""Supported embedding providers.