Compare commits

...

2 Commits

Author SHA1 Message Date
Devin AI
4489baa149 fix: resolve lint and type-checker issues
- Fix RET504 lint error by removing unnecessary assignment before return
- Add proper type annotations for embedding_functions dictionary
- Import Callable and Any from typing to resolve mypy errors

Co-Authored-By: João <joao@crewai.com>
2025-09-23 15:47:29 +00:00
Devin AI
1442f3e4b6 fix: add Watson embedding support to factory
- Add Watson to EmbeddingProvider type definition
- Implement _create_watson_embedding_function in factory.py
- Add Watson to embedding_functions dictionary
- Add comprehensive tests for Watson embedding functionality
- Ensure proper error handling for missing IBM Watson dependencies

Fixes #3582

Co-Authored-By: João <joao@crewai.com>
2025-09-23 15:41:56 +00:00
4 changed files with 3376 additions and 3412 deletions

View File

@@ -1,6 +1,7 @@
"""Minimal embedding function factory for CrewAI."""
import os
from typing import Any, Callable
from chromadb import EmbeddingFunction
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
@@ -46,6 +47,50 @@ from chromadb.utils.embedding_functions.text2vec_embedding_function import (
from crewai.rag.embeddings.types import EmbeddingOptions
def _create_watson_embedding_function(**config_dict) -> EmbeddingFunction:
"""Create Watson embedding function with proper error handling."""
try:
import ibm_watsonx_ai.foundation_models as watson_models # type: ignore[import-not-found]
from ibm_watsonx_ai import Credentials # type: ignore[import-not-found]
from ibm_watsonx_ai.metanames import ( # type: ignore[import-not-found]
EmbedTextParamsMetaNames as EmbedParams,
)
except ImportError as e:
raise ImportError(
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
) from e
class WatsonEmbeddingFunction(EmbeddingFunction):
def __init__(self, **kwargs):
self.config = kwargs
def __call__(self, input):
if isinstance(input, str):
input = [input]
embed_params = {
EmbedParams.TRUNCATE_INPUT_TOKENS: 3,
EmbedParams.RETURN_OPTIONS: {"input_text": True},
}
embedding = watson_models.Embeddings(
model_id=self.config.get("model_name") or self.config.get("model"),
params=embed_params,
credentials=Credentials(
api_key=self.config.get("api_key"),
url=self.config.get("api_url") or self.config.get("url")
),
project_id=self.config.get("project_id"),
)
try:
return embedding.embed_documents(input)
except Exception as e:
raise RuntimeError(f"Error during Watson embedding: {e}") from e
return WatsonEmbeddingFunction(**config_dict)
def get_embedding_function(
config: EmbeddingOptions | dict | None = None,
) -> EmbeddingFunction:
@@ -75,6 +120,7 @@ def get_embedding_function(
- openclip: OpenCLIP embeddings for multimodal tasks
- text2vec: Text2Vec embeddings
- onnx: ONNX MiniLM-L6-v2 (no API key needed, included with ChromaDB)
- watson: IBM Watson embeddings
Examples:
# Use default OpenAI embedding
@@ -108,6 +154,15 @@ def get_embedding_function(
>>> embedder = get_embedding_function({
... "provider": "onnx"
... })
# Use Watson embeddings
>>> embedder = get_embedding_function({
... "provider": "watson",
... "api_key": "your-watson-api-key",
... "api_url": "your-watson-url",
... "project_id": "your-project-id",
... "model_name": "ibm/slate-125m-english-rtrvr"
... })
"""
if config is None:
return OpenAIEmbeddingFunction(
@@ -122,7 +177,7 @@ def get_embedding_function(
provider = config_dict.pop("provider", "openai")
embedding_functions = {
embedding_functions: dict[str, Callable[..., EmbeddingFunction]] = {
"openai": OpenAIEmbeddingFunction,
"cohere": CohereEmbeddingFunction,
"ollama": OllamaEmbeddingFunction,
@@ -138,6 +193,7 @@ def get_embedding_function(
"openclip": OpenCLIPEmbeddingFunction,
"text2vec": Text2VecEmbeddingFunction,
"onnx": ONNXMiniLM_L6_V2,
"watson": _create_watson_embedding_function,
}
if provider not in embedding_functions:

View File

@@ -22,6 +22,7 @@ EmbeddingProvider = Literal[
"openclip",
"text2vec",
"onnx",
"watson",
]
"""Supported embedding providers.

View File

@@ -248,3 +248,68 @@ def test_get_embedding_function_instructor() -> None:
mock_instructor.assert_called_once_with(model_name="hkunlp/instructor-large")
assert result == mock_instance
def test_get_embedding_function_watson() -> None:
"""Test Watson embedding function."""
with patch("crewai.rag.embeddings.factory._create_watson_embedding_function") as mock_watson:
mock_instance = MagicMock()
mock_watson.return_value = mock_instance
config = {
"provider": "watson",
"api_key": "watson-api-key",
"api_url": "https://watson-url.com",
"project_id": "watson-project-id",
"model_name": "ibm/slate-125m-english-rtrvr",
}
result = get_embedding_function(config)
mock_watson.assert_called_once_with(
api_key="watson-api-key",
api_url="https://watson-url.com",
project_id="watson-project-id",
model_name="ibm/slate-125m-english-rtrvr",
)
assert result == mock_instance
def test_get_embedding_function_watson_missing_dependencies() -> None:
"""Test Watson embedding function with missing dependencies."""
with patch("crewai.rag.embeddings.factory._create_watson_embedding_function") as mock_watson:
mock_watson.side_effect = ImportError(
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
)
config = {
"provider": "watson",
"api_key": "watson-api-key",
"api_url": "https://watson-url.com",
"project_id": "watson-project-id",
"model_name": "ibm/slate-125m-english-rtrvr",
}
with pytest.raises(ImportError, match="IBM Watson dependencies are not installed"):
get_embedding_function(config)
def test_get_embedding_function_watson_with_embedding_options() -> None:
"""Test Watson embedding function with EmbeddingOptions object."""
with patch("crewai.rag.embeddings.factory._create_watson_embedding_function") as mock_watson:
mock_instance = MagicMock()
mock_watson.return_value = mock_instance
options = EmbeddingOptions(
provider="watson",
api_key="watson-key",
model_name="ibm/slate-125m-english-rtrvr"
)
result = get_embedding_function(options)
call_kwargs = mock_watson.call_args.kwargs
assert "api_key" in call_kwargs
assert call_kwargs["api_key"].get_secret_value() == "watson-key"
assert call_kwargs["model_name"] == "ibm/slate-125m-english-rtrvr"
assert result == mock_instance

6664
uv.lock generated

File diff suppressed because it is too large Load Diff