mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 04:18:35 +00:00
Compare commits
2 Commits
1.6.1
...
devin/1758
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4489baa149 | ||
|
|
1442f3e4b6 |
@@ -1,6 +1,7 @@
|
||||
"""Minimal embedding function factory for CrewAI."""
|
||||
|
||||
import os
|
||||
from typing import Any, Callable
|
||||
|
||||
from chromadb import EmbeddingFunction
|
||||
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
||||
@@ -46,6 +47,50 @@ from chromadb.utils.embedding_functions.text2vec_embedding_function import (
|
||||
from crewai.rag.embeddings.types import EmbeddingOptions
|
||||
|
||||
|
||||
def _create_watson_embedding_function(**config_dict) -> EmbeddingFunction:
|
||||
"""Create Watson embedding function with proper error handling."""
|
||||
try:
|
||||
import ibm_watsonx_ai.foundation_models as watson_models # type: ignore[import-not-found]
|
||||
from ibm_watsonx_ai import Credentials # type: ignore[import-not-found]
|
||||
from ibm_watsonx_ai.metanames import ( # type: ignore[import-not-found]
|
||||
EmbedTextParamsMetaNames as EmbedParams,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
|
||||
) from e
|
||||
|
||||
class WatsonEmbeddingFunction(EmbeddingFunction):
|
||||
def __init__(self, **kwargs):
|
||||
self.config = kwargs
|
||||
|
||||
def __call__(self, input):
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
embed_params = {
|
||||
EmbedParams.TRUNCATE_INPUT_TOKENS: 3,
|
||||
EmbedParams.RETURN_OPTIONS: {"input_text": True},
|
||||
}
|
||||
|
||||
embedding = watson_models.Embeddings(
|
||||
model_id=self.config.get("model_name") or self.config.get("model"),
|
||||
params=embed_params,
|
||||
credentials=Credentials(
|
||||
api_key=self.config.get("api_key"),
|
||||
url=self.config.get("api_url") or self.config.get("url")
|
||||
),
|
||||
project_id=self.config.get("project_id"),
|
||||
)
|
||||
|
||||
try:
|
||||
return embedding.embed_documents(input)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error during Watson embedding: {e}") from e
|
||||
|
||||
return WatsonEmbeddingFunction(**config_dict)
|
||||
|
||||
|
||||
def get_embedding_function(
|
||||
config: EmbeddingOptions | dict | None = None,
|
||||
) -> EmbeddingFunction:
|
||||
@@ -75,6 +120,7 @@ def get_embedding_function(
|
||||
- openclip: OpenCLIP embeddings for multimodal tasks
|
||||
- text2vec: Text2Vec embeddings
|
||||
- onnx: ONNX MiniLM-L6-v2 (no API key needed, included with ChromaDB)
|
||||
- watson: IBM Watson embeddings
|
||||
|
||||
Examples:
|
||||
# Use default OpenAI embedding
|
||||
@@ -108,6 +154,15 @@ def get_embedding_function(
|
||||
>>> embedder = get_embedding_function({
|
||||
... "provider": "onnx"
|
||||
... })
|
||||
|
||||
# Use Watson embeddings
|
||||
>>> embedder = get_embedding_function({
|
||||
... "provider": "watson",
|
||||
... "api_key": "your-watson-api-key",
|
||||
... "api_url": "your-watson-url",
|
||||
... "project_id": "your-project-id",
|
||||
... "model_name": "ibm/slate-125m-english-rtrvr"
|
||||
... })
|
||||
"""
|
||||
if config is None:
|
||||
return OpenAIEmbeddingFunction(
|
||||
@@ -122,7 +177,7 @@ def get_embedding_function(
|
||||
|
||||
provider = config_dict.pop("provider", "openai")
|
||||
|
||||
embedding_functions = {
|
||||
embedding_functions: dict[str, Callable[..., EmbeddingFunction]] = {
|
||||
"openai": OpenAIEmbeddingFunction,
|
||||
"cohere": CohereEmbeddingFunction,
|
||||
"ollama": OllamaEmbeddingFunction,
|
||||
@@ -138,6 +193,7 @@ def get_embedding_function(
|
||||
"openclip": OpenCLIPEmbeddingFunction,
|
||||
"text2vec": Text2VecEmbeddingFunction,
|
||||
"onnx": ONNXMiniLM_L6_V2,
|
||||
"watson": _create_watson_embedding_function,
|
||||
}
|
||||
|
||||
if provider not in embedding_functions:
|
||||
|
||||
@@ -22,6 +22,7 @@ EmbeddingProvider = Literal[
|
||||
"openclip",
|
||||
"text2vec",
|
||||
"onnx",
|
||||
"watson",
|
||||
]
|
||||
"""Supported embedding providers.
|
||||
|
||||
|
||||
@@ -248,3 +248,68 @@ def test_get_embedding_function_instructor() -> None:
|
||||
|
||||
mock_instructor.assert_called_once_with(model_name="hkunlp/instructor-large")
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_watson() -> None:
|
||||
"""Test Watson embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory._create_watson_embedding_function") as mock_watson:
|
||||
mock_instance = MagicMock()
|
||||
mock_watson.return_value = mock_instance
|
||||
|
||||
config = {
|
||||
"provider": "watson",
|
||||
"api_key": "watson-api-key",
|
||||
"api_url": "https://watson-url.com",
|
||||
"project_id": "watson-project-id",
|
||||
"model_name": "ibm/slate-125m-english-rtrvr",
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_watson.assert_called_once_with(
|
||||
api_key="watson-api-key",
|
||||
api_url="https://watson-url.com",
|
||||
project_id="watson-project-id",
|
||||
model_name="ibm/slate-125m-english-rtrvr",
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_watson_missing_dependencies() -> None:
|
||||
"""Test Watson embedding function with missing dependencies."""
|
||||
with patch("crewai.rag.embeddings.factory._create_watson_embedding_function") as mock_watson:
|
||||
mock_watson.side_effect = ImportError(
|
||||
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
|
||||
)
|
||||
|
||||
config = {
|
||||
"provider": "watson",
|
||||
"api_key": "watson-api-key",
|
||||
"api_url": "https://watson-url.com",
|
||||
"project_id": "watson-project-id",
|
||||
"model_name": "ibm/slate-125m-english-rtrvr",
|
||||
}
|
||||
|
||||
with pytest.raises(ImportError, match="IBM Watson dependencies are not installed"):
|
||||
get_embedding_function(config)
|
||||
|
||||
|
||||
def test_get_embedding_function_watson_with_embedding_options() -> None:
|
||||
"""Test Watson embedding function with EmbeddingOptions object."""
|
||||
with patch("crewai.rag.embeddings.factory._create_watson_embedding_function") as mock_watson:
|
||||
mock_instance = MagicMock()
|
||||
mock_watson.return_value = mock_instance
|
||||
|
||||
options = EmbeddingOptions(
|
||||
provider="watson",
|
||||
api_key="watson-key",
|
||||
model_name="ibm/slate-125m-english-rtrvr"
|
||||
)
|
||||
|
||||
result = get_embedding_function(options)
|
||||
|
||||
call_kwargs = mock_watson.call_args.kwargs
|
||||
assert "api_key" in call_kwargs
|
||||
assert call_kwargs["api_key"].get_secret_value() == "watson-key"
|
||||
assert call_kwargs["model_name"] == "ibm/slate-125m-english-rtrvr"
|
||||
assert result == mock_instance
|
||||
|
||||
Reference in New Issue
Block a user