mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 00:58:30 +00:00
- prefix provider env vars with embeddings_ - rename watson → watsonx in providers - add deprecation warning and alias for legacy 'watson' key (to be removed in v1.0.0)
143 lines
4.8 KiB
Python
143 lines
4.8 KiB
Python
"""IBM WatsonX embeddings provider."""
|
|
|
|
from typing import Any
|
|
|
|
from pydantic import Field, model_validator
|
|
from typing_extensions import Self
|
|
|
|
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
|
from crewai.rag.embeddings.providers.ibm.embedding_callable import (
|
|
WatsonXEmbeddingFunction,
|
|
)
|
|
|
|
|
|
class WatsonXProvider(BaseEmbeddingsProvider[WatsonXEmbeddingFunction]):
|
|
"""IBM WatsonX embeddings provider.
|
|
|
|
Note: Requires custom implementation as WatsonX uses a different interface.
|
|
"""
|
|
|
|
embedding_callable: type[WatsonXEmbeddingFunction] = Field(
|
|
default=WatsonXEmbeddingFunction, description="WatsonX embedding function class"
|
|
)
|
|
model_id: str = Field(
|
|
description="WatsonX model ID", validation_alias="EMBEDDINGS_WATSONX_MODEL_ID"
|
|
)
|
|
params: dict[str, str | dict[str, str]] | None = Field(
|
|
default=None, description="Additional parameters"
|
|
)
|
|
credentials: Any | None = Field(default=None, description="WatsonX credentials")
|
|
project_id: str | None = Field(
|
|
default=None,
|
|
description="WatsonX project ID",
|
|
validation_alias="EMBEDDINGS_WATSONX_PROJECT_ID",
|
|
)
|
|
space_id: str | None = Field(
|
|
default=None,
|
|
description="WatsonX space ID",
|
|
validation_alias="EMBEDDINGS_WATSONX_SPACE_ID",
|
|
)
|
|
api_client: Any | None = Field(default=None, description="WatsonX API client")
|
|
verify: bool | str | None = Field(
|
|
default=None,
|
|
description="SSL verification",
|
|
validation_alias="EMBEDDINGS_WATSONX_VERIFY",
|
|
)
|
|
persistent_connection: bool = Field(
|
|
default=True,
|
|
description="Use persistent connection",
|
|
validation_alias="EMBEDDINGS_WATSONX_PERSISTENT_CONNECTION",
|
|
)
|
|
batch_size: int = Field(
|
|
default=100,
|
|
description="Batch size for processing",
|
|
validation_alias="EMBEDDINGS_WATSONX_BATCH_SIZE",
|
|
)
|
|
concurrency_limit: int = Field(
|
|
default=10,
|
|
description="Concurrency limit",
|
|
validation_alias="EMBEDDINGS_WATSONX_CONCURRENCY_LIMIT",
|
|
)
|
|
max_retries: int | None = Field(
|
|
default=None,
|
|
description="Maximum retries",
|
|
validation_alias="EMBEDDINGS_WATSONX_MAX_RETRIES",
|
|
)
|
|
delay_time: float | None = Field(
|
|
default=None,
|
|
description="Delay time between retries",
|
|
validation_alias="EMBEDDINGS_WATSONX_DELAY_TIME",
|
|
)
|
|
retry_status_codes: list[int] | None = Field(
|
|
default=None, description="HTTP status codes to retry on"
|
|
)
|
|
url: str = Field(
|
|
description="WatsonX API URL", validation_alias="EMBEDDINGS_WATSONX_URL"
|
|
)
|
|
api_key: str = Field(
|
|
description="WatsonX API key", validation_alias="EMBEDDINGS_WATSONX_API_KEY"
|
|
)
|
|
name: str | None = Field(
|
|
default=None,
|
|
description="Service name",
|
|
validation_alias="EMBEDDINGS_WATSONX_NAME",
|
|
)
|
|
iam_serviceid_crn: str | None = Field(
|
|
default=None,
|
|
description="IAM service ID CRN",
|
|
validation_alias="EMBEDDINGS_WATSONX_IAM_SERVICEID_CRN",
|
|
)
|
|
trusted_profile_id: str | None = Field(
|
|
default=None,
|
|
description="Trusted profile ID",
|
|
validation_alias="EMBEDDINGS_WATSONX_TRUSTED_PROFILE_ID",
|
|
)
|
|
token: str | None = Field(
|
|
default=None,
|
|
description="Bearer token",
|
|
validation_alias="EMBEDDINGS_WATSONX_TOKEN",
|
|
)
|
|
projects_token: str | None = Field(
|
|
default=None,
|
|
description="Projects token",
|
|
validation_alias="EMBEDDINGS_WATSONX_PROJECTS_TOKEN",
|
|
)
|
|
username: str | None = Field(
|
|
default=None,
|
|
description="Username",
|
|
validation_alias="EMBEDDINGS_WATSONX_USERNAME",
|
|
)
|
|
password: str | None = Field(
|
|
default=None,
|
|
description="Password",
|
|
validation_alias="EMBEDDINGS_WATSONX_PASSWORD",
|
|
)
|
|
instance_id: str | None = Field(
|
|
default=None,
|
|
description="Service instance ID",
|
|
validation_alias="EMBEDDINGS_WATSONX_INSTANCE_ID",
|
|
)
|
|
version: str | None = Field(
|
|
default=None,
|
|
description="API version",
|
|
validation_alias="EMBEDDINGS_WATSONX_VERSION",
|
|
)
|
|
bedrock_url: str | None = Field(
|
|
default=None,
|
|
description="Bedrock URL",
|
|
validation_alias="EMBEDDINGS_WATSONX_BEDROCK_URL",
|
|
)
|
|
platform_url: str | None = Field(
|
|
default=None,
|
|
description="Platform URL",
|
|
validation_alias="EMBEDDINGS_WATSONX_PLATFORM_URL",
|
|
)
|
|
proxies: dict | None = Field(default=None, description="Proxy configuration")
|
|
|
|
@model_validator(mode="after")
|
|
def validate_space_or_project(self) -> Self:
|
|
"""Validate that either space_id or project_id is provided."""
|
|
if not self.space_id and not self.project_id:
|
|
raise ValueError("One of 'space_id' or 'project_id' must be provided")
|
|
return self
|