fix: rename watson to watsonx embedding provider and prefix env vars

- prefix provider env vars with embeddings_  
- rename watson → watsonx in providers  
- add deprecation warning and alias for legacy 'watson' key (to be removed in v1.0.0)
This commit is contained in:
Greyson LaLonde
2025-09-26 10:57:18 -04:00
committed by GitHub
parent 091d1267d8
commit 12fa7e2ff1
7 changed files with 112 additions and 66 deletions

View File

@@ -2,8 +2,11 @@
from __future__ import annotations
import warnings
from typing import TYPE_CHECKING, TypeVar, overload
from typing_extensions import deprecated
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.utilities.import_utils import import_and_validate_definition
@@ -59,9 +62,12 @@ if TYPE_CHECKING:
HuggingFaceProviderSpec,
)
from crewai.rag.embeddings.providers.ibm.embedding_callable import (
WatsonEmbeddingFunction,
WatsonXEmbeddingFunction,
)
from crewai.rag.embeddings.providers.ibm.types import (
WatsonProviderSpec,
WatsonXProviderSpec,
)
from crewai.rag.embeddings.providers.ibm.types import WatsonProviderSpec
from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec
from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec
from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec
@@ -100,7 +106,8 @@ PROVIDER_PATHS = {
"sentence-transformer": "crewai.rag.embeddings.providers.sentence_transformer.sentence_transformer_provider.SentenceTransformerProvider",
"text2vec": "crewai.rag.embeddings.providers.text2vec.text2vec_provider.Text2VecProvider",
"voyageai": "crewai.rag.embeddings.providers.voyageai.voyageai_provider.VoyageAIProvider",
"watson": "crewai.rag.embeddings.providers.ibm.watson.WatsonProvider",
"watson": "crewai.rag.embeddings.providers.ibm.watsonx.WatsonXProvider", # Deprecated alias
"watsonx": "crewai.rag.embeddings.providers.ibm.watsonx.WatsonXProvider",
}
@@ -169,7 +176,14 @@ def build_embedder_from_dict(
@overload
def build_embedder_from_dict(spec: WatsonProviderSpec) -> WatsonEmbeddingFunction: ...
def build_embedder_from_dict(spec: WatsonXProviderSpec) -> WatsonXEmbeddingFunction: ...
@overload
@deprecated(
'The "WatsonProviderSpec" provider spec is deprecated and will be removed in v1.0.0. Use "WatsonXProviderSpec" instead.'
)
def build_embedder_from_dict(spec: WatsonProviderSpec) -> WatsonXEmbeddingFunction: ...
@overload
@@ -233,6 +247,14 @@ def build_embedder_from_dict(spec):
if not provider_name:
raise ValueError("Missing 'provider' key in specification")
if provider_name == "watson":
warnings.warn(
'The "watson" provider key is deprecated and will be removed in v1.0.0. '
'Use "watsonx" instead.',
DeprecationWarning,
stacklevel=2,
)
if provider_name not in PROVIDER_PATHS:
raise ValueError(
f"Unknown provider: {provider_name}. Available providers: {list(PROVIDER_PATHS.keys())}"
@@ -300,7 +322,14 @@ def build_embedder(spec: VoyageAIProviderSpec) -> VoyageAIEmbeddingFunction: ...
@overload
def build_embedder(spec: WatsonProviderSpec) -> WatsonEmbeddingFunction: ...
def build_embedder(spec: WatsonXProviderSpec) -> WatsonXEmbeddingFunction: ...
@overload
@deprecated(
'The "WatsonProviderSpec" provider spec is deprecated and will be removed in v1.0.0. Use "WatsonXProviderSpec" instead.'
)
def build_embedder(spec: WatsonProviderSpec) -> WatsonXEmbeddingFunction: ...
@overload

View File

@@ -1,15 +1,17 @@
"""IBM embedding providers."""
from crewai.rag.embeddings.providers.ibm.types import (
WatsonProviderConfig,
WatsonProviderSpec,
WatsonXProviderConfig,
WatsonXProviderSpec,
)
from crewai.rag.embeddings.providers.ibm.watson import (
WatsonProvider,
from crewai.rag.embeddings.providers.ibm.watsonx import (
WatsonXProvider,
)
__all__ = [
"WatsonProvider",
"WatsonProviderConfig",
"WatsonProviderSpec",
"WatsonXProvider",
"WatsonXProviderConfig",
"WatsonXProviderSpec",
]

View File

@@ -1,4 +1,4 @@
"""IBM Watson embedding function implementation."""
"""IBM WatsonX embedding function implementation."""
from typing import cast
@@ -6,17 +6,17 @@ from typing_extensions import Unpack
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
from crewai.rag.core.types import Documents, Embeddings
from crewai.rag.embeddings.providers.ibm.types import WatsonProviderConfig
from crewai.rag.embeddings.providers.ibm.types import WatsonXProviderConfig
class WatsonEmbeddingFunction(EmbeddingFunction[Documents]):
"""Embedding function for IBM Watson models."""
class WatsonXEmbeddingFunction(EmbeddingFunction[Documents]):
"""Embedding function for IBM WatsonX models."""
def __init__(self, **kwargs: Unpack[WatsonProviderConfig]) -> None:
"""Initialize Watson embedding function.
def __init__(self, **kwargs: Unpack[WatsonXProviderConfig]) -> None:
"""Initialize WatsonX embedding function.
Args:
**kwargs: Configuration parameters for Watson Embeddings and Credentials.
**kwargs: Configuration parameters for WatsonX Embeddings and Credentials.
"""
self._config = kwargs
@@ -40,7 +40,7 @@ class WatsonEmbeddingFunction(EmbeddingFunction[Documents]):
except ImportError as e:
raise ImportError(
"ibm-watsonx-ai is required for watson embeddings. "
"ibm-watsonx-ai is required for watsonx embeddings. "
"Install it with: uv add ibm-watsonx-ai"
) from e
@@ -150,5 +150,5 @@ class WatsonEmbeddingFunction(EmbeddingFunction[Documents]):
embeddings = embedding.embed_documents(input)
return cast(Embeddings, embeddings)
except Exception as e:
print(f"Error during Watson embedding: {e}")
print(f"Error during WatsonX embedding: {e}")
raise

View File

@@ -1,12 +1,12 @@
"""Type definitions for IBM Watson embedding providers."""
"""Type definitions for IBM WatsonX embedding providers."""
from typing import Annotated, Any, Literal
from typing_extensions import Required, TypedDict
from typing_extensions import Required, TypedDict, deprecated
class WatsonProviderConfig(TypedDict, total=False):
"""Configuration for Watson provider."""
class WatsonXProviderConfig(TypedDict, total=False):
"""Configuration for WatsonX provider."""
model_id: str
url: str
@@ -37,8 +37,22 @@ class WatsonProviderConfig(TypedDict, total=False):
proxies: dict
class WatsonXProviderSpec(TypedDict, total=False):
"""WatsonX provider specification."""
provider: Required[Literal["watsonx"]]
config: WatsonXProviderConfig
@deprecated(
'The "WatsonProviderSpec" provider spec is deprecated and will be removed in v1.0.0. Use "WatsonXProviderSpec" instead.'
)
class WatsonProviderSpec(TypedDict, total=False):
"""Watson provider specification."""
"""Watson provider specification (deprecated).
Notes:
- This is deprecated. Use WatsonXProviderSpec with provider="watsonx" instead.
"""
provider: Required[Literal["watson"]]
config: WatsonProviderConfig
config: WatsonXProviderConfig

View File

@@ -1,4 +1,4 @@
"""IBM Watson embeddings provider."""
"""IBM WatsonX embeddings provider."""
from typing import Any
@@ -7,130 +7,130 @@ from typing_extensions import Self
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.rag.embeddings.providers.ibm.embedding_callable import (
WatsonEmbeddingFunction,
WatsonXEmbeddingFunction,
)
class WatsonProvider(BaseEmbeddingsProvider[WatsonEmbeddingFunction]):
"""IBM Watson embeddings provider.
class WatsonXProvider(BaseEmbeddingsProvider[WatsonXEmbeddingFunction]):
"""IBM WatsonX embeddings provider.
Note: Requires custom implementation as Watson uses a different interface.
Note: Requires custom implementation as WatsonX uses a different interface.
"""
embedding_callable: type[WatsonEmbeddingFunction] = Field(
default=WatsonEmbeddingFunction, description="Watson embedding function class"
embedding_callable: type[WatsonXEmbeddingFunction] = Field(
default=WatsonXEmbeddingFunction, description="WatsonX embedding function class"
)
model_id: str = Field(
description="Watson model ID", validation_alias="EMBEDDINGS_WATSON_MODEL_ID"
description="WatsonX model ID", validation_alias="EMBEDDINGS_WATSONX_MODEL_ID"
)
params: dict[str, str | dict[str, str]] | None = Field(
default=None, description="Additional parameters"
)
credentials: Any | None = Field(default=None, description="Watson credentials")
credentials: Any | None = Field(default=None, description="WatsonX credentials")
project_id: str | None = Field(
default=None,
description="Watson project ID",
validation_alias="EMBEDDINGS_WATSON_PROJECT_ID",
description="WatsonX project ID",
validation_alias="EMBEDDINGS_WATSONX_PROJECT_ID",
)
space_id: str | None = Field(
default=None,
description="Watson space ID",
validation_alias="EMBEDDINGS_WATSON_SPACE_ID",
description="WatsonX space ID",
validation_alias="EMBEDDINGS_WATSONX_SPACE_ID",
)
api_client: Any | None = Field(default=None, description="Watson API client")
api_client: Any | None = Field(default=None, description="WatsonX API client")
verify: bool | str | None = Field(
default=None,
description="SSL verification",
validation_alias="EMBEDDINGS_WATSON_VERIFY",
validation_alias="EMBEDDINGS_WATSONX_VERIFY",
)
persistent_connection: bool = Field(
default=True,
description="Use persistent connection",
validation_alias="EMBEDDINGS_WATSON_PERSISTENT_CONNECTION",
validation_alias="EMBEDDINGS_WATSONX_PERSISTENT_CONNECTION",
)
batch_size: int = Field(
default=100,
description="Batch size for processing",
validation_alias="EMBEDDINGS_WATSON_BATCH_SIZE",
validation_alias="EMBEDDINGS_WATSONX_BATCH_SIZE",
)
concurrency_limit: int = Field(
default=10,
description="Concurrency limit",
validation_alias="EMBEDDINGS_WATSON_CONCURRENCY_LIMIT",
validation_alias="EMBEDDINGS_WATSONX_CONCURRENCY_LIMIT",
)
max_retries: int | None = Field(
default=None,
description="Maximum retries",
validation_alias="EMBEDDINGS_WATSON_MAX_RETRIES",
validation_alias="EMBEDDINGS_WATSONX_MAX_RETRIES",
)
delay_time: float | None = Field(
default=None,
description="Delay time between retries",
validation_alias="EMBEDDINGS_WATSON_DELAY_TIME",
validation_alias="EMBEDDINGS_WATSONX_DELAY_TIME",
)
retry_status_codes: list[int] | None = Field(
default=None, description="HTTP status codes to retry on"
)
url: str = Field(
description="Watson API URL", validation_alias="EMBEDDINGS_WATSON_URL"
description="WatsonX API URL", validation_alias="EMBEDDINGS_WATSONX_URL"
)
api_key: str = Field(
description="Watson API key", validation_alias="EMBEDDINGS_WATSON_API_KEY"
description="WatsonX API key", validation_alias="EMBEDDINGS_WATSONX_API_KEY"
)
name: str | None = Field(
default=None,
description="Service name",
validation_alias="EMBEDDINGS_WATSON_NAME",
validation_alias="EMBEDDINGS_WATSONX_NAME",
)
iam_serviceid_crn: str | None = Field(
default=None,
description="IAM service ID CRN",
validation_alias="EMBEDDINGS_WATSON_IAM_SERVICEID_CRN",
validation_alias="EMBEDDINGS_WATSONX_IAM_SERVICEID_CRN",
)
trusted_profile_id: str | None = Field(
default=None,
description="Trusted profile ID",
validation_alias="EMBEDDINGS_WATSON_TRUSTED_PROFILE_ID",
validation_alias="EMBEDDINGS_WATSONX_TRUSTED_PROFILE_ID",
)
token: str | None = Field(
default=None,
description="Bearer token",
validation_alias="EMBEDDINGS_WATSON_TOKEN",
validation_alias="EMBEDDINGS_WATSONX_TOKEN",
)
projects_token: str | None = Field(
default=None,
description="Projects token",
validation_alias="EMBEDDINGS_WATSON_PROJECTS_TOKEN",
validation_alias="EMBEDDINGS_WATSONX_PROJECTS_TOKEN",
)
username: str | None = Field(
default=None,
description="Username",
validation_alias="EMBEDDINGS_WATSON_USERNAME",
validation_alias="EMBEDDINGS_WATSONX_USERNAME",
)
password: str | None = Field(
default=None,
description="Password",
validation_alias="EMBEDDINGS_WATSON_PASSWORD",
validation_alias="EMBEDDINGS_WATSONX_PASSWORD",
)
instance_id: str | None = Field(
default=None,
description="Service instance ID",
validation_alias="EMBEDDINGS_WATSON_INSTANCE_ID",
validation_alias="EMBEDDINGS_WATSONX_INSTANCE_ID",
)
version: str | None = Field(
default=None,
description="API version",
validation_alias="EMBEDDINGS_WATSON_VERSION",
validation_alias="EMBEDDINGS_WATSONX_VERSION",
)
bedrock_url: str | None = Field(
default=None,
description="Bedrock URL",
validation_alias="EMBEDDINGS_WATSON_BEDROCK_URL",
validation_alias="EMBEDDINGS_WATSONX_BEDROCK_URL",
)
platform_url: str | None = Field(
default=None,
description="Platform URL",
validation_alias="EMBEDDINGS_WATSON_PLATFORM_URL",
validation_alias="EMBEDDINGS_WATSONX_PLATFORM_URL",
)
proxies: dict | None = Field(default=None, description="Proxy configuration")

View File

@@ -11,7 +11,7 @@ from crewai.rag.embeddings.providers.google.types import (
VertexAIProviderSpec,
)
from crewai.rag.embeddings.providers.huggingface.types import HuggingFaceProviderSpec
from crewai.rag.embeddings.providers.ibm.types import WatsonProviderSpec
from crewai.rag.embeddings.providers.ibm.types import WatsonXProviderSpec
from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec
from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec
from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec
@@ -44,7 +44,7 @@ ProviderSpec = (
| Text2VecProviderSpec
| VertexAIProviderSpec
| VoyageAIProviderSpec
| WatsonProviderSpec
| WatsonXProviderSpec
)
AllowedEmbeddingProviders = Literal[
@@ -65,7 +65,8 @@ AllowedEmbeddingProviders = Literal[
"sentence-transformer",
"text2vec",
"voyageai",
"watson",
"watsonx",
"watson", # for backward compatibility until v1.0.0
]
EmbedderConfig: TypeAlias = (