mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 00:58:30 +00:00
fix: rename watson to watsonx embedding provider and prefix env vars
- prefix provider env vars with embeddings_ - rename watson → watsonx in providers - add deprecation warning and alias for legacy 'watson' key (to be removed in v1.0.0)
This commit is contained in:
@@ -2,8 +2,11 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import warnings
|
||||||
from typing import TYPE_CHECKING, TypeVar, overload
|
from typing import TYPE_CHECKING, TypeVar, overload
|
||||||
|
|
||||||
|
from typing_extensions import deprecated
|
||||||
|
|
||||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||||
from crewai.utilities.import_utils import import_and_validate_definition
|
from crewai.utilities.import_utils import import_and_validate_definition
|
||||||
@@ -59,9 +62,12 @@ if TYPE_CHECKING:
|
|||||||
HuggingFaceProviderSpec,
|
HuggingFaceProviderSpec,
|
||||||
)
|
)
|
||||||
from crewai.rag.embeddings.providers.ibm.embedding_callable import (
|
from crewai.rag.embeddings.providers.ibm.embedding_callable import (
|
||||||
WatsonEmbeddingFunction,
|
WatsonXEmbeddingFunction,
|
||||||
|
)
|
||||||
|
from crewai.rag.embeddings.providers.ibm.types import (
|
||||||
|
WatsonProviderSpec,
|
||||||
|
WatsonXProviderSpec,
|
||||||
)
|
)
|
||||||
from crewai.rag.embeddings.providers.ibm.types import WatsonProviderSpec
|
|
||||||
from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec
|
from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec
|
||||||
from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec
|
from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec
|
||||||
from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec
|
from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec
|
||||||
@@ -100,7 +106,8 @@ PROVIDER_PATHS = {
|
|||||||
"sentence-transformer": "crewai.rag.embeddings.providers.sentence_transformer.sentence_transformer_provider.SentenceTransformerProvider",
|
"sentence-transformer": "crewai.rag.embeddings.providers.sentence_transformer.sentence_transformer_provider.SentenceTransformerProvider",
|
||||||
"text2vec": "crewai.rag.embeddings.providers.text2vec.text2vec_provider.Text2VecProvider",
|
"text2vec": "crewai.rag.embeddings.providers.text2vec.text2vec_provider.Text2VecProvider",
|
||||||
"voyageai": "crewai.rag.embeddings.providers.voyageai.voyageai_provider.VoyageAIProvider",
|
"voyageai": "crewai.rag.embeddings.providers.voyageai.voyageai_provider.VoyageAIProvider",
|
||||||
"watson": "crewai.rag.embeddings.providers.ibm.watson.WatsonProvider",
|
"watson": "crewai.rag.embeddings.providers.ibm.watsonx.WatsonXProvider", # Deprecated alias
|
||||||
|
"watsonx": "crewai.rag.embeddings.providers.ibm.watsonx.WatsonXProvider",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -169,7 +176,14 @@ def build_embedder_from_dict(
|
|||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def build_embedder_from_dict(spec: WatsonProviderSpec) -> WatsonEmbeddingFunction: ...
|
def build_embedder_from_dict(spec: WatsonXProviderSpec) -> WatsonXEmbeddingFunction: ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
@deprecated(
|
||||||
|
'The "WatsonProviderSpec" provider spec is deprecated and will be removed in v1.0.0. Use "WatsonXProviderSpec" instead.'
|
||||||
|
)
|
||||||
|
def build_embedder_from_dict(spec: WatsonProviderSpec) -> WatsonXEmbeddingFunction: ...
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
@@ -233,6 +247,14 @@ def build_embedder_from_dict(spec):
|
|||||||
if not provider_name:
|
if not provider_name:
|
||||||
raise ValueError("Missing 'provider' key in specification")
|
raise ValueError("Missing 'provider' key in specification")
|
||||||
|
|
||||||
|
if provider_name == "watson":
|
||||||
|
warnings.warn(
|
||||||
|
'The "watson" provider key is deprecated and will be removed in v1.0.0. '
|
||||||
|
'Use "watsonx" instead.',
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
|
||||||
if provider_name not in PROVIDER_PATHS:
|
if provider_name not in PROVIDER_PATHS:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown provider: {provider_name}. Available providers: {list(PROVIDER_PATHS.keys())}"
|
f"Unknown provider: {provider_name}. Available providers: {list(PROVIDER_PATHS.keys())}"
|
||||||
@@ -300,7 +322,14 @@ def build_embedder(spec: VoyageAIProviderSpec) -> VoyageAIEmbeddingFunction: ...
|
|||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def build_embedder(spec: WatsonProviderSpec) -> WatsonEmbeddingFunction: ...
|
def build_embedder(spec: WatsonXProviderSpec) -> WatsonXEmbeddingFunction: ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
@deprecated(
|
||||||
|
'The "WatsonProviderSpec" provider spec is deprecated and will be removed in v1.0.0. Use "WatsonXProviderSpec" instead.'
|
||||||
|
)
|
||||||
|
def build_embedder(spec: WatsonProviderSpec) -> WatsonXEmbeddingFunction: ...
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|||||||
@@ -1,15 +1,17 @@
|
|||||||
"""IBM embedding providers."""
|
"""IBM embedding providers."""
|
||||||
|
|
||||||
from crewai.rag.embeddings.providers.ibm.types import (
|
from crewai.rag.embeddings.providers.ibm.types import (
|
||||||
WatsonProviderConfig,
|
|
||||||
WatsonProviderSpec,
|
WatsonProviderSpec,
|
||||||
|
WatsonXProviderConfig,
|
||||||
|
WatsonXProviderSpec,
|
||||||
)
|
)
|
||||||
from crewai.rag.embeddings.providers.ibm.watson import (
|
from crewai.rag.embeddings.providers.ibm.watsonx import (
|
||||||
WatsonProvider,
|
WatsonXProvider,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"WatsonProvider",
|
|
||||||
"WatsonProviderConfig",
|
|
||||||
"WatsonProviderSpec",
|
"WatsonProviderSpec",
|
||||||
|
"WatsonXProvider",
|
||||||
|
"WatsonXProviderConfig",
|
||||||
|
"WatsonXProviderSpec",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
"""IBM Watson embedding function implementation."""
|
"""IBM WatsonX embedding function implementation."""
|
||||||
|
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
@@ -6,17 +6,17 @@ from typing_extensions import Unpack
|
|||||||
|
|
||||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||||
from crewai.rag.core.types import Documents, Embeddings
|
from crewai.rag.core.types import Documents, Embeddings
|
||||||
from crewai.rag.embeddings.providers.ibm.types import WatsonProviderConfig
|
from crewai.rag.embeddings.providers.ibm.types import WatsonXProviderConfig
|
||||||
|
|
||||||
|
|
||||||
class WatsonEmbeddingFunction(EmbeddingFunction[Documents]):
|
class WatsonXEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||||
"""Embedding function for IBM Watson models."""
|
"""Embedding function for IBM WatsonX models."""
|
||||||
|
|
||||||
def __init__(self, **kwargs: Unpack[WatsonProviderConfig]) -> None:
|
def __init__(self, **kwargs: Unpack[WatsonXProviderConfig]) -> None:
|
||||||
"""Initialize Watson embedding function.
|
"""Initialize WatsonX embedding function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
**kwargs: Configuration parameters for Watson Embeddings and Credentials.
|
**kwargs: Configuration parameters for WatsonX Embeddings and Credentials.
|
||||||
"""
|
"""
|
||||||
self._config = kwargs
|
self._config = kwargs
|
||||||
|
|
||||||
@@ -40,7 +40,7 @@ class WatsonEmbeddingFunction(EmbeddingFunction[Documents]):
|
|||||||
|
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"ibm-watsonx-ai is required for watson embeddings. "
|
"ibm-watsonx-ai is required for watsonx embeddings. "
|
||||||
"Install it with: uv add ibm-watsonx-ai"
|
"Install it with: uv add ibm-watsonx-ai"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
@@ -150,5 +150,5 @@ class WatsonEmbeddingFunction(EmbeddingFunction[Documents]):
|
|||||||
embeddings = embedding.embed_documents(input)
|
embeddings = embedding.embed_documents(input)
|
||||||
return cast(Embeddings, embeddings)
|
return cast(Embeddings, embeddings)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error during Watson embedding: {e}")
|
print(f"Error during WatsonX embedding: {e}")
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
"""Type definitions for IBM Watson embedding providers."""
|
"""Type definitions for IBM WatsonX embedding providers."""
|
||||||
|
|
||||||
from typing import Annotated, Any, Literal
|
from typing import Annotated, Any, Literal
|
||||||
|
|
||||||
from typing_extensions import Required, TypedDict
|
from typing_extensions import Required, TypedDict, deprecated
|
||||||
|
|
||||||
|
|
||||||
class WatsonProviderConfig(TypedDict, total=False):
|
class WatsonXProviderConfig(TypedDict, total=False):
|
||||||
"""Configuration for Watson provider."""
|
"""Configuration for WatsonX provider."""
|
||||||
|
|
||||||
model_id: str
|
model_id: str
|
||||||
url: str
|
url: str
|
||||||
@@ -37,8 +37,22 @@ class WatsonProviderConfig(TypedDict, total=False):
|
|||||||
proxies: dict
|
proxies: dict
|
||||||
|
|
||||||
|
|
||||||
|
class WatsonXProviderSpec(TypedDict, total=False):
|
||||||
|
"""WatsonX provider specification."""
|
||||||
|
|
||||||
|
provider: Required[Literal["watsonx"]]
|
||||||
|
config: WatsonXProviderConfig
|
||||||
|
|
||||||
|
|
||||||
|
@deprecated(
|
||||||
|
'The "WatsonProviderSpec" provider spec is deprecated and will be removed in v1.0.0. Use "WatsonXProviderSpec" instead.'
|
||||||
|
)
|
||||||
class WatsonProviderSpec(TypedDict, total=False):
|
class WatsonProviderSpec(TypedDict, total=False):
|
||||||
"""Watson provider specification."""
|
"""Watson provider specification (deprecated).
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- This is deprecated. Use WatsonXProviderSpec with provider="watsonx" instead.
|
||||||
|
"""
|
||||||
|
|
||||||
provider: Required[Literal["watson"]]
|
provider: Required[Literal["watson"]]
|
||||||
config: WatsonProviderConfig
|
config: WatsonXProviderConfig
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
"""IBM Watson embeddings provider."""
|
"""IBM WatsonX embeddings provider."""
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -7,130 +7,130 @@ from typing_extensions import Self
|
|||||||
|
|
||||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||||
from crewai.rag.embeddings.providers.ibm.embedding_callable import (
|
from crewai.rag.embeddings.providers.ibm.embedding_callable import (
|
||||||
WatsonEmbeddingFunction,
|
WatsonXEmbeddingFunction,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class WatsonProvider(BaseEmbeddingsProvider[WatsonEmbeddingFunction]):
|
class WatsonXProvider(BaseEmbeddingsProvider[WatsonXEmbeddingFunction]):
|
||||||
"""IBM Watson embeddings provider.
|
"""IBM WatsonX embeddings provider.
|
||||||
|
|
||||||
Note: Requires custom implementation as Watson uses a different interface.
|
Note: Requires custom implementation as WatsonX uses a different interface.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
embedding_callable: type[WatsonEmbeddingFunction] = Field(
|
embedding_callable: type[WatsonXEmbeddingFunction] = Field(
|
||||||
default=WatsonEmbeddingFunction, description="Watson embedding function class"
|
default=WatsonXEmbeddingFunction, description="WatsonX embedding function class"
|
||||||
)
|
)
|
||||||
model_id: str = Field(
|
model_id: str = Field(
|
||||||
description="Watson model ID", validation_alias="EMBEDDINGS_WATSON_MODEL_ID"
|
description="WatsonX model ID", validation_alias="EMBEDDINGS_WATSONX_MODEL_ID"
|
||||||
)
|
)
|
||||||
params: dict[str, str | dict[str, str]] | None = Field(
|
params: dict[str, str | dict[str, str]] | None = Field(
|
||||||
default=None, description="Additional parameters"
|
default=None, description="Additional parameters"
|
||||||
)
|
)
|
||||||
credentials: Any | None = Field(default=None, description="Watson credentials")
|
credentials: Any | None = Field(default=None, description="WatsonX credentials")
|
||||||
project_id: str | None = Field(
|
project_id: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Watson project ID",
|
description="WatsonX project ID",
|
||||||
validation_alias="EMBEDDINGS_WATSON_PROJECT_ID",
|
validation_alias="EMBEDDINGS_WATSONX_PROJECT_ID",
|
||||||
)
|
)
|
||||||
space_id: str | None = Field(
|
space_id: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Watson space ID",
|
description="WatsonX space ID",
|
||||||
validation_alias="EMBEDDINGS_WATSON_SPACE_ID",
|
validation_alias="EMBEDDINGS_WATSONX_SPACE_ID",
|
||||||
)
|
)
|
||||||
api_client: Any | None = Field(default=None, description="Watson API client")
|
api_client: Any | None = Field(default=None, description="WatsonX API client")
|
||||||
verify: bool | str | None = Field(
|
verify: bool | str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="SSL verification",
|
description="SSL verification",
|
||||||
validation_alias="EMBEDDINGS_WATSON_VERIFY",
|
validation_alias="EMBEDDINGS_WATSONX_VERIFY",
|
||||||
)
|
)
|
||||||
persistent_connection: bool = Field(
|
persistent_connection: bool = Field(
|
||||||
default=True,
|
default=True,
|
||||||
description="Use persistent connection",
|
description="Use persistent connection",
|
||||||
validation_alias="EMBEDDINGS_WATSON_PERSISTENT_CONNECTION",
|
validation_alias="EMBEDDINGS_WATSONX_PERSISTENT_CONNECTION",
|
||||||
)
|
)
|
||||||
batch_size: int = Field(
|
batch_size: int = Field(
|
||||||
default=100,
|
default=100,
|
||||||
description="Batch size for processing",
|
description="Batch size for processing",
|
||||||
validation_alias="EMBEDDINGS_WATSON_BATCH_SIZE",
|
validation_alias="EMBEDDINGS_WATSONX_BATCH_SIZE",
|
||||||
)
|
)
|
||||||
concurrency_limit: int = Field(
|
concurrency_limit: int = Field(
|
||||||
default=10,
|
default=10,
|
||||||
description="Concurrency limit",
|
description="Concurrency limit",
|
||||||
validation_alias="EMBEDDINGS_WATSON_CONCURRENCY_LIMIT",
|
validation_alias="EMBEDDINGS_WATSONX_CONCURRENCY_LIMIT",
|
||||||
)
|
)
|
||||||
max_retries: int | None = Field(
|
max_retries: int | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Maximum retries",
|
description="Maximum retries",
|
||||||
validation_alias="EMBEDDINGS_WATSON_MAX_RETRIES",
|
validation_alias="EMBEDDINGS_WATSONX_MAX_RETRIES",
|
||||||
)
|
)
|
||||||
delay_time: float | None = Field(
|
delay_time: float | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Delay time between retries",
|
description="Delay time between retries",
|
||||||
validation_alias="EMBEDDINGS_WATSON_DELAY_TIME",
|
validation_alias="EMBEDDINGS_WATSONX_DELAY_TIME",
|
||||||
)
|
)
|
||||||
retry_status_codes: list[int] | None = Field(
|
retry_status_codes: list[int] | None = Field(
|
||||||
default=None, description="HTTP status codes to retry on"
|
default=None, description="HTTP status codes to retry on"
|
||||||
)
|
)
|
||||||
url: str = Field(
|
url: str = Field(
|
||||||
description="Watson API URL", validation_alias="EMBEDDINGS_WATSON_URL"
|
description="WatsonX API URL", validation_alias="EMBEDDINGS_WATSONX_URL"
|
||||||
)
|
)
|
||||||
api_key: str = Field(
|
api_key: str = Field(
|
||||||
description="Watson API key", validation_alias="EMBEDDINGS_WATSON_API_KEY"
|
description="WatsonX API key", validation_alias="EMBEDDINGS_WATSONX_API_KEY"
|
||||||
)
|
)
|
||||||
name: str | None = Field(
|
name: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Service name",
|
description="Service name",
|
||||||
validation_alias="EMBEDDINGS_WATSON_NAME",
|
validation_alias="EMBEDDINGS_WATSONX_NAME",
|
||||||
)
|
)
|
||||||
iam_serviceid_crn: str | None = Field(
|
iam_serviceid_crn: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="IAM service ID CRN",
|
description="IAM service ID CRN",
|
||||||
validation_alias="EMBEDDINGS_WATSON_IAM_SERVICEID_CRN",
|
validation_alias="EMBEDDINGS_WATSONX_IAM_SERVICEID_CRN",
|
||||||
)
|
)
|
||||||
trusted_profile_id: str | None = Field(
|
trusted_profile_id: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Trusted profile ID",
|
description="Trusted profile ID",
|
||||||
validation_alias="EMBEDDINGS_WATSON_TRUSTED_PROFILE_ID",
|
validation_alias="EMBEDDINGS_WATSONX_TRUSTED_PROFILE_ID",
|
||||||
)
|
)
|
||||||
token: str | None = Field(
|
token: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Bearer token",
|
description="Bearer token",
|
||||||
validation_alias="EMBEDDINGS_WATSON_TOKEN",
|
validation_alias="EMBEDDINGS_WATSONX_TOKEN",
|
||||||
)
|
)
|
||||||
projects_token: str | None = Field(
|
projects_token: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Projects token",
|
description="Projects token",
|
||||||
validation_alias="EMBEDDINGS_WATSON_PROJECTS_TOKEN",
|
validation_alias="EMBEDDINGS_WATSONX_PROJECTS_TOKEN",
|
||||||
)
|
)
|
||||||
username: str | None = Field(
|
username: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Username",
|
description="Username",
|
||||||
validation_alias="EMBEDDINGS_WATSON_USERNAME",
|
validation_alias="EMBEDDINGS_WATSONX_USERNAME",
|
||||||
)
|
)
|
||||||
password: str | None = Field(
|
password: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Password",
|
description="Password",
|
||||||
validation_alias="EMBEDDINGS_WATSON_PASSWORD",
|
validation_alias="EMBEDDINGS_WATSONX_PASSWORD",
|
||||||
)
|
)
|
||||||
instance_id: str | None = Field(
|
instance_id: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Service instance ID",
|
description="Service instance ID",
|
||||||
validation_alias="EMBEDDINGS_WATSON_INSTANCE_ID",
|
validation_alias="EMBEDDINGS_WATSONX_INSTANCE_ID",
|
||||||
)
|
)
|
||||||
version: str | None = Field(
|
version: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="API version",
|
description="API version",
|
||||||
validation_alias="EMBEDDINGS_WATSON_VERSION",
|
validation_alias="EMBEDDINGS_WATSONX_VERSION",
|
||||||
)
|
)
|
||||||
bedrock_url: str | None = Field(
|
bedrock_url: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Bedrock URL",
|
description="Bedrock URL",
|
||||||
validation_alias="EMBEDDINGS_WATSON_BEDROCK_URL",
|
validation_alias="EMBEDDINGS_WATSONX_BEDROCK_URL",
|
||||||
)
|
)
|
||||||
platform_url: str | None = Field(
|
platform_url: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Platform URL",
|
description="Platform URL",
|
||||||
validation_alias="EMBEDDINGS_WATSON_PLATFORM_URL",
|
validation_alias="EMBEDDINGS_WATSONX_PLATFORM_URL",
|
||||||
)
|
)
|
||||||
proxies: dict | None = Field(default=None, description="Proxy configuration")
|
proxies: dict | None = Field(default=None, description="Proxy configuration")
|
||||||
|
|
||||||
@@ -11,7 +11,7 @@ from crewai.rag.embeddings.providers.google.types import (
|
|||||||
VertexAIProviderSpec,
|
VertexAIProviderSpec,
|
||||||
)
|
)
|
||||||
from crewai.rag.embeddings.providers.huggingface.types import HuggingFaceProviderSpec
|
from crewai.rag.embeddings.providers.huggingface.types import HuggingFaceProviderSpec
|
||||||
from crewai.rag.embeddings.providers.ibm.types import WatsonProviderSpec
|
from crewai.rag.embeddings.providers.ibm.types import WatsonXProviderSpec
|
||||||
from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec
|
from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec
|
||||||
from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec
|
from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec
|
||||||
from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec
|
from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec
|
||||||
@@ -44,7 +44,7 @@ ProviderSpec = (
|
|||||||
| Text2VecProviderSpec
|
| Text2VecProviderSpec
|
||||||
| VertexAIProviderSpec
|
| VertexAIProviderSpec
|
||||||
| VoyageAIProviderSpec
|
| VoyageAIProviderSpec
|
||||||
| WatsonProviderSpec
|
| WatsonXProviderSpec
|
||||||
)
|
)
|
||||||
|
|
||||||
AllowedEmbeddingProviders = Literal[
|
AllowedEmbeddingProviders = Literal[
|
||||||
@@ -65,7 +65,8 @@ AllowedEmbeddingProviders = Literal[
|
|||||||
"sentence-transformer",
|
"sentence-transformer",
|
||||||
"text2vec",
|
"text2vec",
|
||||||
"voyageai",
|
"voyageai",
|
||||||
"watson",
|
"watsonx",
|
||||||
|
"watson", # for backward compatibility until v1.0.0
|
||||||
]
|
]
|
||||||
|
|
||||||
EmbedderConfig: TypeAlias = (
|
EmbedderConfig: TypeAlias = (
|
||||||
|
|||||||
@@ -150,8 +150,8 @@ class TestEmbeddingFactory:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||||
def test_build_embedder_watson(self, mock_import):
|
def test_build_embedder_watsonx(self, mock_import):
|
||||||
"""Test building Watson embedder."""
|
"""Test building WatsonX embedder."""
|
||||||
mock_provider_class = MagicMock()
|
mock_provider_class = MagicMock()
|
||||||
mock_provider_instance = MagicMock()
|
mock_provider_instance = MagicMock()
|
||||||
mock_embedding_function = MagicMock()
|
mock_embedding_function = MagicMock()
|
||||||
@@ -161,10 +161,10 @@ class TestEmbeddingFactory:
|
|||||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
"provider": "watson",
|
"provider": "watsonx",
|
||||||
"config": {
|
"config": {
|
||||||
"model_id": "ibm/slate-125m-english-rtrvr",
|
"model_id": "ibm/slate-125m-english-rtrvr",
|
||||||
"api_key": "watson-key",
|
"api_key": "watsonx-key",
|
||||||
"url": "https://us-south.ml.cloud.ibm.com",
|
"url": "https://us-south.ml.cloud.ibm.com",
|
||||||
"project_id": "test-project",
|
"project_id": "test-project",
|
||||||
},
|
},
|
||||||
@@ -173,7 +173,7 @@ class TestEmbeddingFactory:
|
|||||||
build_embedder(config)
|
build_embedder(config)
|
||||||
|
|
||||||
mock_import.assert_called_once_with(
|
mock_import.assert_called_once_with(
|
||||||
"crewai.rag.embeddings.providers.ibm.watson.WatsonProvider"
|
"crewai.rag.embeddings.providers.ibm.watsonx.WatsonXProvider"
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_build_embedder_unknown_provider(self):
|
def test_build_embedder_unknown_provider(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user