mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-22 22:58:13 +00:00
Compare commits
2 Commits
devin/1768
...
devin/1758
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4489baa149 | ||
|
|
1442f3e4b6 |
@@ -1,6 +1,7 @@
|
|||||||
"""Minimal embedding function factory for CrewAI."""
|
"""Minimal embedding function factory for CrewAI."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from typing import Any, Callable
|
||||||
|
|
||||||
from chromadb import EmbeddingFunction
|
from chromadb import EmbeddingFunction
|
||||||
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
||||||
@@ -46,6 +47,50 @@ from chromadb.utils.embedding_functions.text2vec_embedding_function import (
|
|||||||
from crewai.rag.embeddings.types import EmbeddingOptions
|
from crewai.rag.embeddings.types import EmbeddingOptions
|
||||||
|
|
||||||
|
|
||||||
|
def _create_watson_embedding_function(**config_dict) -> EmbeddingFunction:
|
||||||
|
"""Create Watson embedding function with proper error handling."""
|
||||||
|
try:
|
||||||
|
import ibm_watsonx_ai.foundation_models as watson_models # type: ignore[import-not-found]
|
||||||
|
from ibm_watsonx_ai import Credentials # type: ignore[import-not-found]
|
||||||
|
from ibm_watsonx_ai.metanames import ( # type: ignore[import-not-found]
|
||||||
|
EmbedTextParamsMetaNames as EmbedParams,
|
||||||
|
)
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
class WatsonEmbeddingFunction(EmbeddingFunction):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.config = kwargs
|
||||||
|
|
||||||
|
def __call__(self, input):
|
||||||
|
if isinstance(input, str):
|
||||||
|
input = [input]
|
||||||
|
|
||||||
|
embed_params = {
|
||||||
|
EmbedParams.TRUNCATE_INPUT_TOKENS: 3,
|
||||||
|
EmbedParams.RETURN_OPTIONS: {"input_text": True},
|
||||||
|
}
|
||||||
|
|
||||||
|
embedding = watson_models.Embeddings(
|
||||||
|
model_id=self.config.get("model_name") or self.config.get("model"),
|
||||||
|
params=embed_params,
|
||||||
|
credentials=Credentials(
|
||||||
|
api_key=self.config.get("api_key"),
|
||||||
|
url=self.config.get("api_url") or self.config.get("url")
|
||||||
|
),
|
||||||
|
project_id=self.config.get("project_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return embedding.embed_documents(input)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Error during Watson embedding: {e}") from e
|
||||||
|
|
||||||
|
return WatsonEmbeddingFunction(**config_dict)
|
||||||
|
|
||||||
|
|
||||||
def get_embedding_function(
|
def get_embedding_function(
|
||||||
config: EmbeddingOptions | dict | None = None,
|
config: EmbeddingOptions | dict | None = None,
|
||||||
) -> EmbeddingFunction:
|
) -> EmbeddingFunction:
|
||||||
@@ -75,6 +120,7 @@ def get_embedding_function(
|
|||||||
- openclip: OpenCLIP embeddings for multimodal tasks
|
- openclip: OpenCLIP embeddings for multimodal tasks
|
||||||
- text2vec: Text2Vec embeddings
|
- text2vec: Text2Vec embeddings
|
||||||
- onnx: ONNX MiniLM-L6-v2 (no API key needed, included with ChromaDB)
|
- onnx: ONNX MiniLM-L6-v2 (no API key needed, included with ChromaDB)
|
||||||
|
- watson: IBM Watson embeddings
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
# Use default OpenAI embedding
|
# Use default OpenAI embedding
|
||||||
@@ -108,6 +154,15 @@ def get_embedding_function(
|
|||||||
>>> embedder = get_embedding_function({
|
>>> embedder = get_embedding_function({
|
||||||
... "provider": "onnx"
|
... "provider": "onnx"
|
||||||
... })
|
... })
|
||||||
|
|
||||||
|
# Use Watson embeddings
|
||||||
|
>>> embedder = get_embedding_function({
|
||||||
|
... "provider": "watson",
|
||||||
|
... "api_key": "your-watson-api-key",
|
||||||
|
... "api_url": "your-watson-url",
|
||||||
|
... "project_id": "your-project-id",
|
||||||
|
... "model_name": "ibm/slate-125m-english-rtrvr"
|
||||||
|
... })
|
||||||
"""
|
"""
|
||||||
if config is None:
|
if config is None:
|
||||||
return OpenAIEmbeddingFunction(
|
return OpenAIEmbeddingFunction(
|
||||||
@@ -122,7 +177,7 @@ def get_embedding_function(
|
|||||||
|
|
||||||
provider = config_dict.pop("provider", "openai")
|
provider = config_dict.pop("provider", "openai")
|
||||||
|
|
||||||
embedding_functions = {
|
embedding_functions: dict[str, Callable[..., EmbeddingFunction]] = {
|
||||||
"openai": OpenAIEmbeddingFunction,
|
"openai": OpenAIEmbeddingFunction,
|
||||||
"cohere": CohereEmbeddingFunction,
|
"cohere": CohereEmbeddingFunction,
|
||||||
"ollama": OllamaEmbeddingFunction,
|
"ollama": OllamaEmbeddingFunction,
|
||||||
@@ -138,6 +193,7 @@ def get_embedding_function(
|
|||||||
"openclip": OpenCLIPEmbeddingFunction,
|
"openclip": OpenCLIPEmbeddingFunction,
|
||||||
"text2vec": Text2VecEmbeddingFunction,
|
"text2vec": Text2VecEmbeddingFunction,
|
||||||
"onnx": ONNXMiniLM_L6_V2,
|
"onnx": ONNXMiniLM_L6_V2,
|
||||||
|
"watson": _create_watson_embedding_function,
|
||||||
}
|
}
|
||||||
|
|
||||||
if provider not in embedding_functions:
|
if provider not in embedding_functions:
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ EmbeddingProvider = Literal[
|
|||||||
"openclip",
|
"openclip",
|
||||||
"text2vec",
|
"text2vec",
|
||||||
"onnx",
|
"onnx",
|
||||||
|
"watson",
|
||||||
]
|
]
|
||||||
"""Supported embedding providers.
|
"""Supported embedding providers.
|
||||||
|
|
||||||
|
|||||||
@@ -248,3 +248,68 @@ def test_get_embedding_function_instructor() -> None:
|
|||||||
|
|
||||||
mock_instructor.assert_called_once_with(model_name="hkunlp/instructor-large")
|
mock_instructor.assert_called_once_with(model_name="hkunlp/instructor-large")
|
||||||
assert result == mock_instance
|
assert result == mock_instance
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_embedding_function_watson() -> None:
|
||||||
|
"""Test Watson embedding function."""
|
||||||
|
with patch("crewai.rag.embeddings.factory._create_watson_embedding_function") as mock_watson:
|
||||||
|
mock_instance = MagicMock()
|
||||||
|
mock_watson.return_value = mock_instance
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"provider": "watson",
|
||||||
|
"api_key": "watson-api-key",
|
||||||
|
"api_url": "https://watson-url.com",
|
||||||
|
"project_id": "watson-project-id",
|
||||||
|
"model_name": "ibm/slate-125m-english-rtrvr",
|
||||||
|
}
|
||||||
|
|
||||||
|
result = get_embedding_function(config)
|
||||||
|
|
||||||
|
mock_watson.assert_called_once_with(
|
||||||
|
api_key="watson-api-key",
|
||||||
|
api_url="https://watson-url.com",
|
||||||
|
project_id="watson-project-id",
|
||||||
|
model_name="ibm/slate-125m-english-rtrvr",
|
||||||
|
)
|
||||||
|
assert result == mock_instance
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_embedding_function_watson_missing_dependencies() -> None:
|
||||||
|
"""Test Watson embedding function with missing dependencies."""
|
||||||
|
with patch("crewai.rag.embeddings.factory._create_watson_embedding_function") as mock_watson:
|
||||||
|
mock_watson.side_effect = ImportError(
|
||||||
|
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
|
||||||
|
)
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"provider": "watson",
|
||||||
|
"api_key": "watson-api-key",
|
||||||
|
"api_url": "https://watson-url.com",
|
||||||
|
"project_id": "watson-project-id",
|
||||||
|
"model_name": "ibm/slate-125m-english-rtrvr",
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(ImportError, match="IBM Watson dependencies are not installed"):
|
||||||
|
get_embedding_function(config)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_embedding_function_watson_with_embedding_options() -> None:
|
||||||
|
"""Test Watson embedding function with EmbeddingOptions object."""
|
||||||
|
with patch("crewai.rag.embeddings.factory._create_watson_embedding_function") as mock_watson:
|
||||||
|
mock_instance = MagicMock()
|
||||||
|
mock_watson.return_value = mock_instance
|
||||||
|
|
||||||
|
options = EmbeddingOptions(
|
||||||
|
provider="watson",
|
||||||
|
api_key="watson-key",
|
||||||
|
model_name="ibm/slate-125m-english-rtrvr"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = get_embedding_function(options)
|
||||||
|
|
||||||
|
call_kwargs = mock_watson.call_args.kwargs
|
||||||
|
assert "api_key" in call_kwargs
|
||||||
|
assert call_kwargs["api_key"].get_secret_value() == "watson-key"
|
||||||
|
assert call_kwargs["model_name"] == "ibm/slate-125m-english-rtrvr"
|
||||||
|
assert result == mock_instance
|
||||||
|
|||||||
Reference in New Issue
Block a user