fix: rename watson to watsonx embedding provider and prefix env vars

- prefix provider env vars with embeddings_  
- rename watson → watsonx in providers  
- add deprecation warning and alias for legacy 'watson' key (to be removed in v1.0.0)
This commit is contained in:
Greyson LaLonde
2025-09-26 10:57:18 -04:00
committed by GitHub
parent 091d1267d8
commit 12fa7e2ff1
7 changed files with 112 additions and 66 deletions

View File

@@ -2,8 +2,11 @@
from __future__ import annotations from __future__ import annotations
import warnings
from typing import TYPE_CHECKING, TypeVar, overload from typing import TYPE_CHECKING, TypeVar, overload
from typing_extensions import deprecated
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.utilities.import_utils import import_and_validate_definition from crewai.utilities.import_utils import import_and_validate_definition
@@ -59,9 +62,12 @@ if TYPE_CHECKING:
HuggingFaceProviderSpec, HuggingFaceProviderSpec,
) )
from crewai.rag.embeddings.providers.ibm.embedding_callable import ( from crewai.rag.embeddings.providers.ibm.embedding_callable import (
WatsonEmbeddingFunction, WatsonXEmbeddingFunction,
)
from crewai.rag.embeddings.providers.ibm.types import (
WatsonProviderSpec,
WatsonXProviderSpec,
) )
from crewai.rag.embeddings.providers.ibm.types import WatsonProviderSpec
from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec
from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec
from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec
@@ -100,7 +106,8 @@ PROVIDER_PATHS = {
"sentence-transformer": "crewai.rag.embeddings.providers.sentence_transformer.sentence_transformer_provider.SentenceTransformerProvider", "sentence-transformer": "crewai.rag.embeddings.providers.sentence_transformer.sentence_transformer_provider.SentenceTransformerProvider",
"text2vec": "crewai.rag.embeddings.providers.text2vec.text2vec_provider.Text2VecProvider", "text2vec": "crewai.rag.embeddings.providers.text2vec.text2vec_provider.Text2VecProvider",
"voyageai": "crewai.rag.embeddings.providers.voyageai.voyageai_provider.VoyageAIProvider", "voyageai": "crewai.rag.embeddings.providers.voyageai.voyageai_provider.VoyageAIProvider",
"watson": "crewai.rag.embeddings.providers.ibm.watson.WatsonProvider", "watson": "crewai.rag.embeddings.providers.ibm.watsonx.WatsonXProvider", # Deprecated alias
"watsonx": "crewai.rag.embeddings.providers.ibm.watsonx.WatsonXProvider",
} }
@@ -169,7 +176,14 @@ def build_embedder_from_dict(
@overload @overload
def build_embedder_from_dict(spec: WatsonProviderSpec) -> WatsonEmbeddingFunction: ... def build_embedder_from_dict(spec: WatsonXProviderSpec) -> WatsonXEmbeddingFunction: ...
@overload
@deprecated(
'The "WatsonProviderSpec" provider spec is deprecated and will be removed in v1.0.0. Use "WatsonXProviderSpec" instead.'
)
def build_embedder_from_dict(spec: WatsonProviderSpec) -> WatsonXEmbeddingFunction: ...
@overload @overload
@@ -233,6 +247,14 @@ def build_embedder_from_dict(spec):
if not provider_name: if not provider_name:
raise ValueError("Missing 'provider' key in specification") raise ValueError("Missing 'provider' key in specification")
if provider_name == "watson":
warnings.warn(
'The "watson" provider key is deprecated and will be removed in v1.0.0. '
'Use "watsonx" instead.',
DeprecationWarning,
stacklevel=2,
)
if provider_name not in PROVIDER_PATHS: if provider_name not in PROVIDER_PATHS:
raise ValueError( raise ValueError(
f"Unknown provider: {provider_name}. Available providers: {list(PROVIDER_PATHS.keys())}" f"Unknown provider: {provider_name}. Available providers: {list(PROVIDER_PATHS.keys())}"
@@ -300,7 +322,14 @@ def build_embedder(spec: VoyageAIProviderSpec) -> VoyageAIEmbeddingFunction: ...
@overload @overload
def build_embedder(spec: WatsonProviderSpec) -> WatsonEmbeddingFunction: ... def build_embedder(spec: WatsonXProviderSpec) -> WatsonXEmbeddingFunction: ...
@overload
@deprecated(
'The "WatsonProviderSpec" provider spec is deprecated and will be removed in v1.0.0. Use "WatsonXProviderSpec" instead.'
)
def build_embedder(spec: WatsonProviderSpec) -> WatsonXEmbeddingFunction: ...
@overload @overload

View File

@@ -1,15 +1,17 @@
"""IBM embedding providers.""" """IBM embedding providers."""
from crewai.rag.embeddings.providers.ibm.types import ( from crewai.rag.embeddings.providers.ibm.types import (
WatsonProviderConfig,
WatsonProviderSpec, WatsonProviderSpec,
WatsonXProviderConfig,
WatsonXProviderSpec,
) )
from crewai.rag.embeddings.providers.ibm.watson import ( from crewai.rag.embeddings.providers.ibm.watsonx import (
WatsonProvider, WatsonXProvider,
) )
__all__ = [ __all__ = [
"WatsonProvider",
"WatsonProviderConfig",
"WatsonProviderSpec", "WatsonProviderSpec",
"WatsonXProvider",
"WatsonXProviderConfig",
"WatsonXProviderSpec",
] ]

View File

@@ -1,4 +1,4 @@
"""IBM Watson embedding function implementation.""" """IBM WatsonX embedding function implementation."""
from typing import cast from typing import cast
@@ -6,17 +6,17 @@ from typing_extensions import Unpack
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
from crewai.rag.core.types import Documents, Embeddings from crewai.rag.core.types import Documents, Embeddings
from crewai.rag.embeddings.providers.ibm.types import WatsonProviderConfig from crewai.rag.embeddings.providers.ibm.types import WatsonXProviderConfig
class WatsonEmbeddingFunction(EmbeddingFunction[Documents]): class WatsonXEmbeddingFunction(EmbeddingFunction[Documents]):
"""Embedding function for IBM Watson models.""" """Embedding function for IBM WatsonX models."""
def __init__(self, **kwargs: Unpack[WatsonProviderConfig]) -> None: def __init__(self, **kwargs: Unpack[WatsonXProviderConfig]) -> None:
"""Initialize Watson embedding function. """Initialize WatsonX embedding function.
Args: Args:
**kwargs: Configuration parameters for Watson Embeddings and Credentials. **kwargs: Configuration parameters for WatsonX Embeddings and Credentials.
""" """
self._config = kwargs self._config = kwargs
@@ -40,7 +40,7 @@ class WatsonEmbeddingFunction(EmbeddingFunction[Documents]):
except ImportError as e: except ImportError as e:
raise ImportError( raise ImportError(
"ibm-watsonx-ai is required for watson embeddings. " "ibm-watsonx-ai is required for watsonx embeddings. "
"Install it with: uv add ibm-watsonx-ai" "Install it with: uv add ibm-watsonx-ai"
) from e ) from e
@@ -150,5 +150,5 @@ class WatsonEmbeddingFunction(EmbeddingFunction[Documents]):
embeddings = embedding.embed_documents(input) embeddings = embedding.embed_documents(input)
return cast(Embeddings, embeddings) return cast(Embeddings, embeddings)
except Exception as e: except Exception as e:
print(f"Error during Watson embedding: {e}") print(f"Error during WatsonX embedding: {e}")
raise raise

View File

@@ -1,12 +1,12 @@
"""Type definitions for IBM Watson embedding providers.""" """Type definitions for IBM WatsonX embedding providers."""
from typing import Annotated, Any, Literal from typing import Annotated, Any, Literal
from typing_extensions import Required, TypedDict from typing_extensions import Required, TypedDict, deprecated
class WatsonProviderConfig(TypedDict, total=False): class WatsonXProviderConfig(TypedDict, total=False):
"""Configuration for Watson provider.""" """Configuration for WatsonX provider."""
model_id: str model_id: str
url: str url: str
@@ -37,8 +37,22 @@ class WatsonProviderConfig(TypedDict, total=False):
proxies: dict proxies: dict
class WatsonXProviderSpec(TypedDict, total=False):
"""WatsonX provider specification."""
provider: Required[Literal["watsonx"]]
config: WatsonXProviderConfig
@deprecated(
'The "WatsonProviderSpec" provider spec is deprecated and will be removed in v1.0.0. Use "WatsonXProviderSpec" instead.'
)
class WatsonProviderSpec(TypedDict, total=False): class WatsonProviderSpec(TypedDict, total=False):
"""Watson provider specification.""" """Watson provider specification (deprecated).
Notes:
- This is deprecated. Use WatsonXProviderSpec with provider="watsonx" instead.
"""
provider: Required[Literal["watson"]] provider: Required[Literal["watson"]]
config: WatsonProviderConfig config: WatsonXProviderConfig

View File

@@ -1,4 +1,4 @@
"""IBM Watson embeddings provider.""" """IBM WatsonX embeddings provider."""
from typing import Any from typing import Any
@@ -7,130 +7,130 @@ from typing_extensions import Self
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.rag.embeddings.providers.ibm.embedding_callable import ( from crewai.rag.embeddings.providers.ibm.embedding_callable import (
WatsonEmbeddingFunction, WatsonXEmbeddingFunction,
) )
class WatsonProvider(BaseEmbeddingsProvider[WatsonEmbeddingFunction]): class WatsonXProvider(BaseEmbeddingsProvider[WatsonXEmbeddingFunction]):
"""IBM Watson embeddings provider. """IBM WatsonX embeddings provider.
Note: Requires custom implementation as Watson uses a different interface. Note: Requires custom implementation as WatsonX uses a different interface.
""" """
embedding_callable: type[WatsonEmbeddingFunction] = Field( embedding_callable: type[WatsonXEmbeddingFunction] = Field(
default=WatsonEmbeddingFunction, description="Watson embedding function class" default=WatsonXEmbeddingFunction, description="WatsonX embedding function class"
) )
model_id: str = Field( model_id: str = Field(
description="Watson model ID", validation_alias="EMBEDDINGS_WATSON_MODEL_ID" description="WatsonX model ID", validation_alias="EMBEDDINGS_WATSONX_MODEL_ID"
) )
params: dict[str, str | dict[str, str]] | None = Field( params: dict[str, str | dict[str, str]] | None = Field(
default=None, description="Additional parameters" default=None, description="Additional parameters"
) )
credentials: Any | None = Field(default=None, description="Watson credentials") credentials: Any | None = Field(default=None, description="WatsonX credentials")
project_id: str | None = Field( project_id: str | None = Field(
default=None, default=None,
description="Watson project ID", description="WatsonX project ID",
validation_alias="EMBEDDINGS_WATSON_PROJECT_ID", validation_alias="EMBEDDINGS_WATSONX_PROJECT_ID",
) )
space_id: str | None = Field( space_id: str | None = Field(
default=None, default=None,
description="Watson space ID", description="WatsonX space ID",
validation_alias="EMBEDDINGS_WATSON_SPACE_ID", validation_alias="EMBEDDINGS_WATSONX_SPACE_ID",
) )
api_client: Any | None = Field(default=None, description="Watson API client") api_client: Any | None = Field(default=None, description="WatsonX API client")
verify: bool | str | None = Field( verify: bool | str | None = Field(
default=None, default=None,
description="SSL verification", description="SSL verification",
validation_alias="EMBEDDINGS_WATSON_VERIFY", validation_alias="EMBEDDINGS_WATSONX_VERIFY",
) )
persistent_connection: bool = Field( persistent_connection: bool = Field(
default=True, default=True,
description="Use persistent connection", description="Use persistent connection",
validation_alias="EMBEDDINGS_WATSON_PERSISTENT_CONNECTION", validation_alias="EMBEDDINGS_WATSONX_PERSISTENT_CONNECTION",
) )
batch_size: int = Field( batch_size: int = Field(
default=100, default=100,
description="Batch size for processing", description="Batch size for processing",
validation_alias="EMBEDDINGS_WATSON_BATCH_SIZE", validation_alias="EMBEDDINGS_WATSONX_BATCH_SIZE",
) )
concurrency_limit: int = Field( concurrency_limit: int = Field(
default=10, default=10,
description="Concurrency limit", description="Concurrency limit",
validation_alias="EMBEDDINGS_WATSON_CONCURRENCY_LIMIT", validation_alias="EMBEDDINGS_WATSONX_CONCURRENCY_LIMIT",
) )
max_retries: int | None = Field( max_retries: int | None = Field(
default=None, default=None,
description="Maximum retries", description="Maximum retries",
validation_alias="EMBEDDINGS_WATSON_MAX_RETRIES", validation_alias="EMBEDDINGS_WATSONX_MAX_RETRIES",
) )
delay_time: float | None = Field( delay_time: float | None = Field(
default=None, default=None,
description="Delay time between retries", description="Delay time between retries",
validation_alias="EMBEDDINGS_WATSON_DELAY_TIME", validation_alias="EMBEDDINGS_WATSONX_DELAY_TIME",
) )
retry_status_codes: list[int] | None = Field( retry_status_codes: list[int] | None = Field(
default=None, description="HTTP status codes to retry on" default=None, description="HTTP status codes to retry on"
) )
url: str = Field( url: str = Field(
description="Watson API URL", validation_alias="EMBEDDINGS_WATSON_URL" description="WatsonX API URL", validation_alias="EMBEDDINGS_WATSONX_URL"
) )
api_key: str = Field( api_key: str = Field(
description="Watson API key", validation_alias="EMBEDDINGS_WATSON_API_KEY" description="WatsonX API key", validation_alias="EMBEDDINGS_WATSONX_API_KEY"
) )
name: str | None = Field( name: str | None = Field(
default=None, default=None,
description="Service name", description="Service name",
validation_alias="EMBEDDINGS_WATSON_NAME", validation_alias="EMBEDDINGS_WATSONX_NAME",
) )
iam_serviceid_crn: str | None = Field( iam_serviceid_crn: str | None = Field(
default=None, default=None,
description="IAM service ID CRN", description="IAM service ID CRN",
validation_alias="EMBEDDINGS_WATSON_IAM_SERVICEID_CRN", validation_alias="EMBEDDINGS_WATSONX_IAM_SERVICEID_CRN",
) )
trusted_profile_id: str | None = Field( trusted_profile_id: str | None = Field(
default=None, default=None,
description="Trusted profile ID", description="Trusted profile ID",
validation_alias="EMBEDDINGS_WATSON_TRUSTED_PROFILE_ID", validation_alias="EMBEDDINGS_WATSONX_TRUSTED_PROFILE_ID",
) )
token: str | None = Field( token: str | None = Field(
default=None, default=None,
description="Bearer token", description="Bearer token",
validation_alias="EMBEDDINGS_WATSON_TOKEN", validation_alias="EMBEDDINGS_WATSONX_TOKEN",
) )
projects_token: str | None = Field( projects_token: str | None = Field(
default=None, default=None,
description="Projects token", description="Projects token",
validation_alias="EMBEDDINGS_WATSON_PROJECTS_TOKEN", validation_alias="EMBEDDINGS_WATSONX_PROJECTS_TOKEN",
) )
username: str | None = Field( username: str | None = Field(
default=None, default=None,
description="Username", description="Username",
validation_alias="EMBEDDINGS_WATSON_USERNAME", validation_alias="EMBEDDINGS_WATSONX_USERNAME",
) )
password: str | None = Field( password: str | None = Field(
default=None, default=None,
description="Password", description="Password",
validation_alias="EMBEDDINGS_WATSON_PASSWORD", validation_alias="EMBEDDINGS_WATSONX_PASSWORD",
) )
instance_id: str | None = Field( instance_id: str | None = Field(
default=None, default=None,
description="Service instance ID", description="Service instance ID",
validation_alias="EMBEDDINGS_WATSON_INSTANCE_ID", validation_alias="EMBEDDINGS_WATSONX_INSTANCE_ID",
) )
version: str | None = Field( version: str | None = Field(
default=None, default=None,
description="API version", description="API version",
validation_alias="EMBEDDINGS_WATSON_VERSION", validation_alias="EMBEDDINGS_WATSONX_VERSION",
) )
bedrock_url: str | None = Field( bedrock_url: str | None = Field(
default=None, default=None,
description="Bedrock URL", description="Bedrock URL",
validation_alias="EMBEDDINGS_WATSON_BEDROCK_URL", validation_alias="EMBEDDINGS_WATSONX_BEDROCK_URL",
) )
platform_url: str | None = Field( platform_url: str | None = Field(
default=None, default=None,
description="Platform URL", description="Platform URL",
validation_alias="EMBEDDINGS_WATSON_PLATFORM_URL", validation_alias="EMBEDDINGS_WATSONX_PLATFORM_URL",
) )
proxies: dict | None = Field(default=None, description="Proxy configuration") proxies: dict | None = Field(default=None, description="Proxy configuration")

View File

@@ -11,7 +11,7 @@ from crewai.rag.embeddings.providers.google.types import (
VertexAIProviderSpec, VertexAIProviderSpec,
) )
from crewai.rag.embeddings.providers.huggingface.types import HuggingFaceProviderSpec from crewai.rag.embeddings.providers.huggingface.types import HuggingFaceProviderSpec
from crewai.rag.embeddings.providers.ibm.types import WatsonProviderSpec from crewai.rag.embeddings.providers.ibm.types import WatsonXProviderSpec
from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec
from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec
from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec
@@ -44,7 +44,7 @@ ProviderSpec = (
| Text2VecProviderSpec | Text2VecProviderSpec
| VertexAIProviderSpec | VertexAIProviderSpec
| VoyageAIProviderSpec | VoyageAIProviderSpec
| WatsonProviderSpec | WatsonXProviderSpec
) )
AllowedEmbeddingProviders = Literal[ AllowedEmbeddingProviders = Literal[
@@ -65,7 +65,8 @@ AllowedEmbeddingProviders = Literal[
"sentence-transformer", "sentence-transformer",
"text2vec", "text2vec",
"voyageai", "voyageai",
"watson", "watsonx",
"watson", # for backward compatibility until v1.0.0
] ]
EmbedderConfig: TypeAlias = ( EmbedderConfig: TypeAlias = (

View File

@@ -150,8 +150,8 @@ class TestEmbeddingFactory:
) )
@patch("crewai.rag.embeddings.factory.import_and_validate_definition") @patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_watson(self, mock_import): def test_build_embedder_watsonx(self, mock_import):
"""Test building Watson embedder.""" """Test building WatsonX embedder."""
mock_provider_class = MagicMock() mock_provider_class = MagicMock()
mock_provider_instance = MagicMock() mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock() mock_embedding_function = MagicMock()
@@ -161,10 +161,10 @@ class TestEmbeddingFactory:
mock_provider_instance.embedding_callable.return_value = mock_embedding_function mock_provider_instance.embedding_callable.return_value = mock_embedding_function
config = { config = {
"provider": "watson", "provider": "watsonx",
"config": { "config": {
"model_id": "ibm/slate-125m-english-rtrvr", "model_id": "ibm/slate-125m-english-rtrvr",
"api_key": "watson-key", "api_key": "watsonx-key",
"url": "https://us-south.ml.cloud.ibm.com", "url": "https://us-south.ml.cloud.ibm.com",
"project_id": "test-project", "project_id": "test-project",
}, },
@@ -173,7 +173,7 @@ class TestEmbeddingFactory:
build_embedder(config) build_embedder(config)
mock_import.assert_called_once_with( mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.ibm.watson.WatsonProvider" "crewai.rag.embeddings.providers.ibm.watsonx.WatsonXProvider"
) )
def test_build_embedder_unknown_provider(self): def test_build_embedder_unknown_provider(self):