mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-02 15:52:34 +00:00
fix: rag tool embeddings config
* fix: ensure config is not flattened, add tests * chore: refactor inits to model_validator * chore: refactor rag tool config parsing * chore: add initial docs * chore: add additional validation aliases for provider env vars * chore: add solid docs * chore: move imports to top * fix: revert circular import * fix: lazy import qdrant-client * fix: allow collection name config * chore: narrow model names for google * chore: update additional docs * chore: add backward compat on model name aliases * chore: add tests for config changes
This commit is contained in:
@@ -91,6 +91,7 @@ PROVIDER_PATHS = {
|
||||
"cohere": "crewai.rag.embeddings.providers.cohere.cohere_provider.CohereProvider",
|
||||
"custom": "crewai.rag.embeddings.providers.custom.custom_provider.CustomProvider",
|
||||
"google-generativeai": "crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider",
|
||||
"google": "crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider",
|
||||
"google-vertex": "crewai.rag.embeddings.providers.google.vertex.VertexAIProvider",
|
||||
"huggingface": "crewai.rag.embeddings.providers.huggingface.huggingface_provider.HuggingFaceProvider",
|
||||
"instructor": "crewai.rag.embeddings.providers.instructor.instructor_provider.InstructorProvider",
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any
|
||||
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
||||
AmazonBedrockEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -21,7 +21,7 @@ def create_aws_session() -> Any:
|
||||
ValueError: If AWS session creation fails
|
||||
"""
|
||||
try:
|
||||
import boto3 # type: ignore[import]
|
||||
import boto3
|
||||
|
||||
return boto3.Session()
|
||||
except ImportError as e:
|
||||
@@ -46,7 +46,12 @@ class BedrockProvider(BaseEmbeddingsProvider[AmazonBedrockEmbeddingFunction]):
|
||||
model_name: str = Field(
|
||||
default="amazon.titan-embed-text-v1",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_BEDROCK_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_BEDROCK_MODEL_NAME",
|
||||
"BEDROCK_MODEL_NAME",
|
||||
"AWS_BEDROCK_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
session: Any = Field(
|
||||
default_factory=create_aws_session, description="AWS session object"
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.cohere_embedding_function import (
|
||||
CohereEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -15,10 +15,14 @@ class CohereProvider(BaseEmbeddingsProvider[CohereEmbeddingFunction]):
|
||||
default=CohereEmbeddingFunction, description="Cohere embedding function class"
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Cohere API key", validation_alias="EMBEDDINGS_COHERE_API_KEY"
|
||||
description="Cohere API key",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_COHERE_API_KEY", "COHERE_API_KEY"),
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="large",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_COHERE_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_COHERE_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
"""Google Generative AI embeddings provider."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleGenerativeAiEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -15,16 +17,27 @@ class GenerativeAiProvider(BaseEmbeddingsProvider[GoogleGenerativeAiEmbeddingFun
|
||||
default=GoogleGenerativeAiEmbeddingFunction,
|
||||
description="Google Generative AI embedding function class",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="models/embedding-001",
|
||||
model_name: Literal[
|
||||
"gemini-embedding-001", "text-embedding-005", "text-multilingual-embedding-002"
|
||||
] = Field(
|
||||
default="gemini-embedding-001",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_GOOGLE_GENERATIVE_AI_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_GENERATIVE_AI_MODEL_NAME", "model"
|
||||
),
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Google API key", validation_alias="EMBEDDINGS_GOOGLE_API_KEY"
|
||||
description="Google API key",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_API_KEY", "GOOGLE_API_KEY", "GEMINI_API_KEY"
|
||||
),
|
||||
)
|
||||
task_type: str = Field(
|
||||
default="RETRIEVAL_DOCUMENT",
|
||||
description="Task type for embeddings",
|
||||
validation_alias="EMBEDDINGS_GOOGLE_GENERATIVE_AI_TASK_TYPE",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_GENERATIVE_AI_TASK_TYPE",
|
||||
"GOOGLE_GENERATIVE_AI_TASK_TYPE",
|
||||
"GEMINI_TASK_TYPE",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -6,10 +6,23 @@ from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class GenerativeAiProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Google Generative AI provider."""
|
||||
"""Configuration for Google Generative AI provider.
|
||||
|
||||
Attributes:
|
||||
api_key: Google API key for authentication.
|
||||
model_name: Embedding model name.
|
||||
task_type: Task type for embeddings. Default is "RETRIEVAL_DOCUMENT".
|
||||
"""
|
||||
|
||||
api_key: str
|
||||
model_name: Annotated[str, "models/embedding-001"]
|
||||
model_name: Annotated[
|
||||
Literal[
|
||||
"gemini-embedding-001",
|
||||
"text-embedding-005",
|
||||
"text-multilingual-embedding-002",
|
||||
],
|
||||
"gemini-embedding-001",
|
||||
]
|
||||
task_type: Annotated[str, "RETRIEVAL_DOCUMENT"]
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleVertexEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -18,18 +18,29 @@ class VertexAIProvider(BaseEmbeddingsProvider[GoogleVertexEmbeddingFunction]):
|
||||
model_name: str = Field(
|
||||
default="textembedding-gecko",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_GOOGLE_VERTEX_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_VERTEX_MODEL_NAME",
|
||||
"GOOGLE_VERTEX_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Google API key", validation_alias="EMBEDDINGS_GOOGLE_CLOUD_API_KEY"
|
||||
description="Google API key",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_CLOUD_API_KEY", "GOOGLE_CLOUD_API_KEY"
|
||||
),
|
||||
)
|
||||
project_id: str = Field(
|
||||
default="cloud-large-language-models",
|
||||
description="GCP project ID",
|
||||
validation_alias="EMBEDDINGS_GOOGLE_CLOUD_PROJECT",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_CLOUD_PROJECT", "GOOGLE_CLOUD_PROJECT"
|
||||
),
|
||||
)
|
||||
region: str = Field(
|
||||
default="us-central1",
|
||||
description="GCP region",
|
||||
validation_alias="EMBEDDINGS_GOOGLE_CLOUD_REGION",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_CLOUD_REGION", "GOOGLE_CLOUD_REGION"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||
HuggingFaceEmbeddingServer,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -16,5 +16,6 @@ class HuggingFaceProvider(BaseEmbeddingsProvider[HuggingFaceEmbeddingServer]):
|
||||
description="HuggingFace embedding function class",
|
||||
)
|
||||
url: str = Field(
|
||||
description="HuggingFace API URL", validation_alias="EMBEDDINGS_HUGGINGFACE_URL"
|
||||
description="HuggingFace API URL",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_HUGGINGFACE_URL", "HUGGINGFACE_URL"),
|
||||
)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
from pydantic import AliasChoices, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
@@ -21,7 +21,10 @@ class WatsonXProvider(BaseEmbeddingsProvider[WatsonXEmbeddingFunction]):
|
||||
default=WatsonXEmbeddingFunction, description="WatsonX embedding function class"
|
||||
)
|
||||
model_id: str = Field(
|
||||
description="WatsonX model ID", validation_alias="EMBEDDINGS_WATSONX_MODEL_ID"
|
||||
description="WatsonX model ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_MODEL_ID", "WATSONX_MODEL_ID"
|
||||
),
|
||||
)
|
||||
params: dict[str, str | dict[str, str]] | None = Field(
|
||||
default=None, description="Additional parameters"
|
||||
@@ -30,109 +33,143 @@ class WatsonXProvider(BaseEmbeddingsProvider[WatsonXEmbeddingFunction]):
|
||||
project_id: str | None = Field(
|
||||
default=None,
|
||||
description="WatsonX project ID",
|
||||
validation_alias="EMBEDDINGS_WATSONX_PROJECT_ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_PROJECT_ID", "WATSONX_PROJECT_ID"
|
||||
),
|
||||
)
|
||||
space_id: str | None = Field(
|
||||
default=None,
|
||||
description="WatsonX space ID",
|
||||
validation_alias="EMBEDDINGS_WATSONX_SPACE_ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_SPACE_ID", "WATSONX_SPACE_ID"
|
||||
),
|
||||
)
|
||||
api_client: Any | None = Field(default=None, description="WatsonX API client")
|
||||
verify: bool | str | None = Field(
|
||||
default=None,
|
||||
description="SSL verification",
|
||||
validation_alias="EMBEDDINGS_WATSONX_VERIFY",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_WATSONX_VERIFY", "WATSONX_VERIFY"),
|
||||
)
|
||||
persistent_connection: bool = Field(
|
||||
default=True,
|
||||
description="Use persistent connection",
|
||||
validation_alias="EMBEDDINGS_WATSONX_PERSISTENT_CONNECTION",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_PERSISTENT_CONNECTION", "WATSONX_PERSISTENT_CONNECTION"
|
||||
),
|
||||
)
|
||||
batch_size: int = Field(
|
||||
default=100,
|
||||
description="Batch size for processing",
|
||||
validation_alias="EMBEDDINGS_WATSONX_BATCH_SIZE",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_BATCH_SIZE", "WATSONX_BATCH_SIZE"
|
||||
),
|
||||
)
|
||||
concurrency_limit: int = Field(
|
||||
default=10,
|
||||
description="Concurrency limit",
|
||||
validation_alias="EMBEDDINGS_WATSONX_CONCURRENCY_LIMIT",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_CONCURRENCY_LIMIT", "WATSONX_CONCURRENCY_LIMIT"
|
||||
),
|
||||
)
|
||||
max_retries: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum retries",
|
||||
validation_alias="EMBEDDINGS_WATSONX_MAX_RETRIES",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_MAX_RETRIES", "WATSONX_MAX_RETRIES"
|
||||
),
|
||||
)
|
||||
delay_time: float | None = Field(
|
||||
default=None,
|
||||
description="Delay time between retries",
|
||||
validation_alias="EMBEDDINGS_WATSONX_DELAY_TIME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_DELAY_TIME", "WATSONX_DELAY_TIME"
|
||||
),
|
||||
)
|
||||
retry_status_codes: list[int] | None = Field(
|
||||
default=None, description="HTTP status codes to retry on"
|
||||
)
|
||||
url: str = Field(
|
||||
description="WatsonX API URL", validation_alias="EMBEDDINGS_WATSONX_URL"
|
||||
description="WatsonX API URL",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_WATSONX_URL", "WATSONX_URL"),
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="WatsonX API key", validation_alias="EMBEDDINGS_WATSONX_API_KEY"
|
||||
description="WatsonX API key",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_WATSONX_API_KEY", "WATSONX_API_KEY"),
|
||||
)
|
||||
name: str | None = Field(
|
||||
default=None,
|
||||
description="Service name",
|
||||
validation_alias="EMBEDDINGS_WATSONX_NAME",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_WATSONX_NAME", "WATSONX_NAME"),
|
||||
)
|
||||
iam_serviceid_crn: str | None = Field(
|
||||
default=None,
|
||||
description="IAM service ID CRN",
|
||||
validation_alias="EMBEDDINGS_WATSONX_IAM_SERVICEID_CRN",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_IAM_SERVICEID_CRN", "WATSONX_IAM_SERVICEID_CRN"
|
||||
),
|
||||
)
|
||||
trusted_profile_id: str | None = Field(
|
||||
default=None,
|
||||
description="Trusted profile ID",
|
||||
validation_alias="EMBEDDINGS_WATSONX_TRUSTED_PROFILE_ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_TRUSTED_PROFILE_ID", "WATSONX_TRUSTED_PROFILE_ID"
|
||||
),
|
||||
)
|
||||
token: str | None = Field(
|
||||
default=None,
|
||||
description="Bearer token",
|
||||
validation_alias="EMBEDDINGS_WATSONX_TOKEN",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_WATSONX_TOKEN", "WATSONX_TOKEN"),
|
||||
)
|
||||
projects_token: str | None = Field(
|
||||
default=None,
|
||||
description="Projects token",
|
||||
validation_alias="EMBEDDINGS_WATSONX_PROJECTS_TOKEN",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_PROJECTS_TOKEN", "WATSONX_PROJECTS_TOKEN"
|
||||
),
|
||||
)
|
||||
username: str | None = Field(
|
||||
default=None,
|
||||
description="Username",
|
||||
validation_alias="EMBEDDINGS_WATSONX_USERNAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_USERNAME", "WATSONX_USERNAME"
|
||||
),
|
||||
)
|
||||
password: str | None = Field(
|
||||
default=None,
|
||||
description="Password",
|
||||
validation_alias="EMBEDDINGS_WATSONX_PASSWORD",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_PASSWORD", "WATSONX_PASSWORD"
|
||||
),
|
||||
)
|
||||
instance_id: str | None = Field(
|
||||
default=None,
|
||||
description="Service instance ID",
|
||||
validation_alias="EMBEDDINGS_WATSONX_INSTANCE_ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_INSTANCE_ID", "WATSONX_INSTANCE_ID"
|
||||
),
|
||||
)
|
||||
version: str | None = Field(
|
||||
default=None,
|
||||
description="API version",
|
||||
validation_alias="EMBEDDINGS_WATSONX_VERSION",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_WATSONX_VERSION", "WATSONX_VERSION"),
|
||||
)
|
||||
bedrock_url: str | None = Field(
|
||||
default=None,
|
||||
description="Bedrock URL",
|
||||
validation_alias="EMBEDDINGS_WATSONX_BEDROCK_URL",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_BEDROCK_URL", "WATSONX_BEDROCK_URL"
|
||||
),
|
||||
)
|
||||
platform_url: str | None = Field(
|
||||
default=None,
|
||||
description="Platform URL",
|
||||
validation_alias="EMBEDDINGS_WATSONX_PLATFORM_URL",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_PLATFORM_URL", "WATSONX_PLATFORM_URL"
|
||||
),
|
||||
)
|
||||
proxies: dict[str, Any] | None = Field(
|
||||
default=None, description="Proxy configuration"
|
||||
)
|
||||
proxies: dict | None = Field(default=None, description="Proxy configuration")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_space_or_project(self) -> Self:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.instructor_embedding_function import (
|
||||
InstructorEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -18,15 +18,23 @@ class InstructorProvider(BaseEmbeddingsProvider[InstructorEmbeddingFunction]):
|
||||
model_name: str = Field(
|
||||
default="hkunlp/instructor-base",
|
||||
description="Model name to use",
|
||||
validation_alias="EMBEDDINGS_INSTRUCTOR_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_INSTRUCTOR_MODEL_NAME",
|
||||
"INSTRUCTOR_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
device: str = Field(
|
||||
default="cpu",
|
||||
description="Device to run model on (cpu or cuda)",
|
||||
validation_alias="EMBEDDINGS_INSTRUCTOR_DEVICE",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_INSTRUCTOR_DEVICE", "INSTRUCTOR_DEVICE"
|
||||
),
|
||||
)
|
||||
instruction: str | None = Field(
|
||||
default=None,
|
||||
description="Instruction for embeddings",
|
||||
validation_alias="EMBEDDINGS_INSTRUCTOR_INSTRUCTION",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_INSTRUCTOR_INSTRUCTION", "INSTRUCTOR_INSTRUCTION"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.jina_embedding_function import (
|
||||
JinaEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -15,10 +15,15 @@ class JinaProvider(BaseEmbeddingsProvider[JinaEmbeddingFunction]):
|
||||
default=JinaEmbeddingFunction, description="Jina embedding function class"
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Jina API key", validation_alias="EMBEDDINGS_JINA_API_KEY"
|
||||
description="Jina API key",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_JINA_API_KEY", "JINA_API_KEY"),
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="jina-embeddings-v2-base-en",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_JINA_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_JINA_MODEL_NAME",
|
||||
"JINA_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -18,27 +18,39 @@ class AzureProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
|
||||
description="Azure OpenAI embedding function class",
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Azure API key", validation_alias="EMBEDDINGS_OPENAI_API_KEY"
|
||||
description="Azure API key",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_OPENAI_API_KEY", "OPENAI_API_KEY"),
|
||||
)
|
||||
api_base: str | None = Field(
|
||||
default=None,
|
||||
description="Azure endpoint URL",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_BASE",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_OPENAI_API_BASE", "OPENAI_API_BASE"),
|
||||
)
|
||||
api_type: str = Field(
|
||||
default="azure",
|
||||
description="API type for Azure",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_TYPE",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_API_TYPE", "OPENAI_API_TYPE", "AZURE_OPENAI_API_TYPE"
|
||||
),
|
||||
)
|
||||
api_version: str | None = Field(
|
||||
default=None,
|
||||
default="2024-02-01",
|
||||
description="Azure API version",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_VERSION",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_API_VERSION",
|
||||
"OPENAI_API_VERSION",
|
||||
"AZURE_OPENAI_API_VERSION",
|
||||
),
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="text-embedding-ada-002",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_OPENAI_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_MODEL_NAME",
|
||||
"OPENAI_MODEL_NAME",
|
||||
"AZURE_OPENAI_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
default_headers: dict[str, Any] | None = Field(
|
||||
default=None, description="Default headers for API requests"
|
||||
@@ -46,15 +58,26 @@ class AzureProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
|
||||
dimensions: int | None = Field(
|
||||
default=None,
|
||||
description="Embedding dimensions",
|
||||
validation_alias="EMBEDDINGS_OPENAI_DIMENSIONS",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_DIMENSIONS",
|
||||
"OPENAI_DIMENSIONS",
|
||||
"AZURE_OPENAI_DIMENSIONS",
|
||||
),
|
||||
)
|
||||
deployment_id: str | None = Field(
|
||||
default=None,
|
||||
deployment_id: str = Field(
|
||||
description="Azure deployment ID",
|
||||
validation_alias="EMBEDDINGS_OPENAI_DEPLOYMENT_ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_DEPLOYMENT_ID",
|
||||
"AZURE_OPENAI_DEPLOYMENT",
|
||||
"AZURE_DEPLOYMENT_ID",
|
||||
),
|
||||
)
|
||||
organization_id: str | None = Field(
|
||||
default=None,
|
||||
description="Organization ID",
|
||||
validation_alias="EMBEDDINGS_OPENAI_ORGANIZATION_ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_ORGANIZATION_ID",
|
||||
"OPENAI_ORGANIZATION_ID",
|
||||
"AZURE_OPENAI_ORGANIZATION_ID",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -15,7 +15,7 @@ class AzureProviderConfig(TypedDict, total=False):
|
||||
model_name: Annotated[str, "text-embedding-ada-002"]
|
||||
default_headers: dict[str, Any]
|
||||
dimensions: int
|
||||
deployment_id: str
|
||||
deployment_id: Required[str]
|
||||
organization_id: str
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
||||
OllamaEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -17,9 +17,14 @@ class OllamaProvider(BaseEmbeddingsProvider[OllamaEmbeddingFunction]):
|
||||
url: str = Field(
|
||||
default="http://localhost:11434/api/embeddings",
|
||||
description="Ollama API endpoint URL",
|
||||
validation_alias="EMBEDDINGS_OLLAMA_URL",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_OLLAMA_URL", "OLLAMA_URL"),
|
||||
)
|
||||
model_name: str = Field(
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_OLLAMA_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OLLAMA_MODEL_NAME",
|
||||
"OLLAMA_MODEL_NAME",
|
||||
"OLLAMA_MODEL",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""ONNX embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ONNXMiniLM_L6_V2
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -15,5 +15,7 @@ class ONNXProvider(BaseEmbeddingsProvider[ONNXMiniLM_L6_V2]):
|
||||
preferred_providers: list[str] | None = Field(
|
||||
default=None,
|
||||
description="Preferred ONNX execution providers",
|
||||
validation_alias="EMBEDDINGS_ONNX_PREFERRED_PROVIDERS",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_ONNX_PREFERRED_PROVIDERS", "ONNX_PREFERRED_PROVIDERS"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -20,27 +20,33 @@ class OpenAIProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="OpenAI API key",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_KEY",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_OPENAI_API_KEY", "OPENAI_API_KEY"),
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="text-embedding-ada-002",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_OPENAI_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_MODEL_NAME",
|
||||
"OPENAI_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
api_base: str | None = Field(
|
||||
default=None,
|
||||
description="Base URL for API requests",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_BASE",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_OPENAI_API_BASE", "OPENAI_API_BASE"),
|
||||
)
|
||||
api_type: str | None = Field(
|
||||
default=None,
|
||||
description="API type (e.g., 'azure')",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_TYPE",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_OPENAI_API_TYPE", "OPENAI_API_TYPE"),
|
||||
)
|
||||
api_version: str | None = Field(
|
||||
default=None,
|
||||
description="API version",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_VERSION",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_API_VERSION", "OPENAI_API_VERSION"
|
||||
),
|
||||
)
|
||||
default_headers: dict[str, Any] | None = Field(
|
||||
default=None, description="Default headers for API requests"
|
||||
@@ -48,15 +54,21 @@ class OpenAIProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
|
||||
dimensions: int | None = Field(
|
||||
default=None,
|
||||
description="Embedding dimensions",
|
||||
validation_alias="EMBEDDINGS_OPENAI_DIMENSIONS",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_DIMENSIONS", "OPENAI_DIMENSIONS"
|
||||
),
|
||||
)
|
||||
deployment_id: str | None = Field(
|
||||
default=None,
|
||||
description="Azure deployment ID",
|
||||
validation_alias="EMBEDDINGS_OPENAI_DEPLOYMENT_ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_DEPLOYMENT_ID", "OPENAI_DEPLOYMENT_ID"
|
||||
),
|
||||
)
|
||||
organization_id: str | None = Field(
|
||||
default=None,
|
||||
description="OpenAI organization ID",
|
||||
validation_alias="EMBEDDINGS_OPENAI_ORGANIZATION_ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_ORGANIZATION_ID", "OPENAI_ORGANIZATION_ID"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.open_clip_embedding_function import (
|
||||
OpenCLIPEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -18,15 +18,21 @@ class OpenCLIPProvider(BaseEmbeddingsProvider[OpenCLIPEmbeddingFunction]):
|
||||
model_name: str = Field(
|
||||
default="ViT-B-32",
|
||||
description="Model name to use",
|
||||
validation_alias="EMBEDDINGS_OPENCLIP_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENCLIP_MODEL_NAME",
|
||||
"OPENCLIP_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
checkpoint: str = Field(
|
||||
default="laion2b_s34b_b79k",
|
||||
description="Model checkpoint",
|
||||
validation_alias="EMBEDDINGS_OPENCLIP_CHECKPOINT",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENCLIP_CHECKPOINT", "OPENCLIP_CHECKPOINT"
|
||||
),
|
||||
)
|
||||
device: str | None = Field(
|
||||
default="cpu",
|
||||
description="Device to run model on",
|
||||
validation_alias="EMBEDDINGS_OPENCLIP_DEVICE",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_OPENCLIP_DEVICE", "OPENCLIP_DEVICE"),
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.roboflow_embedding_function import (
|
||||
RoboflowEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -18,10 +18,14 @@ class RoboflowProvider(BaseEmbeddingsProvider[RoboflowEmbeddingFunction]):
|
||||
api_key: str = Field(
|
||||
default="",
|
||||
description="Roboflow API key",
|
||||
validation_alias="EMBEDDINGS_ROBOFLOW_API_KEY",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_ROBOFLOW_API_KEY", "ROBOFLOW_API_KEY"
|
||||
),
|
||||
)
|
||||
api_url: str = Field(
|
||||
default="https://infer.roboflow.com",
|
||||
description="Roboflow API URL",
|
||||
validation_alias="EMBEDDINGS_ROBOFLOW_API_URL",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_ROBOFLOW_API_URL", "ROBOFLOW_API_URL"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import (
|
||||
SentenceTransformerEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -20,15 +20,24 @@ class SentenceTransformerProvider(
|
||||
model_name: str = Field(
|
||||
default="all-MiniLM-L6-v2",
|
||||
description="Model name to use",
|
||||
validation_alias="EMBEDDINGS_SENTENCE_TRANSFORMER_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_SENTENCE_TRANSFORMER_MODEL_NAME",
|
||||
"SENTENCE_TRANSFORMER_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
device: str = Field(
|
||||
default="cpu",
|
||||
description="Device to run model on (cpu or cuda)",
|
||||
validation_alias="EMBEDDINGS_SENTENCE_TRANSFORMER_DEVICE",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_SENTENCE_TRANSFORMER_DEVICE", "SENTENCE_TRANSFORMER_DEVICE"
|
||||
),
|
||||
)
|
||||
normalize_embeddings: bool = Field(
|
||||
default=False,
|
||||
description="Whether to normalize embeddings",
|
||||
validation_alias="EMBEDDINGS_SENTENCE_TRANSFORMER_NORMALIZE_EMBEDDINGS",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_SENTENCE_TRANSFORMER_NORMALIZE_EMBEDDINGS",
|
||||
"SENTENCE_TRANSFORMER_NORMALIZE_EMBEDDINGS",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.text2vec_embedding_function import (
|
||||
Text2VecEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -18,5 +18,9 @@ class Text2VecProvider(BaseEmbeddingsProvider[Text2VecEmbeddingFunction]):
|
||||
model_name: str = Field(
|
||||
default="shibing624/text2vec-base-chinese",
|
||||
description="Model name to use",
|
||||
validation_alias="EMBEDDINGS_TEXT2VEC_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_TEXT2VEC_MODEL_NAME",
|
||||
"TEXT2VEC_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Voyage AI embeddings provider."""
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.providers.voyageai.embedding_callable import (
|
||||
@@ -18,38 +18,53 @@ class VoyageAIProvider(BaseEmbeddingsProvider[VoyageAIEmbeddingFunction]):
|
||||
model: str = Field(
|
||||
default="voyage-2",
|
||||
description="Model to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_MODEL",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_VOYAGEAI_MODEL", "VOYAGEAI_MODEL"),
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Voyage AI API key", validation_alias="EMBEDDINGS_VOYAGEAI_API_KEY"
|
||||
description="Voyage AI API key",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_VOYAGEAI_API_KEY", "VOYAGEAI_API_KEY"
|
||||
),
|
||||
)
|
||||
input_type: str | None = Field(
|
||||
default=None,
|
||||
description="Input type for embeddings",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_INPUT_TYPE",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_VOYAGEAI_INPUT_TYPE", "VOYAGEAI_INPUT_TYPE"
|
||||
),
|
||||
)
|
||||
truncation: bool = Field(
|
||||
default=True,
|
||||
description="Whether to truncate inputs",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_TRUNCATION",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_VOYAGEAI_TRUNCATION", "VOYAGEAI_TRUNCATION"
|
||||
),
|
||||
)
|
||||
output_dtype: str | None = Field(
|
||||
default=None,
|
||||
description="Output data type",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_OUTPUT_DTYPE",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_VOYAGEAI_OUTPUT_DTYPE", "VOYAGEAI_OUTPUT_DTYPE"
|
||||
),
|
||||
)
|
||||
output_dimension: int | None = Field(
|
||||
default=None,
|
||||
description="Output dimension",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_OUTPUT_DIMENSION",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_VOYAGEAI_OUTPUT_DIMENSION", "VOYAGEAI_OUTPUT_DIMENSION"
|
||||
),
|
||||
)
|
||||
max_retries: int = Field(
|
||||
default=0,
|
||||
description="Maximum retries for API calls",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_MAX_RETRIES",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_VOYAGEAI_MAX_RETRIES", "VOYAGEAI_MAX_RETRIES"
|
||||
),
|
||||
)
|
||||
timeout: float | None = Field(
|
||||
default=None,
|
||||
description="Timeout for API calls",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_TIMEOUT",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_VOYAGEAI_TIMEOUT", "VOYAGEAI_TIMEOUT"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -29,7 +29,7 @@ from crewai.rag.embeddings.providers.text2vec.types import Text2VecProviderSpec
|
||||
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderSpec
|
||||
|
||||
|
||||
ProviderSpec = (
|
||||
ProviderSpec: TypeAlias = (
|
||||
AzureProviderSpec
|
||||
| BedrockProviderSpec
|
||||
| CohereProviderSpec
|
||||
|
||||
@@ -1,16 +1,23 @@
|
||||
"""Qdrant configuration model."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import field
|
||||
from typing import Literal, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||
from qdrant_client.models import VectorParams
|
||||
|
||||
from crewai.rag.config.base import BaseRagConfig
|
||||
from crewai.rag.qdrant.constants import DEFAULT_EMBEDDING_MODEL, DEFAULT_STORAGE_PATH
|
||||
from crewai.rag.qdrant.types import QdrantClientParams, QdrantEmbeddingFunctionWrapper
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qdrant_client.models import VectorParams
|
||||
else:
|
||||
VectorParams = Any
|
||||
|
||||
|
||||
def _default_options() -> QdrantClientParams:
|
||||
"""Create default Qdrant client options.
|
||||
|
||||
@@ -26,7 +33,7 @@ def _default_embedding_function() -> QdrantEmbeddingFunctionWrapper:
|
||||
Returns:
|
||||
Default embedding function using fastembed with all-MiniLM-L6-v2.
|
||||
"""
|
||||
from fastembed import TextEmbedding # type: ignore[import-not-found]
|
||||
from fastembed import TextEmbedding
|
||||
|
||||
model = TextEmbedding(model_name=DEFAULT_EMBEDDING_MODEL)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user