Compare commits

...

2 Commits

Author SHA1 Message Date
Devin AI
4489baa149 fix: resolve lint and type-checker issues
- Fix RET504 lint error by removing unnecessary assignment before return
- Add proper type annotations for embedding_functions dictionary
- Import Callable and Any from typing to resolve mypy errors

Co-Authored-By: João <joao@crewai.com>
2025-09-23 15:47:29 +00:00
Devin AI
1442f3e4b6 fix: add Watson embedding support to factory
- Add Watson to EmbeddingProvider type definition
- Implement _create_watson_embedding_function in factory.py
- Add Watson to embedding_functions dictionary
- Add comprehensive tests for Watson embedding functionality
- Ensure proper error handling for missing IBM Watson dependencies

Fixes #3582

Co-Authored-By: João <joao@crewai.com>
2025-09-23 15:41:56 +00:00
4 changed files with 3376 additions and 3412 deletions

View File

@@ -1,6 +1,7 @@
"""Minimal embedding function factory for CrewAI.""" """Minimal embedding function factory for CrewAI."""
import os import os
from typing import Any, Callable
from chromadb import EmbeddingFunction from chromadb import EmbeddingFunction
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import ( from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
@@ -46,6 +47,50 @@ from chromadb.utils.embedding_functions.text2vec_embedding_function import (
from crewai.rag.embeddings.types import EmbeddingOptions from crewai.rag.embeddings.types import EmbeddingOptions
def _create_watson_embedding_function(**config_dict) -> EmbeddingFunction:
"""Create Watson embedding function with proper error handling."""
try:
import ibm_watsonx_ai.foundation_models as watson_models # type: ignore[import-not-found]
from ibm_watsonx_ai import Credentials # type: ignore[import-not-found]
from ibm_watsonx_ai.metanames import ( # type: ignore[import-not-found]
EmbedTextParamsMetaNames as EmbedParams,
)
except ImportError as e:
raise ImportError(
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
) from e
class WatsonEmbeddingFunction(EmbeddingFunction):
def __init__(self, **kwargs):
self.config = kwargs
def __call__(self, input):
if isinstance(input, str):
input = [input]
embed_params = {
EmbedParams.TRUNCATE_INPUT_TOKENS: 3,
EmbedParams.RETURN_OPTIONS: {"input_text": True},
}
embedding = watson_models.Embeddings(
model_id=self.config.get("model_name") or self.config.get("model"),
params=embed_params,
credentials=Credentials(
api_key=self.config.get("api_key"),
url=self.config.get("api_url") or self.config.get("url")
),
project_id=self.config.get("project_id"),
)
try:
return embedding.embed_documents(input)
except Exception as e:
raise RuntimeError(f"Error during Watson embedding: {e}") from e
return WatsonEmbeddingFunction(**config_dict)
def get_embedding_function( def get_embedding_function(
config: EmbeddingOptions | dict | None = None, config: EmbeddingOptions | dict | None = None,
) -> EmbeddingFunction: ) -> EmbeddingFunction:
@@ -75,6 +120,7 @@ def get_embedding_function(
- openclip: OpenCLIP embeddings for multimodal tasks - openclip: OpenCLIP embeddings for multimodal tasks
- text2vec: Text2Vec embeddings - text2vec: Text2Vec embeddings
- onnx: ONNX MiniLM-L6-v2 (no API key needed, included with ChromaDB) - onnx: ONNX MiniLM-L6-v2 (no API key needed, included with ChromaDB)
- watson: IBM Watson embeddings
Examples: Examples:
# Use default OpenAI embedding # Use default OpenAI embedding
@@ -108,6 +154,15 @@ def get_embedding_function(
>>> embedder = get_embedding_function({ >>> embedder = get_embedding_function({
... "provider": "onnx" ... "provider": "onnx"
... }) ... })
# Use Watson embeddings
>>> embedder = get_embedding_function({
... "provider": "watson",
... "api_key": "your-watson-api-key",
... "api_url": "your-watson-url",
... "project_id": "your-project-id",
... "model_name": "ibm/slate-125m-english-rtrvr"
... })
""" """
if config is None: if config is None:
return OpenAIEmbeddingFunction( return OpenAIEmbeddingFunction(
@@ -122,7 +177,7 @@ def get_embedding_function(
provider = config_dict.pop("provider", "openai") provider = config_dict.pop("provider", "openai")
embedding_functions = { embedding_functions: dict[str, Callable[..., EmbeddingFunction]] = {
"openai": OpenAIEmbeddingFunction, "openai": OpenAIEmbeddingFunction,
"cohere": CohereEmbeddingFunction, "cohere": CohereEmbeddingFunction,
"ollama": OllamaEmbeddingFunction, "ollama": OllamaEmbeddingFunction,
@@ -138,6 +193,7 @@ def get_embedding_function(
"openclip": OpenCLIPEmbeddingFunction, "openclip": OpenCLIPEmbeddingFunction,
"text2vec": Text2VecEmbeddingFunction, "text2vec": Text2VecEmbeddingFunction,
"onnx": ONNXMiniLM_L6_V2, "onnx": ONNXMiniLM_L6_V2,
"watson": _create_watson_embedding_function,
} }
if provider not in embedding_functions: if provider not in embedding_functions:

View File

@@ -22,6 +22,7 @@ EmbeddingProvider = Literal[
"openclip", "openclip",
"text2vec", "text2vec",
"onnx", "onnx",
"watson",
] ]
"""Supported embedding providers. """Supported embedding providers.

View File

@@ -248,3 +248,68 @@ def test_get_embedding_function_instructor() -> None:
mock_instructor.assert_called_once_with(model_name="hkunlp/instructor-large") mock_instructor.assert_called_once_with(model_name="hkunlp/instructor-large")
assert result == mock_instance assert result == mock_instance
def test_get_embedding_function_watson() -> None:
"""Test Watson embedding function."""
with patch("crewai.rag.embeddings.factory._create_watson_embedding_function") as mock_watson:
mock_instance = MagicMock()
mock_watson.return_value = mock_instance
config = {
"provider": "watson",
"api_key": "watson-api-key",
"api_url": "https://watson-url.com",
"project_id": "watson-project-id",
"model_name": "ibm/slate-125m-english-rtrvr",
}
result = get_embedding_function(config)
mock_watson.assert_called_once_with(
api_key="watson-api-key",
api_url="https://watson-url.com",
project_id="watson-project-id",
model_name="ibm/slate-125m-english-rtrvr",
)
assert result == mock_instance
def test_get_embedding_function_watson_missing_dependencies() -> None:
"""Test Watson embedding function with missing dependencies."""
with patch("crewai.rag.embeddings.factory._create_watson_embedding_function") as mock_watson:
mock_watson.side_effect = ImportError(
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
)
config = {
"provider": "watson",
"api_key": "watson-api-key",
"api_url": "https://watson-url.com",
"project_id": "watson-project-id",
"model_name": "ibm/slate-125m-english-rtrvr",
}
with pytest.raises(ImportError, match="IBM Watson dependencies are not installed"):
get_embedding_function(config)
def test_get_embedding_function_watson_with_embedding_options() -> None:
"""Test Watson embedding function with EmbeddingOptions object."""
with patch("crewai.rag.embeddings.factory._create_watson_embedding_function") as mock_watson:
mock_instance = MagicMock()
mock_watson.return_value = mock_instance
options = EmbeddingOptions(
provider="watson",
api_key="watson-key",
model_name="ibm/slate-125m-english-rtrvr"
)
result = get_embedding_function(options)
call_kwargs = mock_watson.call_args.kwargs
assert "api_key" in call_kwargs
assert call_kwargs["api_key"].get_secret_value() == "watson-key"
assert call_kwargs["model_name"] == "ibm/slate-125m-english-rtrvr"
assert result == mock_instance

6664
uv.lock generated

File diff suppressed because it is too large Load Diff