Files
crewAI/src/crewai/rag/embeddings/providers/ibm/watsonx.py
Greyson LaLonde 12fa7e2ff1 fix: rename watson to watsonx embedding provider and prefix env vars
- prefix provider env vars with embeddings_  
- rename watson → watsonx in providers  
- add deprecation warning and alias for legacy 'watson' key (to be removed in v1.0.0)
2025-09-26 10:57:18 -04:00

143 lines
4.8 KiB
Python

"""IBM WatsonX embeddings provider."""
from typing import Any
from pydantic import Field, model_validator
from typing_extensions import Self
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.rag.embeddings.providers.ibm.embedding_callable import (
WatsonXEmbeddingFunction,
)
class WatsonXProvider(BaseEmbeddingsProvider[WatsonXEmbeddingFunction]):
"""IBM WatsonX embeddings provider.
Note: Requires custom implementation as WatsonX uses a different interface.
"""
embedding_callable: type[WatsonXEmbeddingFunction] = Field(
default=WatsonXEmbeddingFunction, description="WatsonX embedding function class"
)
model_id: str = Field(
description="WatsonX model ID", validation_alias="EMBEDDINGS_WATSONX_MODEL_ID"
)
params: dict[str, str | dict[str, str]] | None = Field(
default=None, description="Additional parameters"
)
credentials: Any | None = Field(default=None, description="WatsonX credentials")
project_id: str | None = Field(
default=None,
description="WatsonX project ID",
validation_alias="EMBEDDINGS_WATSONX_PROJECT_ID",
)
space_id: str | None = Field(
default=None,
description="WatsonX space ID",
validation_alias="EMBEDDINGS_WATSONX_SPACE_ID",
)
api_client: Any | None = Field(default=None, description="WatsonX API client")
verify: bool | str | None = Field(
default=None,
description="SSL verification",
validation_alias="EMBEDDINGS_WATSONX_VERIFY",
)
persistent_connection: bool = Field(
default=True,
description="Use persistent connection",
validation_alias="EMBEDDINGS_WATSONX_PERSISTENT_CONNECTION",
)
batch_size: int = Field(
default=100,
description="Batch size for processing",
validation_alias="EMBEDDINGS_WATSONX_BATCH_SIZE",
)
concurrency_limit: int = Field(
default=10,
description="Concurrency limit",
validation_alias="EMBEDDINGS_WATSONX_CONCURRENCY_LIMIT",
)
max_retries: int | None = Field(
default=None,
description="Maximum retries",
validation_alias="EMBEDDINGS_WATSONX_MAX_RETRIES",
)
delay_time: float | None = Field(
default=None,
description="Delay time between retries",
validation_alias="EMBEDDINGS_WATSONX_DELAY_TIME",
)
retry_status_codes: list[int] | None = Field(
default=None, description="HTTP status codes to retry on"
)
url: str = Field(
description="WatsonX API URL", validation_alias="EMBEDDINGS_WATSONX_URL"
)
api_key: str = Field(
description="WatsonX API key", validation_alias="EMBEDDINGS_WATSONX_API_KEY"
)
name: str | None = Field(
default=None,
description="Service name",
validation_alias="EMBEDDINGS_WATSONX_NAME",
)
iam_serviceid_crn: str | None = Field(
default=None,
description="IAM service ID CRN",
validation_alias="EMBEDDINGS_WATSONX_IAM_SERVICEID_CRN",
)
trusted_profile_id: str | None = Field(
default=None,
description="Trusted profile ID",
validation_alias="EMBEDDINGS_WATSONX_TRUSTED_PROFILE_ID",
)
token: str | None = Field(
default=None,
description="Bearer token",
validation_alias="EMBEDDINGS_WATSONX_TOKEN",
)
projects_token: str | None = Field(
default=None,
description="Projects token",
validation_alias="EMBEDDINGS_WATSONX_PROJECTS_TOKEN",
)
username: str | None = Field(
default=None,
description="Username",
validation_alias="EMBEDDINGS_WATSONX_USERNAME",
)
password: str | None = Field(
default=None,
description="Password",
validation_alias="EMBEDDINGS_WATSONX_PASSWORD",
)
instance_id: str | None = Field(
default=None,
description="Service instance ID",
validation_alias="EMBEDDINGS_WATSONX_INSTANCE_ID",
)
version: str | None = Field(
default=None,
description="API version",
validation_alias="EMBEDDINGS_WATSONX_VERSION",
)
bedrock_url: str | None = Field(
default=None,
description="Bedrock URL",
validation_alias="EMBEDDINGS_WATSONX_BEDROCK_URL",
)
platform_url: str | None = Field(
default=None,
description="Platform URL",
validation_alias="EMBEDDINGS_WATSONX_PLATFORM_URL",
)
proxies: dict | None = Field(default=None, description="Proxy configuration")
@model_validator(mode="after")
def validate_space_or_project(self) -> Self:
"""Validate that either space_id or project_id is provided."""
if not self.space_id and not self.project_id:
raise ValueError("One of 'space_id' or 'project_id' must be provided")
return self