mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 08:38:30 +00:00
fix: prefix embedding provider env vars with EMBEDDINGS_
This commit is contained in:
@@ -46,7 +46,7 @@ class BedrockProvider(BaseEmbeddingsProvider[AmazonBedrockEmbeddingFunction]):
|
|||||||
model_name: str = Field(
|
model_name: str = Field(
|
||||||
default="amazon.titan-embed-text-v1",
|
default="amazon.titan-embed-text-v1",
|
||||||
description="Model name to use for embeddings",
|
description="Model name to use for embeddings",
|
||||||
validation_alias="BEDROCK_MODEL_NAME",
|
validation_alias="EMBEDDINGS_BEDROCK_MODEL_NAME",
|
||||||
)
|
)
|
||||||
session: Any = Field(
|
session: Any = Field(
|
||||||
default_factory=create_aws_session, description="AWS session object"
|
default_factory=create_aws_session, description="AWS session object"
|
||||||
|
|||||||
@@ -15,10 +15,10 @@ class CohereProvider(BaseEmbeddingsProvider[CohereEmbeddingFunction]):
|
|||||||
default=CohereEmbeddingFunction, description="Cohere embedding function class"
|
default=CohereEmbeddingFunction, description="Cohere embedding function class"
|
||||||
)
|
)
|
||||||
api_key: str = Field(
|
api_key: str = Field(
|
||||||
description="Cohere API key", validation_alias="COHERE_API_KEY"
|
description="Cohere API key", validation_alias="EMBEDDINGS_COHERE_API_KEY"
|
||||||
)
|
)
|
||||||
model_name: str = Field(
|
model_name: str = Field(
|
||||||
default="large",
|
default="large",
|
||||||
description="Model name to use for embeddings",
|
description="Model name to use for embeddings",
|
||||||
validation_alias="COHERE_MODEL_NAME",
|
validation_alias="EMBEDDINGS_COHERE_MODEL_NAME",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -18,13 +18,13 @@ class GenerativeAiProvider(BaseEmbeddingsProvider[GoogleGenerativeAiEmbeddingFun
|
|||||||
model_name: str = Field(
|
model_name: str = Field(
|
||||||
default="models/embedding-001",
|
default="models/embedding-001",
|
||||||
description="Model name to use for embeddings",
|
description="Model name to use for embeddings",
|
||||||
validation_alias="GOOGLE_GENERATIVE_AI_MODEL_NAME",
|
validation_alias="EMBEDDINGS_GOOGLE_GENERATIVE_AI_MODEL_NAME",
|
||||||
)
|
)
|
||||||
api_key: str = Field(
|
api_key: str = Field(
|
||||||
description="Google API key", validation_alias="GOOGLE_API_KEY"
|
description="Google API key", validation_alias="EMBEDDINGS_GOOGLE_API_KEY"
|
||||||
)
|
)
|
||||||
task_type: str = Field(
|
task_type: str = Field(
|
||||||
default="RETRIEVAL_DOCUMENT",
|
default="RETRIEVAL_DOCUMENT",
|
||||||
description="Task type for embeddings",
|
description="Task type for embeddings",
|
||||||
validation_alias="GOOGLE_GENERATIVE_AI_TASK_TYPE",
|
validation_alias="EMBEDDINGS_GOOGLE_GENERATIVE_AI_TASK_TYPE",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -18,18 +18,18 @@ class VertexAIProvider(BaseEmbeddingsProvider[GoogleVertexEmbeddingFunction]):
|
|||||||
model_name: str = Field(
|
model_name: str = Field(
|
||||||
default="textembedding-gecko",
|
default="textembedding-gecko",
|
||||||
description="Model name to use for embeddings",
|
description="Model name to use for embeddings",
|
||||||
validation_alias="GOOGLE_VERTEX_MODEL_NAME",
|
validation_alias="EMBEDDINGS_GOOGLE_VERTEX_MODEL_NAME",
|
||||||
)
|
)
|
||||||
api_key: str = Field(
|
api_key: str = Field(
|
||||||
description="Google API key", validation_alias="GOOGLE_CLOUD_API_KEY"
|
description="Google API key", validation_alias="EMBEDDINGS_GOOGLE_CLOUD_API_KEY"
|
||||||
)
|
)
|
||||||
project_id: str = Field(
|
project_id: str = Field(
|
||||||
default="cloud-large-language-models",
|
default="cloud-large-language-models",
|
||||||
description="GCP project ID",
|
description="GCP project ID",
|
||||||
validation_alias="GOOGLE_CLOUD_PROJECT",
|
validation_alias="EMBEDDINGS_GOOGLE_CLOUD_PROJECT",
|
||||||
)
|
)
|
||||||
region: str = Field(
|
region: str = Field(
|
||||||
default="us-central1",
|
default="us-central1",
|
||||||
description="GCP region",
|
description="GCP region",
|
||||||
validation_alias="GOOGLE_CLOUD_REGION",
|
validation_alias="EMBEDDINGS_GOOGLE_CLOUD_REGION",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -16,5 +16,5 @@ class HuggingFaceProvider(BaseEmbeddingsProvider[HuggingFaceEmbeddingServer]):
|
|||||||
description="HuggingFace embedding function class",
|
description="HuggingFace embedding function class",
|
||||||
)
|
)
|
||||||
url: str = Field(
|
url: str = Field(
|
||||||
description="HuggingFace API URL", validation_alias="HUGGINGFACE_URL"
|
description="HuggingFace API URL", validation_alias="EMBEDDINGS_HUGGINGFACE_URL"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ class WatsonProvider(BaseEmbeddingsProvider[WatsonEmbeddingFunction]):
|
|||||||
default=WatsonEmbeddingFunction, description="Watson embedding function class"
|
default=WatsonEmbeddingFunction, description="Watson embedding function class"
|
||||||
)
|
)
|
||||||
model_id: str = Field(
|
model_id: str = Field(
|
||||||
description="Watson model ID", validation_alias="WATSON_MODEL_ID"
|
description="Watson model ID", validation_alias="EMBEDDINGS_WATSON_MODEL_ID"
|
||||||
)
|
)
|
||||||
params: dict[str, str | dict[str, str]] | None = Field(
|
params: dict[str, str | dict[str, str]] | None = Field(
|
||||||
default=None, description="Additional parameters"
|
default=None, description="Additional parameters"
|
||||||
@@ -30,87 +30,107 @@ class WatsonProvider(BaseEmbeddingsProvider[WatsonEmbeddingFunction]):
|
|||||||
project_id: str | None = Field(
|
project_id: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Watson project ID",
|
description="Watson project ID",
|
||||||
validation_alias="WATSON_PROJECT_ID",
|
validation_alias="EMBEDDINGS_WATSON_PROJECT_ID",
|
||||||
)
|
)
|
||||||
space_id: str | None = Field(
|
space_id: str | None = Field(
|
||||||
default=None, description="Watson space ID", validation_alias="WATSON_SPACE_ID"
|
default=None,
|
||||||
|
description="Watson space ID",
|
||||||
|
validation_alias="EMBEDDINGS_WATSON_SPACE_ID",
|
||||||
)
|
)
|
||||||
api_client: Any | None = Field(default=None, description="Watson API client")
|
api_client: Any | None = Field(default=None, description="Watson API client")
|
||||||
verify: bool | str | None = Field(
|
verify: bool | str | None = Field(
|
||||||
default=None, description="SSL verification", validation_alias="WATSON_VERIFY"
|
default=None,
|
||||||
|
description="SSL verification",
|
||||||
|
validation_alias="EMBEDDINGS_WATSON_VERIFY",
|
||||||
)
|
)
|
||||||
persistent_connection: bool = Field(
|
persistent_connection: bool = Field(
|
||||||
default=True,
|
default=True,
|
||||||
description="Use persistent connection",
|
description="Use persistent connection",
|
||||||
validation_alias="WATSON_PERSISTENT_CONNECTION",
|
validation_alias="EMBEDDINGS_WATSON_PERSISTENT_CONNECTION",
|
||||||
)
|
)
|
||||||
batch_size: int = Field(
|
batch_size: int = Field(
|
||||||
default=100,
|
default=100,
|
||||||
description="Batch size for processing",
|
description="Batch size for processing",
|
||||||
validation_alias="WATSON_BATCH_SIZE",
|
validation_alias="EMBEDDINGS_WATSON_BATCH_SIZE",
|
||||||
)
|
)
|
||||||
concurrency_limit: int = Field(
|
concurrency_limit: int = Field(
|
||||||
default=10,
|
default=10,
|
||||||
description="Concurrency limit",
|
description="Concurrency limit",
|
||||||
validation_alias="WATSON_CONCURRENCY_LIMIT",
|
validation_alias="EMBEDDINGS_WATSON_CONCURRENCY_LIMIT",
|
||||||
)
|
)
|
||||||
max_retries: int | None = Field(
|
max_retries: int | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Maximum retries",
|
description="Maximum retries",
|
||||||
validation_alias="WATSON_MAX_RETRIES",
|
validation_alias="EMBEDDINGS_WATSON_MAX_RETRIES",
|
||||||
)
|
)
|
||||||
delay_time: float | None = Field(
|
delay_time: float | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Delay time between retries",
|
description="Delay time between retries",
|
||||||
validation_alias="WATSON_DELAY_TIME",
|
validation_alias="EMBEDDINGS_WATSON_DELAY_TIME",
|
||||||
)
|
)
|
||||||
retry_status_codes: list[int] | None = Field(
|
retry_status_codes: list[int] | None = Field(
|
||||||
default=None, description="HTTP status codes to retry on"
|
default=None, description="HTTP status codes to retry on"
|
||||||
)
|
)
|
||||||
url: str = Field(description="Watson API URL", validation_alias="WATSON_URL")
|
url: str = Field(
|
||||||
|
description="Watson API URL", validation_alias="EMBEDDINGS_WATSON_URL"
|
||||||
|
)
|
||||||
api_key: str = Field(
|
api_key: str = Field(
|
||||||
description="Watson API key", validation_alias="WATSON_API_KEY"
|
description="Watson API key", validation_alias="EMBEDDINGS_WATSON_API_KEY"
|
||||||
)
|
)
|
||||||
name: str | None = Field(
|
name: str | None = Field(
|
||||||
default=None, description="Service name", validation_alias="WATSON_NAME"
|
default=None,
|
||||||
|
description="Service name",
|
||||||
|
validation_alias="EMBEDDINGS_WATSON_NAME",
|
||||||
)
|
)
|
||||||
iam_serviceid_crn: str | None = Field(
|
iam_serviceid_crn: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="IAM service ID CRN",
|
description="IAM service ID CRN",
|
||||||
validation_alias="WATSON_IAM_SERVICEID_CRN",
|
validation_alias="EMBEDDINGS_WATSON_IAM_SERVICEID_CRN",
|
||||||
)
|
)
|
||||||
trusted_profile_id: str | None = Field(
|
trusted_profile_id: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Trusted profile ID",
|
description="Trusted profile ID",
|
||||||
validation_alias="WATSON_TRUSTED_PROFILE_ID",
|
validation_alias="EMBEDDINGS_WATSON_TRUSTED_PROFILE_ID",
|
||||||
)
|
)
|
||||||
token: str | None = Field(
|
token: str | None = Field(
|
||||||
default=None, description="Bearer token", validation_alias="WATSON_TOKEN"
|
default=None,
|
||||||
|
description="Bearer token",
|
||||||
|
validation_alias="EMBEDDINGS_WATSON_TOKEN",
|
||||||
)
|
)
|
||||||
projects_token: str | None = Field(
|
projects_token: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Projects token",
|
description="Projects token",
|
||||||
validation_alias="WATSON_PROJECTS_TOKEN",
|
validation_alias="EMBEDDINGS_WATSON_PROJECTS_TOKEN",
|
||||||
)
|
)
|
||||||
username: str | None = Field(
|
username: str | None = Field(
|
||||||
default=None, description="Username", validation_alias="WATSON_USERNAME"
|
default=None,
|
||||||
|
description="Username",
|
||||||
|
validation_alias="EMBEDDINGS_WATSON_USERNAME",
|
||||||
)
|
)
|
||||||
password: str | None = Field(
|
password: str | None = Field(
|
||||||
default=None, description="Password", validation_alias="WATSON_PASSWORD"
|
default=None,
|
||||||
|
description="Password",
|
||||||
|
validation_alias="EMBEDDINGS_WATSON_PASSWORD",
|
||||||
)
|
)
|
||||||
instance_id: str | None = Field(
|
instance_id: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Service instance ID",
|
description="Service instance ID",
|
||||||
validation_alias="WATSON_INSTANCE_ID",
|
validation_alias="EMBEDDINGS_WATSON_INSTANCE_ID",
|
||||||
)
|
)
|
||||||
version: str | None = Field(
|
version: str | None = Field(
|
||||||
default=None, description="API version", validation_alias="WATSON_VERSION"
|
default=None,
|
||||||
|
description="API version",
|
||||||
|
validation_alias="EMBEDDINGS_WATSON_VERSION",
|
||||||
)
|
)
|
||||||
bedrock_url: str | None = Field(
|
bedrock_url: str | None = Field(
|
||||||
default=None, description="Bedrock URL", validation_alias="WATSON_BEDROCK_URL"
|
default=None,
|
||||||
|
description="Bedrock URL",
|
||||||
|
validation_alias="EMBEDDINGS_WATSON_BEDROCK_URL",
|
||||||
)
|
)
|
||||||
platform_url: str | None = Field(
|
platform_url: str | None = Field(
|
||||||
default=None, description="Platform URL", validation_alias="WATSON_PLATFORM_URL"
|
default=None,
|
||||||
|
description="Platform URL",
|
||||||
|
validation_alias="EMBEDDINGS_WATSON_PLATFORM_URL",
|
||||||
)
|
)
|
||||||
proxies: dict | None = Field(default=None, description="Proxy configuration")
|
proxies: dict | None = Field(default=None, description="Proxy configuration")
|
||||||
|
|
||||||
|
|||||||
@@ -18,15 +18,15 @@ class InstructorProvider(BaseEmbeddingsProvider[InstructorEmbeddingFunction]):
|
|||||||
model_name: str = Field(
|
model_name: str = Field(
|
||||||
default="hkunlp/instructor-base",
|
default="hkunlp/instructor-base",
|
||||||
description="Model name to use",
|
description="Model name to use",
|
||||||
validation_alias="INSTRUCTOR_MODEL_NAME",
|
validation_alias="EMBEDDINGS_INSTRUCTOR_MODEL_NAME",
|
||||||
)
|
)
|
||||||
device: str = Field(
|
device: str = Field(
|
||||||
default="cpu",
|
default="cpu",
|
||||||
description="Device to run model on (cpu or cuda)",
|
description="Device to run model on (cpu or cuda)",
|
||||||
validation_alias="INSTRUCTOR_DEVICE",
|
validation_alias="EMBEDDINGS_INSTRUCTOR_DEVICE",
|
||||||
)
|
)
|
||||||
instruction: str | None = Field(
|
instruction: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Instruction for embeddings",
|
description="Instruction for embeddings",
|
||||||
validation_alias="INSTRUCTOR_INSTRUCTION",
|
validation_alias="EMBEDDINGS_INSTRUCTOR_INSTRUCTION",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -14,9 +14,11 @@ class JinaProvider(BaseEmbeddingsProvider[JinaEmbeddingFunction]):
|
|||||||
embedding_callable: type[JinaEmbeddingFunction] = Field(
|
embedding_callable: type[JinaEmbeddingFunction] = Field(
|
||||||
default=JinaEmbeddingFunction, description="Jina embedding function class"
|
default=JinaEmbeddingFunction, description="Jina embedding function class"
|
||||||
)
|
)
|
||||||
api_key: str = Field(description="Jina API key", validation_alias="JINA_API_KEY")
|
api_key: str = Field(
|
||||||
|
description="Jina API key", validation_alias="EMBEDDINGS_JINA_API_KEY"
|
||||||
|
)
|
||||||
model_name: str = Field(
|
model_name: str = Field(
|
||||||
default="jina-embeddings-v2-base-en",
|
default="jina-embeddings-v2-base-en",
|
||||||
description="Model name to use for embeddings",
|
description="Model name to use for embeddings",
|
||||||
validation_alias="JINA_MODEL_NAME",
|
validation_alias="EMBEDDINGS_JINA_MODEL_NAME",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -17,26 +17,28 @@ class AzureProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
|
|||||||
default=OpenAIEmbeddingFunction,
|
default=OpenAIEmbeddingFunction,
|
||||||
description="Azure OpenAI embedding function class",
|
description="Azure OpenAI embedding function class",
|
||||||
)
|
)
|
||||||
api_key: str = Field(description="Azure API key", validation_alias="OPENAI_API_KEY")
|
api_key: str = Field(
|
||||||
|
description="Azure API key", validation_alias="EMBEDDINGS_OPENAI_API_KEY"
|
||||||
|
)
|
||||||
api_base: str | None = Field(
|
api_base: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Azure endpoint URL",
|
description="Azure endpoint URL",
|
||||||
validation_alias="OPENAI_API_BASE",
|
validation_alias="EMBEDDINGS_OPENAI_API_BASE",
|
||||||
)
|
)
|
||||||
api_type: str = Field(
|
api_type: str = Field(
|
||||||
default="azure",
|
default="azure",
|
||||||
description="API type for Azure",
|
description="API type for Azure",
|
||||||
validation_alias="OPENAI_API_TYPE",
|
validation_alias="EMBEDDINGS_OPENAI_API_TYPE",
|
||||||
)
|
)
|
||||||
api_version: str | None = Field(
|
api_version: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Azure API version",
|
description="Azure API version",
|
||||||
validation_alias="OPENAI_API_VERSION",
|
validation_alias="EMBEDDINGS_OPENAI_API_VERSION",
|
||||||
)
|
)
|
||||||
model_name: str = Field(
|
model_name: str = Field(
|
||||||
default="text-embedding-ada-002",
|
default="text-embedding-ada-002",
|
||||||
description="Model name to use for embeddings",
|
description="Model name to use for embeddings",
|
||||||
validation_alias="OPENAI_MODEL_NAME",
|
validation_alias="EMBEDDINGS_OPENAI_MODEL_NAME",
|
||||||
)
|
)
|
||||||
default_headers: dict[str, Any] | None = Field(
|
default_headers: dict[str, Any] | None = Field(
|
||||||
default=None, description="Default headers for API requests"
|
default=None, description="Default headers for API requests"
|
||||||
@@ -44,15 +46,15 @@ class AzureProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
|
|||||||
dimensions: int | None = Field(
|
dimensions: int | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Embedding dimensions",
|
description="Embedding dimensions",
|
||||||
validation_alias="OPENAI_DIMENSIONS",
|
validation_alias="EMBEDDINGS_OPENAI_DIMENSIONS",
|
||||||
)
|
)
|
||||||
deployment_id: str | None = Field(
|
deployment_id: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Azure deployment ID",
|
description="Azure deployment ID",
|
||||||
validation_alias="OPENAI_DEPLOYMENT_ID",
|
validation_alias="EMBEDDINGS_OPENAI_DEPLOYMENT_ID",
|
||||||
)
|
)
|
||||||
organization_id: str | None = Field(
|
organization_id: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Organization ID",
|
description="Organization ID",
|
||||||
validation_alias="OPENAI_ORGANIZATION_ID",
|
validation_alias="EMBEDDINGS_OPENAI_ORGANIZATION_ID",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -17,9 +17,9 @@ class OllamaProvider(BaseEmbeddingsProvider[OllamaEmbeddingFunction]):
|
|||||||
url: str = Field(
|
url: str = Field(
|
||||||
default="http://localhost:11434/api/embeddings",
|
default="http://localhost:11434/api/embeddings",
|
||||||
description="Ollama API endpoint URL",
|
description="Ollama API endpoint URL",
|
||||||
validation_alias="OLLAMA_URL",
|
validation_alias="EMBEDDINGS_OLLAMA_URL",
|
||||||
)
|
)
|
||||||
model_name: str = Field(
|
model_name: str = Field(
|
||||||
description="Model name to use for embeddings",
|
description="Model name to use for embeddings",
|
||||||
validation_alias="OLLAMA_MODEL_NAME",
|
validation_alias="EMBEDDINGS_OLLAMA_MODEL_NAME",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -15,5 +15,5 @@ class ONNXProvider(BaseEmbeddingsProvider[ONNXMiniLM_L6_V2]):
|
|||||||
preferred_providers: list[str] | None = Field(
|
preferred_providers: list[str] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Preferred ONNX execution providers",
|
description="Preferred ONNX execution providers",
|
||||||
validation_alias="ONNX_PREFERRED_PROVIDERS",
|
validation_alias="EMBEDDINGS_ONNX_PREFERRED_PROVIDERS",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -18,25 +18,29 @@ class OpenAIProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
|
|||||||
description="OpenAI embedding function class",
|
description="OpenAI embedding function class",
|
||||||
)
|
)
|
||||||
api_key: str | None = Field(
|
api_key: str | None = Field(
|
||||||
default=None, description="OpenAI API key", validation_alias="OPENAI_API_KEY"
|
default=None,
|
||||||
|
description="OpenAI API key",
|
||||||
|
validation_alias="EMBEDDINGS_OPENAI_API_KEY",
|
||||||
)
|
)
|
||||||
model_name: str = Field(
|
model_name: str = Field(
|
||||||
default="text-embedding-ada-002",
|
default="text-embedding-ada-002",
|
||||||
description="Model name to use for embeddings",
|
description="Model name to use for embeddings",
|
||||||
validation_alias="OPENAI_MODEL_NAME",
|
validation_alias="EMBEDDINGS_OPENAI_MODEL_NAME",
|
||||||
)
|
)
|
||||||
api_base: str | None = Field(
|
api_base: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Base URL for API requests",
|
description="Base URL for API requests",
|
||||||
validation_alias="OPENAI_API_BASE",
|
validation_alias="EMBEDDINGS_OPENAI_API_BASE",
|
||||||
)
|
)
|
||||||
api_type: str | None = Field(
|
api_type: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="API type (e.g., 'azure')",
|
description="API type (e.g., 'azure')",
|
||||||
validation_alias="OPENAI_API_TYPE",
|
validation_alias="EMBEDDINGS_OPENAI_API_TYPE",
|
||||||
)
|
)
|
||||||
api_version: str | None = Field(
|
api_version: str | None = Field(
|
||||||
default=None, description="API version", validation_alias="OPENAI_API_VERSION"
|
default=None,
|
||||||
|
description="API version",
|
||||||
|
validation_alias="EMBEDDINGS_OPENAI_API_VERSION",
|
||||||
)
|
)
|
||||||
default_headers: dict[str, Any] | None = Field(
|
default_headers: dict[str, Any] | None = Field(
|
||||||
default=None, description="Default headers for API requests"
|
default=None, description="Default headers for API requests"
|
||||||
@@ -44,15 +48,15 @@ class OpenAIProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
|
|||||||
dimensions: int | None = Field(
|
dimensions: int | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Embedding dimensions",
|
description="Embedding dimensions",
|
||||||
validation_alias="OPENAI_DIMENSIONS",
|
validation_alias="EMBEDDINGS_OPENAI_DIMENSIONS",
|
||||||
)
|
)
|
||||||
deployment_id: str | None = Field(
|
deployment_id: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Azure deployment ID",
|
description="Azure deployment ID",
|
||||||
validation_alias="OPENAI_DEPLOYMENT_ID",
|
validation_alias="EMBEDDINGS_OPENAI_DEPLOYMENT_ID",
|
||||||
)
|
)
|
||||||
organization_id: str | None = Field(
|
organization_id: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="OpenAI organization ID",
|
description="OpenAI organization ID",
|
||||||
validation_alias="OPENAI_ORGANIZATION_ID",
|
validation_alias="EMBEDDINGS_OPENAI_ORGANIZATION_ID",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -18,15 +18,15 @@ class OpenCLIPProvider(BaseEmbeddingsProvider[OpenCLIPEmbeddingFunction]):
|
|||||||
model_name: str = Field(
|
model_name: str = Field(
|
||||||
default="ViT-B-32",
|
default="ViT-B-32",
|
||||||
description="Model name to use",
|
description="Model name to use",
|
||||||
validation_alias="OPENCLIP_MODEL_NAME",
|
validation_alias="EMBEDDINGS_OPENCLIP_MODEL_NAME",
|
||||||
)
|
)
|
||||||
checkpoint: str = Field(
|
checkpoint: str = Field(
|
||||||
default="laion2b_s34b_b79k",
|
default="laion2b_s34b_b79k",
|
||||||
description="Model checkpoint",
|
description="Model checkpoint",
|
||||||
validation_alias="OPENCLIP_CHECKPOINT",
|
validation_alias="EMBEDDINGS_OPENCLIP_CHECKPOINT",
|
||||||
)
|
)
|
||||||
device: str | None = Field(
|
device: str | None = Field(
|
||||||
default="cpu",
|
default="cpu",
|
||||||
description="Device to run model on",
|
description="Device to run model on",
|
||||||
validation_alias="OPENCLIP_DEVICE",
|
validation_alias="EMBEDDINGS_OPENCLIP_DEVICE",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -16,10 +16,12 @@ class RoboflowProvider(BaseEmbeddingsProvider[RoboflowEmbeddingFunction]):
|
|||||||
description="Roboflow embedding function class",
|
description="Roboflow embedding function class",
|
||||||
)
|
)
|
||||||
api_key: str = Field(
|
api_key: str = Field(
|
||||||
default="", description="Roboflow API key", validation_alias="ROBOFLOW_API_KEY"
|
default="",
|
||||||
|
description="Roboflow API key",
|
||||||
|
validation_alias="EMBEDDINGS_ROBOFLOW_API_KEY",
|
||||||
)
|
)
|
||||||
api_url: str = Field(
|
api_url: str = Field(
|
||||||
default="https://infer.roboflow.com",
|
default="https://infer.roboflow.com",
|
||||||
description="Roboflow API URL",
|
description="Roboflow API URL",
|
||||||
validation_alias="ROBOFLOW_API_URL",
|
validation_alias="EMBEDDINGS_ROBOFLOW_API_URL",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -20,15 +20,15 @@ class SentenceTransformerProvider(
|
|||||||
model_name: str = Field(
|
model_name: str = Field(
|
||||||
default="all-MiniLM-L6-v2",
|
default="all-MiniLM-L6-v2",
|
||||||
description="Model name to use",
|
description="Model name to use",
|
||||||
validation_alias="SENTENCE_TRANSFORMER_MODEL_NAME",
|
validation_alias="EMBEDDINGS_SENTENCE_TRANSFORMER_MODEL_NAME",
|
||||||
)
|
)
|
||||||
device: str = Field(
|
device: str = Field(
|
||||||
default="cpu",
|
default="cpu",
|
||||||
description="Device to run model on (cpu or cuda)",
|
description="Device to run model on (cpu or cuda)",
|
||||||
validation_alias="SENTENCE_TRANSFORMER_DEVICE",
|
validation_alias="EMBEDDINGS_SENTENCE_TRANSFORMER_DEVICE",
|
||||||
)
|
)
|
||||||
normalize_embeddings: bool = Field(
|
normalize_embeddings: bool = Field(
|
||||||
default=False,
|
default=False,
|
||||||
description="Whether to normalize embeddings",
|
description="Whether to normalize embeddings",
|
||||||
validation_alias="SENTENCE_TRANSFORMER_NORMALIZE_EMBEDDINGS",
|
validation_alias="EMBEDDINGS_SENTENCE_TRANSFORMER_NORMALIZE_EMBEDDINGS",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -18,5 +18,5 @@ class Text2VecProvider(BaseEmbeddingsProvider[Text2VecEmbeddingFunction]):
|
|||||||
model_name: str = Field(
|
model_name: str = Field(
|
||||||
default="shibing624/text2vec-base-chinese",
|
default="shibing624/text2vec-base-chinese",
|
||||||
description="Model name to use",
|
description="Model name to use",
|
||||||
validation_alias="TEXT2VEC_MODEL_NAME",
|
validation_alias="EMBEDDINGS_TEXT2VEC_MODEL_NAME",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -18,38 +18,38 @@ class VoyageAIProvider(BaseEmbeddingsProvider[VoyageAIEmbeddingFunction]):
|
|||||||
model: str = Field(
|
model: str = Field(
|
||||||
default="voyage-2",
|
default="voyage-2",
|
||||||
description="Model to use for embeddings",
|
description="Model to use for embeddings",
|
||||||
validation_alias="VOYAGEAI_MODEL",
|
validation_alias="EMBEDDINGS_VOYAGEAI_MODEL",
|
||||||
)
|
)
|
||||||
api_key: str = Field(
|
api_key: str = Field(
|
||||||
description="Voyage AI API key", validation_alias="VOYAGEAI_API_KEY"
|
description="Voyage AI API key", validation_alias="EMBEDDINGS_VOYAGEAI_API_KEY"
|
||||||
)
|
)
|
||||||
input_type: str | None = Field(
|
input_type: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Input type for embeddings",
|
description="Input type for embeddings",
|
||||||
validation_alias="VOYAGEAI_INPUT_TYPE",
|
validation_alias="EMBEDDINGS_VOYAGEAI_INPUT_TYPE",
|
||||||
)
|
)
|
||||||
truncation: bool = Field(
|
truncation: bool = Field(
|
||||||
default=True,
|
default=True,
|
||||||
description="Whether to truncate inputs",
|
description="Whether to truncate inputs",
|
||||||
validation_alias="VOYAGEAI_TRUNCATION",
|
validation_alias="EMBEDDINGS_VOYAGEAI_TRUNCATION",
|
||||||
)
|
)
|
||||||
output_dtype: str | None = Field(
|
output_dtype: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Output data type",
|
description="Output data type",
|
||||||
validation_alias="VOYAGEAI_OUTPUT_DTYPE",
|
validation_alias="EMBEDDINGS_VOYAGEAI_OUTPUT_DTYPE",
|
||||||
)
|
)
|
||||||
output_dimension: int | None = Field(
|
output_dimension: int | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Output dimension",
|
description="Output dimension",
|
||||||
validation_alias="VOYAGEAI_OUTPUT_DIMENSION",
|
validation_alias="EMBEDDINGS_VOYAGEAI_OUTPUT_DIMENSION",
|
||||||
)
|
)
|
||||||
max_retries: int = Field(
|
max_retries: int = Field(
|
||||||
default=0,
|
default=0,
|
||||||
description="Maximum retries for API calls",
|
description="Maximum retries for API calls",
|
||||||
validation_alias="VOYAGEAI_MAX_RETRIES",
|
validation_alias="EMBEDDINGS_VOYAGEAI_MAX_RETRIES",
|
||||||
)
|
)
|
||||||
timeout: float | None = Field(
|
timeout: float | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Timeout for API calls",
|
description="Timeout for API calls",
|
||||||
validation_alias="VOYAGEAI_TIMEOUT",
|
validation_alias="EMBEDDINGS_VOYAGEAI_TIMEOUT",
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user