mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-29 02:38:29 +00:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7d5cd4d3e2 | ||
|
|
73e932bfee | ||
|
|
12fa7e2ff1 | ||
|
|
091d1267d8 |
@@ -40,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
|
||||
|
||||
_suppress_pydantic_deprecation_warnings()
|
||||
|
||||
__version__ = "0.201.0"
|
||||
__version__ = "0.201.1"
|
||||
_telemetry_submitted = False
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]>=0.201.0,<1.0.0"
|
||||
"crewai[tools]>=0.201.1,<1.0.0"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]>=0.201.0,<1.0.0",
|
||||
"crewai[tools]>=0.201.1,<1.0.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "Power up your crews with {{folder_name}}"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]>=0.201.0"
|
||||
"crewai[tools]>=0.201.1"
|
||||
]
|
||||
|
||||
[tool.crewai]
|
||||
|
||||
@@ -140,3 +140,10 @@ class EmbeddingFunction(Protocol[D]):
|
||||
return validate_embeddings(normalized)
|
||||
|
||||
cls.__call__ = wrapped_call # type: ignore[method-assign]
|
||||
|
||||
def embed_query(self, input: D) -> Embeddings:
|
||||
"""
|
||||
Get the embeddings for a query input.
|
||||
This method is optional, and if not implemented, the default behavior is to call __call__.
|
||||
"""
|
||||
return self.__call__(input=input)
|
||||
|
||||
@@ -2,8 +2,11 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, TypeVar, overload
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.utilities.import_utils import import_and_validate_definition
|
||||
@@ -59,9 +62,12 @@ if TYPE_CHECKING:
|
||||
HuggingFaceProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.ibm.embedding_callable import (
|
||||
WatsonEmbeddingFunction,
|
||||
WatsonXEmbeddingFunction,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.ibm.types import (
|
||||
WatsonProviderSpec,
|
||||
WatsonXProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.ibm.types import WatsonProviderSpec
|
||||
from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec
|
||||
from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec
|
||||
from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec
|
||||
@@ -100,7 +106,8 @@ PROVIDER_PATHS = {
|
||||
"sentence-transformer": "crewai.rag.embeddings.providers.sentence_transformer.sentence_transformer_provider.SentenceTransformerProvider",
|
||||
"text2vec": "crewai.rag.embeddings.providers.text2vec.text2vec_provider.Text2VecProvider",
|
||||
"voyageai": "crewai.rag.embeddings.providers.voyageai.voyageai_provider.VoyageAIProvider",
|
||||
"watson": "crewai.rag.embeddings.providers.ibm.watson.WatsonProvider",
|
||||
"watson": "crewai.rag.embeddings.providers.ibm.watsonx.WatsonXProvider", # Deprecated alias
|
||||
"watsonx": "crewai.rag.embeddings.providers.ibm.watsonx.WatsonXProvider",
|
||||
}
|
||||
|
||||
|
||||
@@ -169,7 +176,14 @@ def build_embedder_from_dict(
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(spec: WatsonProviderSpec) -> WatsonEmbeddingFunction: ...
|
||||
def build_embedder_from_dict(spec: WatsonXProviderSpec) -> WatsonXEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
@deprecated(
|
||||
'The "WatsonProviderSpec" provider spec is deprecated and will be removed in v1.0.0. Use "WatsonXProviderSpec" instead.'
|
||||
)
|
||||
def build_embedder_from_dict(spec: WatsonProviderSpec) -> WatsonXEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
@@ -233,6 +247,14 @@ def build_embedder_from_dict(spec):
|
||||
if not provider_name:
|
||||
raise ValueError("Missing 'provider' key in specification")
|
||||
|
||||
if provider_name == "watson":
|
||||
warnings.warn(
|
||||
'The "watson" provider key is deprecated and will be removed in v1.0.0. '
|
||||
'Use "watsonx" instead.',
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if provider_name not in PROVIDER_PATHS:
|
||||
raise ValueError(
|
||||
f"Unknown provider: {provider_name}. Available providers: {list(PROVIDER_PATHS.keys())}"
|
||||
@@ -300,7 +322,14 @@ def build_embedder(spec: VoyageAIProviderSpec) -> VoyageAIEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: WatsonProviderSpec) -> WatsonEmbeddingFunction: ...
|
||||
def build_embedder(spec: WatsonXProviderSpec) -> WatsonXEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
@deprecated(
|
||||
'The "WatsonProviderSpec" provider spec is deprecated and will be removed in v1.0.0. Use "WatsonXProviderSpec" instead.'
|
||||
)
|
||||
def build_embedder(spec: WatsonProviderSpec) -> WatsonXEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
|
||||
@@ -46,7 +46,7 @@ class BedrockProvider(BaseEmbeddingsProvider[AmazonBedrockEmbeddingFunction]):
|
||||
model_name: str = Field(
|
||||
default="amazon.titan-embed-text-v1",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="BEDROCK_MODEL_NAME",
|
||||
validation_alias="EMBEDDINGS_BEDROCK_MODEL_NAME",
|
||||
)
|
||||
session: Any = Field(
|
||||
default_factory=create_aws_session, description="AWS session object"
|
||||
|
||||
@@ -15,10 +15,10 @@ class CohereProvider(BaseEmbeddingsProvider[CohereEmbeddingFunction]):
|
||||
default=CohereEmbeddingFunction, description="Cohere embedding function class"
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Cohere API key", validation_alias="COHERE_API_KEY"
|
||||
description="Cohere API key", validation_alias="EMBEDDINGS_COHERE_API_KEY"
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="large",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="COHERE_MODEL_NAME",
|
||||
validation_alias="EMBEDDINGS_COHERE_MODEL_NAME",
|
||||
)
|
||||
|
||||
@@ -18,13 +18,13 @@ class GenerativeAiProvider(BaseEmbeddingsProvider[GoogleGenerativeAiEmbeddingFun
|
||||
model_name: str = Field(
|
||||
default="models/embedding-001",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="GOOGLE_GENERATIVE_AI_MODEL_NAME",
|
||||
validation_alias="EMBEDDINGS_GOOGLE_GENERATIVE_AI_MODEL_NAME",
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Google API key", validation_alias="GOOGLE_API_KEY"
|
||||
description="Google API key", validation_alias="EMBEDDINGS_GOOGLE_API_KEY"
|
||||
)
|
||||
task_type: str = Field(
|
||||
default="RETRIEVAL_DOCUMENT",
|
||||
description="Task type for embeddings",
|
||||
validation_alias="GOOGLE_GENERATIVE_AI_TASK_TYPE",
|
||||
validation_alias="EMBEDDINGS_GOOGLE_GENERATIVE_AI_TASK_TYPE",
|
||||
)
|
||||
|
||||
@@ -18,18 +18,18 @@ class VertexAIProvider(BaseEmbeddingsProvider[GoogleVertexEmbeddingFunction]):
|
||||
model_name: str = Field(
|
||||
default="textembedding-gecko",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="GOOGLE_VERTEX_MODEL_NAME",
|
||||
validation_alias="EMBEDDINGS_GOOGLE_VERTEX_MODEL_NAME",
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Google API key", validation_alias="GOOGLE_CLOUD_API_KEY"
|
||||
description="Google API key", validation_alias="EMBEDDINGS_GOOGLE_CLOUD_API_KEY"
|
||||
)
|
||||
project_id: str = Field(
|
||||
default="cloud-large-language-models",
|
||||
description="GCP project ID",
|
||||
validation_alias="GOOGLE_CLOUD_PROJECT",
|
||||
validation_alias="EMBEDDINGS_GOOGLE_CLOUD_PROJECT",
|
||||
)
|
||||
region: str = Field(
|
||||
default="us-central1",
|
||||
description="GCP region",
|
||||
validation_alias="GOOGLE_CLOUD_REGION",
|
||||
validation_alias="EMBEDDINGS_GOOGLE_CLOUD_REGION",
|
||||
)
|
||||
|
||||
@@ -16,5 +16,5 @@ class HuggingFaceProvider(BaseEmbeddingsProvider[HuggingFaceEmbeddingServer]):
|
||||
description="HuggingFace embedding function class",
|
||||
)
|
||||
url: str = Field(
|
||||
description="HuggingFace API URL", validation_alias="HUGGINGFACE_URL"
|
||||
description="HuggingFace API URL", validation_alias="EMBEDDINGS_HUGGINGFACE_URL"
|
||||
)
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
"""IBM embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.ibm.types import (
|
||||
WatsonProviderConfig,
|
||||
WatsonProviderSpec,
|
||||
WatsonXProviderConfig,
|
||||
WatsonXProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.ibm.watson import (
|
||||
WatsonProvider,
|
||||
from crewai.rag.embeddings.providers.ibm.watsonx import (
|
||||
WatsonXProvider,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"WatsonProvider",
|
||||
"WatsonProviderConfig",
|
||||
"WatsonProviderSpec",
|
||||
"WatsonXProvider",
|
||||
"WatsonXProviderConfig",
|
||||
"WatsonXProviderSpec",
|
||||
]
|
||||
|
||||
@@ -1,25 +1,30 @@
|
||||
"""IBM Watson embedding function implementation."""
|
||||
"""IBM WatsonX embedding function implementation."""
|
||||
|
||||
from typing import cast
|
||||
|
||||
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||
from crewai.rag.core.types import Documents, Embeddings
|
||||
from crewai.rag.embeddings.providers.ibm.types import WatsonProviderConfig
|
||||
from crewai.rag.embeddings.providers.ibm.types import WatsonXProviderConfig
|
||||
|
||||
|
||||
class WatsonEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"""Embedding function for IBM Watson models."""
|
||||
class WatsonXEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"""Embedding function for IBM WatsonX models."""
|
||||
|
||||
def __init__(self, **kwargs: Unpack[WatsonProviderConfig]) -> None:
|
||||
"""Initialize Watson embedding function.
|
||||
def __init__(self, **kwargs: Unpack[WatsonXProviderConfig]) -> None:
|
||||
"""Initialize WatsonX embedding function.
|
||||
|
||||
Args:
|
||||
**kwargs: Configuration parameters for Watson Embeddings and Credentials.
|
||||
**kwargs: Configuration parameters for WatsonX Embeddings and Credentials.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._config = kwargs
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
"""Return the name of the embedding function for ChromaDB compatibility."""
|
||||
return "watsonx"
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
"""Generate embeddings for input documents.
|
||||
|
||||
@@ -40,7 +45,7 @@ class WatsonEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"ibm-watsonx-ai is required for watson embeddings. "
|
||||
"ibm-watsonx-ai is required for watsonx embeddings. "
|
||||
"Install it with: uv add ibm-watsonx-ai"
|
||||
) from e
|
||||
|
||||
@@ -150,5 +155,5 @@ class WatsonEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
embeddings = embedding.embed_documents(input)
|
||||
return cast(Embeddings, embeddings)
|
||||
except Exception as e:
|
||||
print(f"Error during Watson embedding: {e}")
|
||||
print(f"Error during WatsonX embedding: {e}")
|
||||
raise
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
"""Type definitions for IBM Watson embedding providers."""
|
||||
"""Type definitions for IBM WatsonX embedding providers."""
|
||||
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
from typing_extensions import Required, TypedDict, deprecated
|
||||
|
||||
|
||||
class WatsonProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Watson provider."""
|
||||
class WatsonXProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for WatsonX provider."""
|
||||
|
||||
model_id: str
|
||||
url: str
|
||||
@@ -37,8 +37,22 @@ class WatsonProviderConfig(TypedDict, total=False):
|
||||
proxies: dict
|
||||
|
||||
|
||||
class WatsonXProviderSpec(TypedDict, total=False):
|
||||
"""WatsonX provider specification."""
|
||||
|
||||
provider: Required[Literal["watsonx"]]
|
||||
config: WatsonXProviderConfig
|
||||
|
||||
|
||||
@deprecated(
|
||||
'The "WatsonProviderSpec" provider spec is deprecated and will be removed in v1.0.0. Use "WatsonXProviderSpec" instead.'
|
||||
)
|
||||
class WatsonProviderSpec(TypedDict, total=False):
|
||||
"""Watson provider specification."""
|
||||
"""Watson provider specification (deprecated).
|
||||
|
||||
Notes:
|
||||
- This is deprecated. Use WatsonXProviderSpec with provider="watsonx" instead.
|
||||
"""
|
||||
|
||||
provider: Required[Literal["watson"]]
|
||||
config: WatsonProviderConfig
|
||||
config: WatsonXProviderConfig
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""IBM Watson embeddings provider."""
|
||||
"""IBM WatsonX embeddings provider."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
@@ -7,110 +7,130 @@ from typing_extensions import Self
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.providers.ibm.embedding_callable import (
|
||||
WatsonEmbeddingFunction,
|
||||
WatsonXEmbeddingFunction,
|
||||
)
|
||||
|
||||
|
||||
class WatsonProvider(BaseEmbeddingsProvider[WatsonEmbeddingFunction]):
|
||||
"""IBM Watson embeddings provider.
|
||||
class WatsonXProvider(BaseEmbeddingsProvider[WatsonXEmbeddingFunction]):
|
||||
"""IBM WatsonX embeddings provider.
|
||||
|
||||
Note: Requires custom implementation as Watson uses a different interface.
|
||||
Note: Requires custom implementation as WatsonX uses a different interface.
|
||||
"""
|
||||
|
||||
embedding_callable: type[WatsonEmbeddingFunction] = Field(
|
||||
default=WatsonEmbeddingFunction, description="Watson embedding function class"
|
||||
embedding_callable: type[WatsonXEmbeddingFunction] = Field(
|
||||
default=WatsonXEmbeddingFunction, description="WatsonX embedding function class"
|
||||
)
|
||||
model_id: str = Field(
|
||||
description="Watson model ID", validation_alias="WATSON_MODEL_ID"
|
||||
description="WatsonX model ID", validation_alias="EMBEDDINGS_WATSONX_MODEL_ID"
|
||||
)
|
||||
params: dict[str, str | dict[str, str]] | None = Field(
|
||||
default=None, description="Additional parameters"
|
||||
)
|
||||
credentials: Any | None = Field(default=None, description="Watson credentials")
|
||||
credentials: Any | None = Field(default=None, description="WatsonX credentials")
|
||||
project_id: str | None = Field(
|
||||
default=None,
|
||||
description="Watson project ID",
|
||||
validation_alias="WATSON_PROJECT_ID",
|
||||
description="WatsonX project ID",
|
||||
validation_alias="EMBEDDINGS_WATSONX_PROJECT_ID",
|
||||
)
|
||||
space_id: str | None = Field(
|
||||
default=None, description="Watson space ID", validation_alias="WATSON_SPACE_ID"
|
||||
default=None,
|
||||
description="WatsonX space ID",
|
||||
validation_alias="EMBEDDINGS_WATSONX_SPACE_ID",
|
||||
)
|
||||
api_client: Any | None = Field(default=None, description="Watson API client")
|
||||
api_client: Any | None = Field(default=None, description="WatsonX API client")
|
||||
verify: bool | str | None = Field(
|
||||
default=None, description="SSL verification", validation_alias="WATSON_VERIFY"
|
||||
default=None,
|
||||
description="SSL verification",
|
||||
validation_alias="EMBEDDINGS_WATSONX_VERIFY",
|
||||
)
|
||||
persistent_connection: bool = Field(
|
||||
default=True,
|
||||
description="Use persistent connection",
|
||||
validation_alias="WATSON_PERSISTENT_CONNECTION",
|
||||
validation_alias="EMBEDDINGS_WATSONX_PERSISTENT_CONNECTION",
|
||||
)
|
||||
batch_size: int = Field(
|
||||
default=100,
|
||||
description="Batch size for processing",
|
||||
validation_alias="WATSON_BATCH_SIZE",
|
||||
validation_alias="EMBEDDINGS_WATSONX_BATCH_SIZE",
|
||||
)
|
||||
concurrency_limit: int = Field(
|
||||
default=10,
|
||||
description="Concurrency limit",
|
||||
validation_alias="WATSON_CONCURRENCY_LIMIT",
|
||||
validation_alias="EMBEDDINGS_WATSONX_CONCURRENCY_LIMIT",
|
||||
)
|
||||
max_retries: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum retries",
|
||||
validation_alias="WATSON_MAX_RETRIES",
|
||||
validation_alias="EMBEDDINGS_WATSONX_MAX_RETRIES",
|
||||
)
|
||||
delay_time: float | None = Field(
|
||||
default=None,
|
||||
description="Delay time between retries",
|
||||
validation_alias="WATSON_DELAY_TIME",
|
||||
validation_alias="EMBEDDINGS_WATSONX_DELAY_TIME",
|
||||
)
|
||||
retry_status_codes: list[int] | None = Field(
|
||||
default=None, description="HTTP status codes to retry on"
|
||||
)
|
||||
url: str = Field(description="Watson API URL", validation_alias="WATSON_URL")
|
||||
url: str = Field(
|
||||
description="WatsonX API URL", validation_alias="EMBEDDINGS_WATSONX_URL"
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Watson API key", validation_alias="WATSON_API_KEY"
|
||||
description="WatsonX API key", validation_alias="EMBEDDINGS_WATSONX_API_KEY"
|
||||
)
|
||||
name: str | None = Field(
|
||||
default=None, description="Service name", validation_alias="WATSON_NAME"
|
||||
default=None,
|
||||
description="Service name",
|
||||
validation_alias="EMBEDDINGS_WATSONX_NAME",
|
||||
)
|
||||
iam_serviceid_crn: str | None = Field(
|
||||
default=None,
|
||||
description="IAM service ID CRN",
|
||||
validation_alias="WATSON_IAM_SERVICEID_CRN",
|
||||
validation_alias="EMBEDDINGS_WATSONX_IAM_SERVICEID_CRN",
|
||||
)
|
||||
trusted_profile_id: str | None = Field(
|
||||
default=None,
|
||||
description="Trusted profile ID",
|
||||
validation_alias="WATSON_TRUSTED_PROFILE_ID",
|
||||
validation_alias="EMBEDDINGS_WATSONX_TRUSTED_PROFILE_ID",
|
||||
)
|
||||
token: str | None = Field(
|
||||
default=None, description="Bearer token", validation_alias="WATSON_TOKEN"
|
||||
default=None,
|
||||
description="Bearer token",
|
||||
validation_alias="EMBEDDINGS_WATSONX_TOKEN",
|
||||
)
|
||||
projects_token: str | None = Field(
|
||||
default=None,
|
||||
description="Projects token",
|
||||
validation_alias="WATSON_PROJECTS_TOKEN",
|
||||
validation_alias="EMBEDDINGS_WATSONX_PROJECTS_TOKEN",
|
||||
)
|
||||
username: str | None = Field(
|
||||
default=None, description="Username", validation_alias="WATSON_USERNAME"
|
||||
default=None,
|
||||
description="Username",
|
||||
validation_alias="EMBEDDINGS_WATSONX_USERNAME",
|
||||
)
|
||||
password: str | None = Field(
|
||||
default=None, description="Password", validation_alias="WATSON_PASSWORD"
|
||||
default=None,
|
||||
description="Password",
|
||||
validation_alias="EMBEDDINGS_WATSONX_PASSWORD",
|
||||
)
|
||||
instance_id: str | None = Field(
|
||||
default=None,
|
||||
description="Service instance ID",
|
||||
validation_alias="WATSON_INSTANCE_ID",
|
||||
validation_alias="EMBEDDINGS_WATSONX_INSTANCE_ID",
|
||||
)
|
||||
version: str | None = Field(
|
||||
default=None, description="API version", validation_alias="WATSON_VERSION"
|
||||
default=None,
|
||||
description="API version",
|
||||
validation_alias="EMBEDDINGS_WATSONX_VERSION",
|
||||
)
|
||||
bedrock_url: str | None = Field(
|
||||
default=None, description="Bedrock URL", validation_alias="WATSON_BEDROCK_URL"
|
||||
default=None,
|
||||
description="Bedrock URL",
|
||||
validation_alias="EMBEDDINGS_WATSONX_BEDROCK_URL",
|
||||
)
|
||||
platform_url: str | None = Field(
|
||||
default=None, description="Platform URL", validation_alias="WATSON_PLATFORM_URL"
|
||||
default=None,
|
||||
description="Platform URL",
|
||||
validation_alias="EMBEDDINGS_WATSONX_PLATFORM_URL",
|
||||
)
|
||||
proxies: dict | None = Field(default=None, description="Proxy configuration")
|
||||
|
||||
@@ -18,15 +18,15 @@ class InstructorProvider(BaseEmbeddingsProvider[InstructorEmbeddingFunction]):
|
||||
model_name: str = Field(
|
||||
default="hkunlp/instructor-base",
|
||||
description="Model name to use",
|
||||
validation_alias="INSTRUCTOR_MODEL_NAME",
|
||||
validation_alias="EMBEDDINGS_INSTRUCTOR_MODEL_NAME",
|
||||
)
|
||||
device: str = Field(
|
||||
default="cpu",
|
||||
description="Device to run model on (cpu or cuda)",
|
||||
validation_alias="INSTRUCTOR_DEVICE",
|
||||
validation_alias="EMBEDDINGS_INSTRUCTOR_DEVICE",
|
||||
)
|
||||
instruction: str | None = Field(
|
||||
default=None,
|
||||
description="Instruction for embeddings",
|
||||
validation_alias="INSTRUCTOR_INSTRUCTION",
|
||||
validation_alias="EMBEDDINGS_INSTRUCTOR_INSTRUCTION",
|
||||
)
|
||||
|
||||
@@ -14,9 +14,11 @@ class JinaProvider(BaseEmbeddingsProvider[JinaEmbeddingFunction]):
|
||||
embedding_callable: type[JinaEmbeddingFunction] = Field(
|
||||
default=JinaEmbeddingFunction, description="Jina embedding function class"
|
||||
)
|
||||
api_key: str = Field(description="Jina API key", validation_alias="JINA_API_KEY")
|
||||
api_key: str = Field(
|
||||
description="Jina API key", validation_alias="EMBEDDINGS_JINA_API_KEY"
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="jina-embeddings-v2-base-en",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="JINA_MODEL_NAME",
|
||||
validation_alias="EMBEDDINGS_JINA_MODEL_NAME",
|
||||
)
|
||||
|
||||
@@ -17,26 +17,28 @@ class AzureProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
|
||||
default=OpenAIEmbeddingFunction,
|
||||
description="Azure OpenAI embedding function class",
|
||||
)
|
||||
api_key: str = Field(description="Azure API key", validation_alias="OPENAI_API_KEY")
|
||||
api_key: str = Field(
|
||||
description="Azure API key", validation_alias="EMBEDDINGS_OPENAI_API_KEY"
|
||||
)
|
||||
api_base: str | None = Field(
|
||||
default=None,
|
||||
description="Azure endpoint URL",
|
||||
validation_alias="OPENAI_API_BASE",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_BASE",
|
||||
)
|
||||
api_type: str = Field(
|
||||
default="azure",
|
||||
description="API type for Azure",
|
||||
validation_alias="OPENAI_API_TYPE",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_TYPE",
|
||||
)
|
||||
api_version: str | None = Field(
|
||||
default=None,
|
||||
description="Azure API version",
|
||||
validation_alias="OPENAI_API_VERSION",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_VERSION",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="text-embedding-ada-002",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="OPENAI_MODEL_NAME",
|
||||
validation_alias="EMBEDDINGS_OPENAI_MODEL_NAME",
|
||||
)
|
||||
default_headers: dict[str, Any] | None = Field(
|
||||
default=None, description="Default headers for API requests"
|
||||
@@ -44,15 +46,15 @@ class AzureProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
|
||||
dimensions: int | None = Field(
|
||||
default=None,
|
||||
description="Embedding dimensions",
|
||||
validation_alias="OPENAI_DIMENSIONS",
|
||||
validation_alias="EMBEDDINGS_OPENAI_DIMENSIONS",
|
||||
)
|
||||
deployment_id: str | None = Field(
|
||||
default=None,
|
||||
description="Azure deployment ID",
|
||||
validation_alias="OPENAI_DEPLOYMENT_ID",
|
||||
validation_alias="EMBEDDINGS_OPENAI_DEPLOYMENT_ID",
|
||||
)
|
||||
organization_id: str | None = Field(
|
||||
default=None,
|
||||
description="Organization ID",
|
||||
validation_alias="OPENAI_ORGANIZATION_ID",
|
||||
validation_alias="EMBEDDINGS_OPENAI_ORGANIZATION_ID",
|
||||
)
|
||||
|
||||
@@ -17,9 +17,9 @@ class OllamaProvider(BaseEmbeddingsProvider[OllamaEmbeddingFunction]):
|
||||
url: str = Field(
|
||||
default="http://localhost:11434/api/embeddings",
|
||||
description="Ollama API endpoint URL",
|
||||
validation_alias="OLLAMA_URL",
|
||||
validation_alias="EMBEDDINGS_OLLAMA_URL",
|
||||
)
|
||||
model_name: str = Field(
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="OLLAMA_MODEL_NAME",
|
||||
validation_alias="EMBEDDINGS_OLLAMA_MODEL_NAME",
|
||||
)
|
||||
|
||||
@@ -15,5 +15,5 @@ class ONNXProvider(BaseEmbeddingsProvider[ONNXMiniLM_L6_V2]):
|
||||
preferred_providers: list[str] | None = Field(
|
||||
default=None,
|
||||
description="Preferred ONNX execution providers",
|
||||
validation_alias="ONNX_PREFERRED_PROVIDERS",
|
||||
validation_alias="EMBEDDINGS_ONNX_PREFERRED_PROVIDERS",
|
||||
)
|
||||
|
||||
@@ -18,25 +18,29 @@ class OpenAIProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
|
||||
description="OpenAI embedding function class",
|
||||
)
|
||||
api_key: str | None = Field(
|
||||
default=None, description="OpenAI API key", validation_alias="OPENAI_API_KEY"
|
||||
default=None,
|
||||
description="OpenAI API key",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_KEY",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="text-embedding-ada-002",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="OPENAI_MODEL_NAME",
|
||||
validation_alias="EMBEDDINGS_OPENAI_MODEL_NAME",
|
||||
)
|
||||
api_base: str | None = Field(
|
||||
default=None,
|
||||
description="Base URL for API requests",
|
||||
validation_alias="OPENAI_API_BASE",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_BASE",
|
||||
)
|
||||
api_type: str | None = Field(
|
||||
default=None,
|
||||
description="API type (e.g., 'azure')",
|
||||
validation_alias="OPENAI_API_TYPE",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_TYPE",
|
||||
)
|
||||
api_version: str | None = Field(
|
||||
default=None, description="API version", validation_alias="OPENAI_API_VERSION"
|
||||
default=None,
|
||||
description="API version",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_VERSION",
|
||||
)
|
||||
default_headers: dict[str, Any] | None = Field(
|
||||
default=None, description="Default headers for API requests"
|
||||
@@ -44,15 +48,15 @@ class OpenAIProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
|
||||
dimensions: int | None = Field(
|
||||
default=None,
|
||||
description="Embedding dimensions",
|
||||
validation_alias="OPENAI_DIMENSIONS",
|
||||
validation_alias="EMBEDDINGS_OPENAI_DIMENSIONS",
|
||||
)
|
||||
deployment_id: str | None = Field(
|
||||
default=None,
|
||||
description="Azure deployment ID",
|
||||
validation_alias="OPENAI_DEPLOYMENT_ID",
|
||||
validation_alias="EMBEDDINGS_OPENAI_DEPLOYMENT_ID",
|
||||
)
|
||||
organization_id: str | None = Field(
|
||||
default=None,
|
||||
description="OpenAI organization ID",
|
||||
validation_alias="OPENAI_ORGANIZATION_ID",
|
||||
validation_alias="EMBEDDINGS_OPENAI_ORGANIZATION_ID",
|
||||
)
|
||||
|
||||
@@ -18,15 +18,15 @@ class OpenCLIPProvider(BaseEmbeddingsProvider[OpenCLIPEmbeddingFunction]):
|
||||
model_name: str = Field(
|
||||
default="ViT-B-32",
|
||||
description="Model name to use",
|
||||
validation_alias="OPENCLIP_MODEL_NAME",
|
||||
validation_alias="EMBEDDINGS_OPENCLIP_MODEL_NAME",
|
||||
)
|
||||
checkpoint: str = Field(
|
||||
default="laion2b_s34b_b79k",
|
||||
description="Model checkpoint",
|
||||
validation_alias="OPENCLIP_CHECKPOINT",
|
||||
validation_alias="EMBEDDINGS_OPENCLIP_CHECKPOINT",
|
||||
)
|
||||
device: str | None = Field(
|
||||
default="cpu",
|
||||
description="Device to run model on",
|
||||
validation_alias="OPENCLIP_DEVICE",
|
||||
validation_alias="EMBEDDINGS_OPENCLIP_DEVICE",
|
||||
)
|
||||
|
||||
@@ -16,10 +16,12 @@ class RoboflowProvider(BaseEmbeddingsProvider[RoboflowEmbeddingFunction]):
|
||||
description="Roboflow embedding function class",
|
||||
)
|
||||
api_key: str = Field(
|
||||
default="", description="Roboflow API key", validation_alias="ROBOFLOW_API_KEY"
|
||||
default="",
|
||||
description="Roboflow API key",
|
||||
validation_alias="EMBEDDINGS_ROBOFLOW_API_KEY",
|
||||
)
|
||||
api_url: str = Field(
|
||||
default="https://infer.roboflow.com",
|
||||
description="Roboflow API URL",
|
||||
validation_alias="ROBOFLOW_API_URL",
|
||||
validation_alias="EMBEDDINGS_ROBOFLOW_API_URL",
|
||||
)
|
||||
|
||||
@@ -20,15 +20,15 @@ class SentenceTransformerProvider(
|
||||
model_name: str = Field(
|
||||
default="all-MiniLM-L6-v2",
|
||||
description="Model name to use",
|
||||
validation_alias="SENTENCE_TRANSFORMER_MODEL_NAME",
|
||||
validation_alias="EMBEDDINGS_SENTENCE_TRANSFORMER_MODEL_NAME",
|
||||
)
|
||||
device: str = Field(
|
||||
default="cpu",
|
||||
description="Device to run model on (cpu or cuda)",
|
||||
validation_alias="SENTENCE_TRANSFORMER_DEVICE",
|
||||
validation_alias="EMBEDDINGS_SENTENCE_TRANSFORMER_DEVICE",
|
||||
)
|
||||
normalize_embeddings: bool = Field(
|
||||
default=False,
|
||||
description="Whether to normalize embeddings",
|
||||
validation_alias="SENTENCE_TRANSFORMER_NORMALIZE_EMBEDDINGS",
|
||||
validation_alias="EMBEDDINGS_SENTENCE_TRANSFORMER_NORMALIZE_EMBEDDINGS",
|
||||
)
|
||||
|
||||
@@ -18,5 +18,5 @@ class Text2VecProvider(BaseEmbeddingsProvider[Text2VecEmbeddingFunction]):
|
||||
model_name: str = Field(
|
||||
default="shibing624/text2vec-base-chinese",
|
||||
description="Model name to use",
|
||||
validation_alias="TEXT2VEC_MODEL_NAME",
|
||||
validation_alias="EMBEDDINGS_TEXT2VEC_MODEL_NAME",
|
||||
)
|
||||
|
||||
@@ -2,10 +2,9 @@
|
||||
|
||||
from typing import cast
|
||||
|
||||
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||
from crewai.rag.core.types import Documents, Embeddings
|
||||
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderConfig
|
||||
|
||||
|
||||
@@ -33,6 +32,11 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
timeout=kwargs.get("timeout"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
"""Return the name of the embedding function for ChromaDB compatibility."""
|
||||
return "voyageai"
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
"""Generate embeddings for input documents.
|
||||
|
||||
|
||||
@@ -18,38 +18,38 @@ class VoyageAIProvider(BaseEmbeddingsProvider[VoyageAIEmbeddingFunction]):
|
||||
model: str = Field(
|
||||
default="voyage-2",
|
||||
description="Model to use for embeddings",
|
||||
validation_alias="VOYAGEAI_MODEL",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_MODEL",
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Voyage AI API key", validation_alias="VOYAGEAI_API_KEY"
|
||||
description="Voyage AI API key", validation_alias="EMBEDDINGS_VOYAGEAI_API_KEY"
|
||||
)
|
||||
input_type: str | None = Field(
|
||||
default=None,
|
||||
description="Input type for embeddings",
|
||||
validation_alias="VOYAGEAI_INPUT_TYPE",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_INPUT_TYPE",
|
||||
)
|
||||
truncation: bool = Field(
|
||||
default=True,
|
||||
description="Whether to truncate inputs",
|
||||
validation_alias="VOYAGEAI_TRUNCATION",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_TRUNCATION",
|
||||
)
|
||||
output_dtype: str | None = Field(
|
||||
default=None,
|
||||
description="Output data type",
|
||||
validation_alias="VOYAGEAI_OUTPUT_DTYPE",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_OUTPUT_DTYPE",
|
||||
)
|
||||
output_dimension: int | None = Field(
|
||||
default=None,
|
||||
description="Output dimension",
|
||||
validation_alias="VOYAGEAI_OUTPUT_DIMENSION",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_OUTPUT_DIMENSION",
|
||||
)
|
||||
max_retries: int = Field(
|
||||
default=0,
|
||||
description="Maximum retries for API calls",
|
||||
validation_alias="VOYAGEAI_MAX_RETRIES",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_MAX_RETRIES",
|
||||
)
|
||||
timeout: float | None = Field(
|
||||
default=None,
|
||||
description="Timeout for API calls",
|
||||
validation_alias="VOYAGEAI_TIMEOUT",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_TIMEOUT",
|
||||
)
|
||||
|
||||
@@ -11,7 +11,10 @@ from crewai.rag.embeddings.providers.google.types import (
|
||||
VertexAIProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.huggingface.types import HuggingFaceProviderSpec
|
||||
from crewai.rag.embeddings.providers.ibm.types import WatsonProviderSpec
|
||||
from crewai.rag.embeddings.providers.ibm.types import (
|
||||
WatsonProviderSpec,
|
||||
WatsonXProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec
|
||||
from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec
|
||||
from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec
|
||||
@@ -44,7 +47,8 @@ ProviderSpec = (
|
||||
| Text2VecProviderSpec
|
||||
| VertexAIProviderSpec
|
||||
| VoyageAIProviderSpec
|
||||
| WatsonProviderSpec
|
||||
| WatsonProviderSpec # Deprecated, use WatsonXProviderSpec
|
||||
| WatsonXProviderSpec
|
||||
)
|
||||
|
||||
AllowedEmbeddingProviders = Literal[
|
||||
@@ -65,7 +69,8 @@ AllowedEmbeddingProviders = Literal[
|
||||
"sentence-transformer",
|
||||
"text2vec",
|
||||
"voyageai",
|
||||
"watson",
|
||||
"watsonx",
|
||||
"watson", # for backward compatibility until v1.0.0
|
||||
]
|
||||
|
||||
EmbedderConfig: TypeAlias = (
|
||||
|
||||
@@ -150,8 +150,8 @@ class TestEmbeddingFactory:
|
||||
)
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_build_embedder_watson(self, mock_import):
|
||||
"""Test building Watson embedder."""
|
||||
def test_build_embedder_watsonx(self, mock_import):
|
||||
"""Test building WatsonX embedder."""
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_embedding_function = MagicMock()
|
||||
@@ -161,10 +161,10 @@ class TestEmbeddingFactory:
|
||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||
|
||||
config = {
|
||||
"provider": "watson",
|
||||
"provider": "watsonx",
|
||||
"config": {
|
||||
"model_id": "ibm/slate-125m-english-rtrvr",
|
||||
"api_key": "watson-key",
|
||||
"api_key": "watsonx-key",
|
||||
"url": "https://us-south.ml.cloud.ibm.com",
|
||||
"project_id": "test-project",
|
||||
},
|
||||
@@ -173,7 +173,7 @@ class TestEmbeddingFactory:
|
||||
build_embedder(config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.ibm.watson.WatsonProvider"
|
||||
"crewai.rag.embeddings.providers.ibm.watsonx.WatsonXProvider"
|
||||
)
|
||||
|
||||
def test_build_embedder_unknown_provider(self):
|
||||
|
||||
Reference in New Issue
Block a user