fix: rag tool embeddings config

* fix: ensure config is not flattened, add tests

* chore: refactor inits to model_validator

* chore: refactor rag tool config parsing

* chore: add initial docs

* chore: add additional validation aliases for provider env vars

* chore: add solid docs

* chore: move imports to top

* fix: revert circular import

* fix: lazy import qdrant-client

* fix: allow collection name config

* chore: narrow model names for google

* chore: update additional docs

* chore: add backward compat on model name aliases

* chore: add tests for config changes
This commit is contained in:
Greyson LaLonde
2025-11-24 16:51:28 -05:00
committed by GitHub
parent 9c84475691
commit a928cde6ee
46 changed files with 1850 additions and 291 deletions

View File

@@ -91,6 +91,7 @@ PROVIDER_PATHS = {
"cohere": "crewai.rag.embeddings.providers.cohere.cohere_provider.CohereProvider",
"custom": "crewai.rag.embeddings.providers.custom.custom_provider.CustomProvider",
"google-generativeai": "crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider",
"google": "crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider",
"google-vertex": "crewai.rag.embeddings.providers.google.vertex.VertexAIProvider",
"huggingface": "crewai.rag.embeddings.providers.huggingface.huggingface_provider.HuggingFaceProvider",
"instructor": "crewai.rag.embeddings.providers.instructor.instructor_provider.InstructorProvider",

View File

@@ -5,7 +5,7 @@ from typing import Any
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
AmazonBedrockEmbeddingFunction,
)
from pydantic import Field
from pydantic import AliasChoices, Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
@@ -21,7 +21,7 @@ def create_aws_session() -> Any:
ValueError: If AWS session creation fails
"""
try:
import boto3 # type: ignore[import]
import boto3
return boto3.Session()
except ImportError as e:
@@ -46,7 +46,12 @@ class BedrockProvider(BaseEmbeddingsProvider[AmazonBedrockEmbeddingFunction]):
model_name: str = Field(
default="amazon.titan-embed-text-v1",
description="Model name to use for embeddings",
validation_alias="EMBEDDINGS_BEDROCK_MODEL_NAME",
validation_alias=AliasChoices(
"EMBEDDINGS_BEDROCK_MODEL_NAME",
"BEDROCK_MODEL_NAME",
"AWS_BEDROCK_MODEL_NAME",
"model",
),
)
session: Any = Field(
default_factory=create_aws_session, description="AWS session object"

View File

@@ -3,7 +3,7 @@
from chromadb.utils.embedding_functions.cohere_embedding_function import (
CohereEmbeddingFunction,
)
from pydantic import Field
from pydantic import AliasChoices, Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
@@ -15,10 +15,14 @@ class CohereProvider(BaseEmbeddingsProvider[CohereEmbeddingFunction]):
default=CohereEmbeddingFunction, description="Cohere embedding function class"
)
api_key: str = Field(
description="Cohere API key", validation_alias="EMBEDDINGS_COHERE_API_KEY"
description="Cohere API key",
validation_alias=AliasChoices("EMBEDDINGS_COHERE_API_KEY", "COHERE_API_KEY"),
)
model_name: str = Field(
default="large",
description="Model name to use for embeddings",
validation_alias="EMBEDDINGS_COHERE_MODEL_NAME",
validation_alias=AliasChoices(
"EMBEDDINGS_COHERE_MODEL_NAME",
"model",
),
)

View File

@@ -1,9 +1,11 @@
"""Google Generative AI embeddings provider."""
from typing import Literal
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleGenerativeAiEmbeddingFunction,
)
from pydantic import Field
from pydantic import AliasChoices, Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
@@ -15,16 +17,27 @@ class GenerativeAiProvider(BaseEmbeddingsProvider[GoogleGenerativeAiEmbeddingFun
default=GoogleGenerativeAiEmbeddingFunction,
description="Google Generative AI embedding function class",
)
model_name: str = Field(
default="models/embedding-001",
model_name: Literal[
"gemini-embedding-001", "text-embedding-005", "text-multilingual-embedding-002"
] = Field(
default="gemini-embedding-001",
description="Model name to use for embeddings",
validation_alias="EMBEDDINGS_GOOGLE_GENERATIVE_AI_MODEL_NAME",
validation_alias=AliasChoices(
"EMBEDDINGS_GOOGLE_GENERATIVE_AI_MODEL_NAME", "model"
),
)
api_key: str = Field(
description="Google API key", validation_alias="EMBEDDINGS_GOOGLE_API_KEY"
description="Google API key",
validation_alias=AliasChoices(
"EMBEDDINGS_GOOGLE_API_KEY", "GOOGLE_API_KEY", "GEMINI_API_KEY"
),
)
task_type: str = Field(
default="RETRIEVAL_DOCUMENT",
description="Task type for embeddings",
validation_alias="EMBEDDINGS_GOOGLE_GENERATIVE_AI_TASK_TYPE",
validation_alias=AliasChoices(
"EMBEDDINGS_GOOGLE_GENERATIVE_AI_TASK_TYPE",
"GOOGLE_GENERATIVE_AI_TASK_TYPE",
"GEMINI_TASK_TYPE",
),
)

View File

@@ -6,10 +6,23 @@ from typing_extensions import Required, TypedDict
class GenerativeAiProviderConfig(TypedDict, total=False):
"""Configuration for Google Generative AI provider."""
"""Configuration for Google Generative AI provider.
Attributes:
api_key: Google API key for authentication.
model_name: Embedding model name.
task_type: Task type for embeddings. Default is "RETRIEVAL_DOCUMENT".
"""
api_key: str
model_name: Annotated[str, "models/embedding-001"]
model_name: Annotated[
Literal[
"gemini-embedding-001",
"text-embedding-005",
"text-multilingual-embedding-002",
],
"gemini-embedding-001",
]
task_type: Annotated[str, "RETRIEVAL_DOCUMENT"]

View File

@@ -3,7 +3,7 @@
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleVertexEmbeddingFunction,
)
from pydantic import Field
from pydantic import AliasChoices, Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
@@ -18,18 +18,29 @@ class VertexAIProvider(BaseEmbeddingsProvider[GoogleVertexEmbeddingFunction]):
model_name: str = Field(
default="textembedding-gecko",
description="Model name to use for embeddings",
validation_alias="EMBEDDINGS_GOOGLE_VERTEX_MODEL_NAME",
validation_alias=AliasChoices(
"EMBEDDINGS_GOOGLE_VERTEX_MODEL_NAME",
"GOOGLE_VERTEX_MODEL_NAME",
"model",
),
)
api_key: str = Field(
description="Google API key", validation_alias="EMBEDDINGS_GOOGLE_CLOUD_API_KEY"
description="Google API key",
validation_alias=AliasChoices(
"EMBEDDINGS_GOOGLE_CLOUD_API_KEY", "GOOGLE_CLOUD_API_KEY"
),
)
project_id: str = Field(
default="cloud-large-language-models",
description="GCP project ID",
validation_alias="EMBEDDINGS_GOOGLE_CLOUD_PROJECT",
validation_alias=AliasChoices(
"EMBEDDINGS_GOOGLE_CLOUD_PROJECT", "GOOGLE_CLOUD_PROJECT"
),
)
region: str = Field(
default="us-central1",
description="GCP region",
validation_alias="EMBEDDINGS_GOOGLE_CLOUD_REGION",
validation_alias=AliasChoices(
"EMBEDDINGS_GOOGLE_CLOUD_REGION", "GOOGLE_CLOUD_REGION"
),
)

View File

@@ -3,7 +3,7 @@
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
HuggingFaceEmbeddingServer,
)
from pydantic import Field
from pydantic import AliasChoices, Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
@@ -16,5 +16,6 @@ class HuggingFaceProvider(BaseEmbeddingsProvider[HuggingFaceEmbeddingServer]):
description="HuggingFace embedding function class",
)
url: str = Field(
description="HuggingFace API URL", validation_alias="EMBEDDINGS_HUGGINGFACE_URL"
description="HuggingFace API URL",
validation_alias=AliasChoices("EMBEDDINGS_HUGGINGFACE_URL", "HUGGINGFACE_URL"),
)

View File

@@ -2,7 +2,7 @@
from typing import Any
from pydantic import Field, model_validator
from pydantic import AliasChoices, Field, model_validator
from typing_extensions import Self
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
@@ -21,7 +21,10 @@ class WatsonXProvider(BaseEmbeddingsProvider[WatsonXEmbeddingFunction]):
default=WatsonXEmbeddingFunction, description="WatsonX embedding function class"
)
model_id: str = Field(
description="WatsonX model ID", validation_alias="EMBEDDINGS_WATSONX_MODEL_ID"
description="WatsonX model ID",
validation_alias=AliasChoices(
"EMBEDDINGS_WATSONX_MODEL_ID", "WATSONX_MODEL_ID"
),
)
params: dict[str, str | dict[str, str]] | None = Field(
default=None, description="Additional parameters"
@@ -30,109 +33,143 @@ class WatsonXProvider(BaseEmbeddingsProvider[WatsonXEmbeddingFunction]):
project_id: str | None = Field(
default=None,
description="WatsonX project ID",
validation_alias="EMBEDDINGS_WATSONX_PROJECT_ID",
validation_alias=AliasChoices(
"EMBEDDINGS_WATSONX_PROJECT_ID", "WATSONX_PROJECT_ID"
),
)
space_id: str | None = Field(
default=None,
description="WatsonX space ID",
validation_alias="EMBEDDINGS_WATSONX_SPACE_ID",
validation_alias=AliasChoices(
"EMBEDDINGS_WATSONX_SPACE_ID", "WATSONX_SPACE_ID"
),
)
api_client: Any | None = Field(default=None, description="WatsonX API client")
verify: bool | str | None = Field(
default=None,
description="SSL verification",
validation_alias="EMBEDDINGS_WATSONX_VERIFY",
validation_alias=AliasChoices("EMBEDDINGS_WATSONX_VERIFY", "WATSONX_VERIFY"),
)
persistent_connection: bool = Field(
default=True,
description="Use persistent connection",
validation_alias="EMBEDDINGS_WATSONX_PERSISTENT_CONNECTION",
validation_alias=AliasChoices(
"EMBEDDINGS_WATSONX_PERSISTENT_CONNECTION", "WATSONX_PERSISTENT_CONNECTION"
),
)
batch_size: int = Field(
default=100,
description="Batch size for processing",
validation_alias="EMBEDDINGS_WATSONX_BATCH_SIZE",
validation_alias=AliasChoices(
"EMBEDDINGS_WATSONX_BATCH_SIZE", "WATSONX_BATCH_SIZE"
),
)
concurrency_limit: int = Field(
default=10,
description="Concurrency limit",
validation_alias="EMBEDDINGS_WATSONX_CONCURRENCY_LIMIT",
validation_alias=AliasChoices(
"EMBEDDINGS_WATSONX_CONCURRENCY_LIMIT", "WATSONX_CONCURRENCY_LIMIT"
),
)
max_retries: int | None = Field(
default=None,
description="Maximum retries",
validation_alias="EMBEDDINGS_WATSONX_MAX_RETRIES",
validation_alias=AliasChoices(
"EMBEDDINGS_WATSONX_MAX_RETRIES", "WATSONX_MAX_RETRIES"
),
)
delay_time: float | None = Field(
default=None,
description="Delay time between retries",
validation_alias="EMBEDDINGS_WATSONX_DELAY_TIME",
validation_alias=AliasChoices(
"EMBEDDINGS_WATSONX_DELAY_TIME", "WATSONX_DELAY_TIME"
),
)
retry_status_codes: list[int] | None = Field(
default=None, description="HTTP status codes to retry on"
)
url: str = Field(
description="WatsonX API URL", validation_alias="EMBEDDINGS_WATSONX_URL"
description="WatsonX API URL",
validation_alias=AliasChoices("EMBEDDINGS_WATSONX_URL", "WATSONX_URL"),
)
api_key: str = Field(
description="WatsonX API key", validation_alias="EMBEDDINGS_WATSONX_API_KEY"
description="WatsonX API key",
validation_alias=AliasChoices("EMBEDDINGS_WATSONX_API_KEY", "WATSONX_API_KEY"),
)
name: str | None = Field(
default=None,
description="Service name",
validation_alias="EMBEDDINGS_WATSONX_NAME",
validation_alias=AliasChoices("EMBEDDINGS_WATSONX_NAME", "WATSONX_NAME"),
)
iam_serviceid_crn: str | None = Field(
default=None,
description="IAM service ID CRN",
validation_alias="EMBEDDINGS_WATSONX_IAM_SERVICEID_CRN",
validation_alias=AliasChoices(
"EMBEDDINGS_WATSONX_IAM_SERVICEID_CRN", "WATSONX_IAM_SERVICEID_CRN"
),
)
trusted_profile_id: str | None = Field(
default=None,
description="Trusted profile ID",
validation_alias="EMBEDDINGS_WATSONX_TRUSTED_PROFILE_ID",
validation_alias=AliasChoices(
"EMBEDDINGS_WATSONX_TRUSTED_PROFILE_ID", "WATSONX_TRUSTED_PROFILE_ID"
),
)
token: str | None = Field(
default=None,
description="Bearer token",
validation_alias="EMBEDDINGS_WATSONX_TOKEN",
validation_alias=AliasChoices("EMBEDDINGS_WATSONX_TOKEN", "WATSONX_TOKEN"),
)
projects_token: str | None = Field(
default=None,
description="Projects token",
validation_alias="EMBEDDINGS_WATSONX_PROJECTS_TOKEN",
validation_alias=AliasChoices(
"EMBEDDINGS_WATSONX_PROJECTS_TOKEN", "WATSONX_PROJECTS_TOKEN"
),
)
username: str | None = Field(
default=None,
description="Username",
validation_alias="EMBEDDINGS_WATSONX_USERNAME",
validation_alias=AliasChoices(
"EMBEDDINGS_WATSONX_USERNAME", "WATSONX_USERNAME"
),
)
password: str | None = Field(
default=None,
description="Password",
validation_alias="EMBEDDINGS_WATSONX_PASSWORD",
validation_alias=AliasChoices(
"EMBEDDINGS_WATSONX_PASSWORD", "WATSONX_PASSWORD"
),
)
instance_id: str | None = Field(
default=None,
description="Service instance ID",
validation_alias="EMBEDDINGS_WATSONX_INSTANCE_ID",
validation_alias=AliasChoices(
"EMBEDDINGS_WATSONX_INSTANCE_ID", "WATSONX_INSTANCE_ID"
),
)
version: str | None = Field(
default=None,
description="API version",
validation_alias="EMBEDDINGS_WATSONX_VERSION",
validation_alias=AliasChoices("EMBEDDINGS_WATSONX_VERSION", "WATSONX_VERSION"),
)
bedrock_url: str | None = Field(
default=None,
description="Bedrock URL",
validation_alias="EMBEDDINGS_WATSONX_BEDROCK_URL",
validation_alias=AliasChoices(
"EMBEDDINGS_WATSONX_BEDROCK_URL", "WATSONX_BEDROCK_URL"
),
)
platform_url: str | None = Field(
default=None,
description="Platform URL",
validation_alias="EMBEDDINGS_WATSONX_PLATFORM_URL",
validation_alias=AliasChoices(
"EMBEDDINGS_WATSONX_PLATFORM_URL", "WATSONX_PLATFORM_URL"
),
)
proxies: dict[str, Any] | None = Field(
default=None, description="Proxy configuration"
)
proxies: dict | None = Field(default=None, description="Proxy configuration")
@model_validator(mode="after")
def validate_space_or_project(self) -> Self:

View File

@@ -3,7 +3,7 @@
from chromadb.utils.embedding_functions.instructor_embedding_function import (
InstructorEmbeddingFunction,
)
from pydantic import Field
from pydantic import AliasChoices, Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
@@ -18,15 +18,23 @@ class InstructorProvider(BaseEmbeddingsProvider[InstructorEmbeddingFunction]):
model_name: str = Field(
default="hkunlp/instructor-base",
description="Model name to use",
validation_alias="EMBEDDINGS_INSTRUCTOR_MODEL_NAME",
validation_alias=AliasChoices(
"EMBEDDINGS_INSTRUCTOR_MODEL_NAME",
"INSTRUCTOR_MODEL_NAME",
"model",
),
)
device: str = Field(
default="cpu",
description="Device to run model on (cpu or cuda)",
validation_alias="EMBEDDINGS_INSTRUCTOR_DEVICE",
validation_alias=AliasChoices(
"EMBEDDINGS_INSTRUCTOR_DEVICE", "INSTRUCTOR_DEVICE"
),
)
instruction: str | None = Field(
default=None,
description="Instruction for embeddings",
validation_alias="EMBEDDINGS_INSTRUCTOR_INSTRUCTION",
validation_alias=AliasChoices(
"EMBEDDINGS_INSTRUCTOR_INSTRUCTION", "INSTRUCTOR_INSTRUCTION"
),
)

View File

@@ -3,7 +3,7 @@
from chromadb.utils.embedding_functions.jina_embedding_function import (
JinaEmbeddingFunction,
)
from pydantic import Field
from pydantic import AliasChoices, Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
@@ -15,10 +15,15 @@ class JinaProvider(BaseEmbeddingsProvider[JinaEmbeddingFunction]):
default=JinaEmbeddingFunction, description="Jina embedding function class"
)
api_key: str = Field(
description="Jina API key", validation_alias="EMBEDDINGS_JINA_API_KEY"
description="Jina API key",
validation_alias=AliasChoices("EMBEDDINGS_JINA_API_KEY", "JINA_API_KEY"),
)
model_name: str = Field(
default="jina-embeddings-v2-base-en",
description="Model name to use for embeddings",
validation_alias="EMBEDDINGS_JINA_MODEL_NAME",
validation_alias=AliasChoices(
"EMBEDDINGS_JINA_MODEL_NAME",
"JINA_MODEL_NAME",
"model",
),
)

View File

@@ -5,7 +5,7 @@ from typing import Any
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
from pydantic import Field
from pydantic import AliasChoices, Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
@@ -18,27 +18,39 @@ class AzureProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
description="Azure OpenAI embedding function class",
)
api_key: str = Field(
description="Azure API key", validation_alias="EMBEDDINGS_OPENAI_API_KEY"
description="Azure API key",
validation_alias=AliasChoices("EMBEDDINGS_OPENAI_API_KEY", "OPENAI_API_KEY"),
)
api_base: str | None = Field(
default=None,
description="Azure endpoint URL",
validation_alias="EMBEDDINGS_OPENAI_API_BASE",
validation_alias=AliasChoices("EMBEDDINGS_OPENAI_API_BASE", "OPENAI_API_BASE"),
)
api_type: str = Field(
default="azure",
description="API type for Azure",
validation_alias="EMBEDDINGS_OPENAI_API_TYPE",
validation_alias=AliasChoices(
"EMBEDDINGS_OPENAI_API_TYPE", "OPENAI_API_TYPE", "AZURE_OPENAI_API_TYPE"
),
)
api_version: str | None = Field(
default=None,
default="2024-02-01",
description="Azure API version",
validation_alias="EMBEDDINGS_OPENAI_API_VERSION",
validation_alias=AliasChoices(
"EMBEDDINGS_OPENAI_API_VERSION",
"OPENAI_API_VERSION",
"AZURE_OPENAI_API_VERSION",
),
)
model_name: str = Field(
default="text-embedding-ada-002",
description="Model name to use for embeddings",
validation_alias="EMBEDDINGS_OPENAI_MODEL_NAME",
validation_alias=AliasChoices(
"EMBEDDINGS_OPENAI_MODEL_NAME",
"OPENAI_MODEL_NAME",
"AZURE_OPENAI_MODEL_NAME",
"model",
),
)
default_headers: dict[str, Any] | None = Field(
default=None, description="Default headers for API requests"
@@ -46,15 +58,26 @@ class AzureProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
dimensions: int | None = Field(
default=None,
description="Embedding dimensions",
validation_alias="EMBEDDINGS_OPENAI_DIMENSIONS",
validation_alias=AliasChoices(
"EMBEDDINGS_OPENAI_DIMENSIONS",
"OPENAI_DIMENSIONS",
"AZURE_OPENAI_DIMENSIONS",
),
)
deployment_id: str | None = Field(
default=None,
deployment_id: str = Field(
description="Azure deployment ID",
validation_alias="EMBEDDINGS_OPENAI_DEPLOYMENT_ID",
validation_alias=AliasChoices(
"EMBEDDINGS_OPENAI_DEPLOYMENT_ID",
"AZURE_OPENAI_DEPLOYMENT",
"AZURE_DEPLOYMENT_ID",
),
)
organization_id: str | None = Field(
default=None,
description="Organization ID",
validation_alias="EMBEDDINGS_OPENAI_ORGANIZATION_ID",
validation_alias=AliasChoices(
"EMBEDDINGS_OPENAI_ORGANIZATION_ID",
"OPENAI_ORGANIZATION_ID",
"AZURE_OPENAI_ORGANIZATION_ID",
),
)

View File

@@ -15,7 +15,7 @@ class AzureProviderConfig(TypedDict, total=False):
model_name: Annotated[str, "text-embedding-ada-002"]
default_headers: dict[str, Any]
dimensions: int
deployment_id: str
deployment_id: Required[str]
organization_id: str

View File

@@ -3,7 +3,7 @@
from chromadb.utils.embedding_functions.ollama_embedding_function import (
OllamaEmbeddingFunction,
)
from pydantic import Field
from pydantic import AliasChoices, Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
@@ -17,9 +17,14 @@ class OllamaProvider(BaseEmbeddingsProvider[OllamaEmbeddingFunction]):
url: str = Field(
default="http://localhost:11434/api/embeddings",
description="Ollama API endpoint URL",
validation_alias="EMBEDDINGS_OLLAMA_URL",
validation_alias=AliasChoices("EMBEDDINGS_OLLAMA_URL", "OLLAMA_URL"),
)
model_name: str = Field(
description="Model name to use for embeddings",
validation_alias="EMBEDDINGS_OLLAMA_MODEL_NAME",
validation_alias=AliasChoices(
"EMBEDDINGS_OLLAMA_MODEL_NAME",
"OLLAMA_MODEL_NAME",
"OLLAMA_MODEL",
"model",
),
)

View File

@@ -1,7 +1,7 @@
"""ONNX embeddings provider."""
from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ONNXMiniLM_L6_V2
from pydantic import Field
from pydantic import AliasChoices, Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
@@ -15,5 +15,7 @@ class ONNXProvider(BaseEmbeddingsProvider[ONNXMiniLM_L6_V2]):
preferred_providers: list[str] | None = Field(
default=None,
description="Preferred ONNX execution providers",
validation_alias="EMBEDDINGS_ONNX_PREFERRED_PROVIDERS",
validation_alias=AliasChoices(
"EMBEDDINGS_ONNX_PREFERRED_PROVIDERS", "ONNX_PREFERRED_PROVIDERS"
),
)

View File

@@ -5,7 +5,7 @@ from typing import Any
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
from pydantic import Field
from pydantic import AliasChoices, Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
@@ -20,27 +20,33 @@ class OpenAIProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
api_key: str | None = Field(
default=None,
description="OpenAI API key",
validation_alias="EMBEDDINGS_OPENAI_API_KEY",
validation_alias=AliasChoices("EMBEDDINGS_OPENAI_API_KEY", "OPENAI_API_KEY"),
)
model_name: str = Field(
default="text-embedding-ada-002",
description="Model name to use for embeddings",
validation_alias="EMBEDDINGS_OPENAI_MODEL_NAME",
validation_alias=AliasChoices(
"EMBEDDINGS_OPENAI_MODEL_NAME",
"OPENAI_MODEL_NAME",
"model",
),
)
api_base: str | None = Field(
default=None,
description="Base URL for API requests",
validation_alias="EMBEDDINGS_OPENAI_API_BASE",
validation_alias=AliasChoices("EMBEDDINGS_OPENAI_API_BASE", "OPENAI_API_BASE"),
)
api_type: str | None = Field(
default=None,
description="API type (e.g., 'azure')",
validation_alias="EMBEDDINGS_OPENAI_API_TYPE",
validation_alias=AliasChoices("EMBEDDINGS_OPENAI_API_TYPE", "OPENAI_API_TYPE"),
)
api_version: str | None = Field(
default=None,
description="API version",
validation_alias="EMBEDDINGS_OPENAI_API_VERSION",
validation_alias=AliasChoices(
"EMBEDDINGS_OPENAI_API_VERSION", "OPENAI_API_VERSION"
),
)
default_headers: dict[str, Any] | None = Field(
default=None, description="Default headers for API requests"
@@ -48,15 +54,21 @@ class OpenAIProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
dimensions: int | None = Field(
default=None,
description="Embedding dimensions",
validation_alias="EMBEDDINGS_OPENAI_DIMENSIONS",
validation_alias=AliasChoices(
"EMBEDDINGS_OPENAI_DIMENSIONS", "OPENAI_DIMENSIONS"
),
)
deployment_id: str | None = Field(
default=None,
description="Azure deployment ID",
validation_alias="EMBEDDINGS_OPENAI_DEPLOYMENT_ID",
validation_alias=AliasChoices(
"EMBEDDINGS_OPENAI_DEPLOYMENT_ID", "OPENAI_DEPLOYMENT_ID"
),
)
organization_id: str | None = Field(
default=None,
description="OpenAI organization ID",
validation_alias="EMBEDDINGS_OPENAI_ORGANIZATION_ID",
validation_alias=AliasChoices(
"EMBEDDINGS_OPENAI_ORGANIZATION_ID", "OPENAI_ORGANIZATION_ID"
),
)

View File

@@ -3,7 +3,7 @@
from chromadb.utils.embedding_functions.open_clip_embedding_function import (
OpenCLIPEmbeddingFunction,
)
from pydantic import Field
from pydantic import AliasChoices, Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
@@ -18,15 +18,21 @@ class OpenCLIPProvider(BaseEmbeddingsProvider[OpenCLIPEmbeddingFunction]):
model_name: str = Field(
default="ViT-B-32",
description="Model name to use",
validation_alias="EMBEDDINGS_OPENCLIP_MODEL_NAME",
validation_alias=AliasChoices(
"EMBEDDINGS_OPENCLIP_MODEL_NAME",
"OPENCLIP_MODEL_NAME",
"model",
),
)
checkpoint: str = Field(
default="laion2b_s34b_b79k",
description="Model checkpoint",
validation_alias="EMBEDDINGS_OPENCLIP_CHECKPOINT",
validation_alias=AliasChoices(
"EMBEDDINGS_OPENCLIP_CHECKPOINT", "OPENCLIP_CHECKPOINT"
),
)
device: str | None = Field(
default="cpu",
description="Device to run model on",
validation_alias="EMBEDDINGS_OPENCLIP_DEVICE",
validation_alias=AliasChoices("EMBEDDINGS_OPENCLIP_DEVICE", "OPENCLIP_DEVICE"),
)

View File

@@ -3,7 +3,7 @@
from chromadb.utils.embedding_functions.roboflow_embedding_function import (
RoboflowEmbeddingFunction,
)
from pydantic import Field
from pydantic import AliasChoices, Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
@@ -18,10 +18,14 @@ class RoboflowProvider(BaseEmbeddingsProvider[RoboflowEmbeddingFunction]):
api_key: str = Field(
default="",
description="Roboflow API key",
validation_alias="EMBEDDINGS_ROBOFLOW_API_KEY",
validation_alias=AliasChoices(
"EMBEDDINGS_ROBOFLOW_API_KEY", "ROBOFLOW_API_KEY"
),
)
api_url: str = Field(
default="https://infer.roboflow.com",
description="Roboflow API URL",
validation_alias="EMBEDDINGS_ROBOFLOW_API_URL",
validation_alias=AliasChoices(
"EMBEDDINGS_ROBOFLOW_API_URL", "ROBOFLOW_API_URL"
),
)

View File

@@ -3,7 +3,7 @@
from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import (
SentenceTransformerEmbeddingFunction,
)
from pydantic import Field
from pydantic import AliasChoices, Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
@@ -20,15 +20,24 @@ class SentenceTransformerProvider(
model_name: str = Field(
default="all-MiniLM-L6-v2",
description="Model name to use",
validation_alias="EMBEDDINGS_SENTENCE_TRANSFORMER_MODEL_NAME",
validation_alias=AliasChoices(
"EMBEDDINGS_SENTENCE_TRANSFORMER_MODEL_NAME",
"SENTENCE_TRANSFORMER_MODEL_NAME",
"model",
),
)
device: str = Field(
default="cpu",
description="Device to run model on (cpu or cuda)",
validation_alias="EMBEDDINGS_SENTENCE_TRANSFORMER_DEVICE",
validation_alias=AliasChoices(
"EMBEDDINGS_SENTENCE_TRANSFORMER_DEVICE", "SENTENCE_TRANSFORMER_DEVICE"
),
)
normalize_embeddings: bool = Field(
default=False,
description="Whether to normalize embeddings",
validation_alias="EMBEDDINGS_SENTENCE_TRANSFORMER_NORMALIZE_EMBEDDINGS",
validation_alias=AliasChoices(
"EMBEDDINGS_SENTENCE_TRANSFORMER_NORMALIZE_EMBEDDINGS",
"SENTENCE_TRANSFORMER_NORMALIZE_EMBEDDINGS",
),
)

View File

@@ -3,7 +3,7 @@
from chromadb.utils.embedding_functions.text2vec_embedding_function import (
Text2VecEmbeddingFunction,
)
from pydantic import Field
from pydantic import AliasChoices, Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
@@ -18,5 +18,9 @@ class Text2VecProvider(BaseEmbeddingsProvider[Text2VecEmbeddingFunction]):
model_name: str = Field(
default="shibing624/text2vec-base-chinese",
description="Model name to use",
validation_alias="EMBEDDINGS_TEXT2VEC_MODEL_NAME",
validation_alias=AliasChoices(
"EMBEDDINGS_TEXT2VEC_MODEL_NAME",
"TEXT2VEC_MODEL_NAME",
"model",
),
)

View File

@@ -1,6 +1,6 @@
"""Voyage AI embeddings provider."""
from pydantic import Field
from pydantic import AliasChoices, Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.rag.embeddings.providers.voyageai.embedding_callable import (
@@ -18,38 +18,53 @@ class VoyageAIProvider(BaseEmbeddingsProvider[VoyageAIEmbeddingFunction]):
model: str = Field(
default="voyage-2",
description="Model to use for embeddings",
validation_alias="EMBEDDINGS_VOYAGEAI_MODEL",
validation_alias=AliasChoices("EMBEDDINGS_VOYAGEAI_MODEL", "VOYAGEAI_MODEL"),
)
api_key: str = Field(
description="Voyage AI API key", validation_alias="EMBEDDINGS_VOYAGEAI_API_KEY"
description="Voyage AI API key",
validation_alias=AliasChoices(
"EMBEDDINGS_VOYAGEAI_API_KEY", "VOYAGEAI_API_KEY"
),
)
input_type: str | None = Field(
default=None,
description="Input type for embeddings",
validation_alias="EMBEDDINGS_VOYAGEAI_INPUT_TYPE",
validation_alias=AliasChoices(
"EMBEDDINGS_VOYAGEAI_INPUT_TYPE", "VOYAGEAI_INPUT_TYPE"
),
)
truncation: bool = Field(
default=True,
description="Whether to truncate inputs",
validation_alias="EMBEDDINGS_VOYAGEAI_TRUNCATION",
validation_alias=AliasChoices(
"EMBEDDINGS_VOYAGEAI_TRUNCATION", "VOYAGEAI_TRUNCATION"
),
)
output_dtype: str | None = Field(
default=None,
description="Output data type",
validation_alias="EMBEDDINGS_VOYAGEAI_OUTPUT_DTYPE",
validation_alias=AliasChoices(
"EMBEDDINGS_VOYAGEAI_OUTPUT_DTYPE", "VOYAGEAI_OUTPUT_DTYPE"
),
)
output_dimension: int | None = Field(
default=None,
description="Output dimension",
validation_alias="EMBEDDINGS_VOYAGEAI_OUTPUT_DIMENSION",
validation_alias=AliasChoices(
"EMBEDDINGS_VOYAGEAI_OUTPUT_DIMENSION", "VOYAGEAI_OUTPUT_DIMENSION"
),
)
max_retries: int = Field(
default=0,
description="Maximum retries for API calls",
validation_alias="EMBEDDINGS_VOYAGEAI_MAX_RETRIES",
validation_alias=AliasChoices(
"EMBEDDINGS_VOYAGEAI_MAX_RETRIES", "VOYAGEAI_MAX_RETRIES"
),
)
timeout: float | None = Field(
default=None,
description="Timeout for API calls",
validation_alias="EMBEDDINGS_VOYAGEAI_TIMEOUT",
validation_alias=AliasChoices(
"EMBEDDINGS_VOYAGEAI_TIMEOUT", "VOYAGEAI_TIMEOUT"
),
)

View File

@@ -29,7 +29,7 @@ from crewai.rag.embeddings.providers.text2vec.types import Text2VecProviderSpec
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderSpec
ProviderSpec = (
ProviderSpec: TypeAlias = (
AzureProviderSpec
| BedrockProviderSpec
| CohereProviderSpec

View File

@@ -1,16 +1,23 @@
"""Qdrant configuration model."""
from __future__ import annotations
from dataclasses import field
from typing import Literal, cast
from typing import TYPE_CHECKING, Any, Literal, cast
from pydantic.dataclasses import dataclass as pyd_dataclass
from qdrant_client.models import VectorParams
from crewai.rag.config.base import BaseRagConfig
from crewai.rag.qdrant.constants import DEFAULT_EMBEDDING_MODEL, DEFAULT_STORAGE_PATH
from crewai.rag.qdrant.types import QdrantClientParams, QdrantEmbeddingFunctionWrapper
if TYPE_CHECKING:
from qdrant_client.models import VectorParams
else:
VectorParams = Any
def _default_options() -> QdrantClientParams:
"""Create default Qdrant client options.
@@ -26,7 +33,7 @@ def _default_embedding_function() -> QdrantEmbeddingFunctionWrapper:
Returns:
Default embedding function using fastembed with all-MiniLM-L6-v2.
"""
from fastembed import TextEmbedding # type: ignore[import-not-found]
from fastembed import TextEmbedding
model = TextEmbedding(model_name=DEFAULT_EMBEDDING_MODEL)

View File

@@ -0,0 +1,364 @@
"""Tests for backward compatibility of embedding provider configurations."""
from crewai.rag.embeddings.factory import build_embedder, PROVIDER_PATHS
from crewai.rag.embeddings.providers.openai.openai_provider import OpenAIProvider
from crewai.rag.embeddings.providers.cohere.cohere_provider import CohereProvider
from crewai.rag.embeddings.providers.google.generative_ai import GenerativeAiProvider
from crewai.rag.embeddings.providers.google.vertex import VertexAIProvider
from crewai.rag.embeddings.providers.microsoft.azure import AzureProvider
from crewai.rag.embeddings.providers.jina.jina_provider import JinaProvider
from crewai.rag.embeddings.providers.ollama.ollama_provider import OllamaProvider
from crewai.rag.embeddings.providers.aws.bedrock import BedrockProvider
from crewai.rag.embeddings.providers.text2vec.text2vec_provider import Text2VecProvider
from crewai.rag.embeddings.providers.sentence_transformer.sentence_transformer_provider import (
SentenceTransformerProvider,
)
from crewai.rag.embeddings.providers.instructor.instructor_provider import InstructorProvider
from crewai.rag.embeddings.providers.openclip.openclip_provider import OpenCLIPProvider
class TestGoogleProviderAlias:
"""Test that 'google' provider name alias works for backward compatibility."""
def test_google_alias_in_provider_paths(self):
"""Verify 'google' is registered as an alias for google-generativeai."""
assert "google" in PROVIDER_PATHS
assert "google-generativeai" in PROVIDER_PATHS
assert PROVIDER_PATHS["google"] == PROVIDER_PATHS["google-generativeai"]
class TestModelKeyBackwardCompatibility:
"""Test that 'model' config key works as alias for 'model_name'."""
def test_openai_provider_accepts_model_key(self):
"""Test OpenAI provider accepts 'model' as alias for 'model_name'."""
provider = OpenAIProvider(
api_key="test-key",
model="text-embedding-3-small",
)
assert provider.model_name == "text-embedding-3-small"
def test_openai_provider_model_name_takes_precedence(self):
"""Test that model_name takes precedence when both are provided."""
provider = OpenAIProvider(
api_key="test-key",
model_name="text-embedding-3-large",
)
assert provider.model_name == "text-embedding-3-large"
def test_cohere_provider_accepts_model_key(self):
"""Test Cohere provider accepts 'model' as alias for 'model_name'."""
provider = CohereProvider(
api_key="test-key",
model="embed-english-v3.0",
)
assert provider.model_name == "embed-english-v3.0"
def test_google_generativeai_provider_accepts_model_key(self):
"""Test Google Generative AI provider accepts 'model' as alias."""
provider = GenerativeAiProvider(
api_key="test-key",
model="gemini-embedding-001",
)
assert provider.model_name == "gemini-embedding-001"
def test_google_vertex_provider_accepts_model_key(self):
"""Test Google Vertex AI provider accepts 'model' as alias."""
provider = VertexAIProvider(
api_key="test-key",
model="text-embedding-004",
)
assert provider.model_name == "text-embedding-004"
def test_azure_provider_accepts_model_key(self):
"""Test Azure provider accepts 'model' as alias for 'model_name'."""
provider = AzureProvider(
api_key="test-key",
deployment_id="test-deployment",
model="text-embedding-ada-002",
)
assert provider.model_name == "text-embedding-ada-002"
def test_jina_provider_accepts_model_key(self):
"""Test Jina provider accepts 'model' as alias for 'model_name'."""
provider = JinaProvider(
api_key="test-key",
model="jina-embeddings-v3",
)
assert provider.model_name == "jina-embeddings-v3"
def test_ollama_provider_accepts_model_key(self):
"""Test Ollama provider accepts 'model' as alias for 'model_name'."""
provider = OllamaProvider(
model="nomic-embed-text",
)
assert provider.model_name == "nomic-embed-text"
def test_text2vec_provider_accepts_model_key(self):
"""Test Text2Vec provider accepts 'model' as alias for 'model_name'."""
provider = Text2VecProvider(
model="shibing624/text2vec-base-multilingual",
)
assert provider.model_name == "shibing624/text2vec-base-multilingual"
def test_sentence_transformer_provider_accepts_model_key(self):
"""Test SentenceTransformer provider accepts 'model' as alias."""
provider = SentenceTransformerProvider(
model="all-mpnet-base-v2",
)
assert provider.model_name == "all-mpnet-base-v2"
def test_instructor_provider_accepts_model_key(self):
"""Test Instructor provider accepts 'model' as alias for 'model_name'."""
provider = InstructorProvider(
model="hkunlp/instructor-xl",
)
assert provider.model_name == "hkunlp/instructor-xl"
def test_openclip_provider_accepts_model_key(self):
"""Test OpenCLIP provider accepts 'model' as alias for 'model_name'."""
provider = OpenCLIPProvider(
model="ViT-B-16",
)
assert provider.model_name == "ViT-B-16"
class TestTaskTypeConfiguration:
"""Test that task_type configuration works correctly."""
def test_google_provider_accepts_lowercase_task_type(self):
"""Test Google provider accepts lowercase task_type."""
provider = GenerativeAiProvider(
api_key="test-key",
task_type="retrieval_document",
)
assert provider.task_type == "retrieval_document"
def test_google_provider_accepts_uppercase_task_type(self):
"""Test Google provider accepts uppercase task_type."""
provider = GenerativeAiProvider(
api_key="test-key",
task_type="RETRIEVAL_QUERY",
)
assert provider.task_type == "RETRIEVAL_QUERY"
def test_google_provider_default_task_type(self):
"""Test Google provider has correct default task_type."""
provider = GenerativeAiProvider(
api_key="test-key",
)
assert provider.task_type == "RETRIEVAL_DOCUMENT"
class TestFactoryBackwardCompatibility:
"""Test factory function with backward compatible configurations."""
def test_factory_with_google_alias(self):
"""Test factory resolves 'google' to google-generativeai provider."""
config = {
"provider": "google",
"config": {
"api_key": "test-key",
"model": "gemini-embedding-001",
},
}
from unittest.mock import patch, MagicMock
with patch("crewai.rag.embeddings.factory.import_and_validate_definition") as mock_import:
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider"
)
def test_factory_with_model_key_openai(self):
"""Test factory passes 'model' config to OpenAI provider."""
config = {
"provider": "openai",
"config": {
"api_key": "test-key",
"model": "text-embedding-3-small",
},
}
from unittest.mock import patch, MagicMock
with patch("crewai.rag.embeddings.factory.import_and_validate_definition") as mock_import:
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
build_embedder(config)
call_kwargs = mock_provider_class.call_args.kwargs
assert call_kwargs["model"] == "text-embedding-3-small"
class TestDocumentationCodeSnippets:
"""Test code snippets from documentation work correctly."""
def test_memory_openai_config(self):
"""Test OpenAI config from memory.mdx documentation."""
provider = OpenAIProvider(
model_name="text-embedding-3-small",
)
assert provider.model_name == "text-embedding-3-small"
def test_memory_openai_config_with_options(self):
"""Test OpenAI config with all options from memory.mdx."""
provider = OpenAIProvider(
api_key="your-openai-api-key",
model_name="text-embedding-3-large",
dimensions=1536,
organization_id="your-org-id",
)
assert provider.model_name == "text-embedding-3-large"
assert provider.dimensions == 1536
def test_memory_azure_config(self):
"""Test Azure config from memory.mdx documentation."""
provider = AzureProvider(
api_key="your-azure-key",
api_base="https://your-resource.openai.azure.com/",
api_type="azure",
api_version="2023-05-15",
model_name="text-embedding-3-small",
deployment_id="your-deployment-name",
)
assert provider.model_name == "text-embedding-3-small"
assert provider.api_type == "azure"
def test_memory_google_generativeai_config(self):
"""Test Google Generative AI config from memory.mdx documentation."""
provider = GenerativeAiProvider(
api_key="your-google-api-key",
model_name="gemini-embedding-001",
)
assert provider.model_name == "gemini-embedding-001"
def test_memory_cohere_config(self):
"""Test Cohere config from memory.mdx documentation."""
provider = CohereProvider(
api_key="your-cohere-api-key",
model_name="embed-english-v3.0",
)
assert provider.model_name == "embed-english-v3.0"
def test_knowledge_agent_embedder_config(self):
"""Test agent embedder config from knowledge.mdx documentation."""
provider = GenerativeAiProvider(
model_name="gemini-embedding-001",
api_key="your-google-key",
)
assert provider.model_name == "gemini-embedding-001"
def test_ragtool_openai_config(self):
"""Test RagTool OpenAI config from ragtool.mdx documentation."""
provider = OpenAIProvider(
model_name="text-embedding-3-small",
)
assert provider.model_name == "text-embedding-3-small"
def test_ragtool_cohere_config(self):
"""Test RagTool Cohere config from ragtool.mdx documentation."""
provider = CohereProvider(
api_key="your-api-key",
model_name="embed-english-v3.0",
)
assert provider.model_name == "embed-english-v3.0"
def test_ragtool_ollama_config(self):
"""Test RagTool Ollama config from ragtool.mdx documentation."""
provider = OllamaProvider(
model_name="llama2",
url="http://localhost:11434/api/embeddings",
)
assert provider.model_name == "llama2"
def test_ragtool_azure_config(self):
"""Test RagTool Azure config from ragtool.mdx documentation."""
provider = AzureProvider(
deployment_id="your-deployment-id",
api_key="your-api-key",
api_base="https://your-resource.openai.azure.com",
api_version="2024-02-01",
model_name="text-embedding-ada-002",
api_type="azure",
)
assert provider.model_name == "text-embedding-ada-002"
assert provider.deployment_id == "your-deployment-id"
def test_ragtool_google_generativeai_config(self):
"""Test RagTool Google Generative AI config from ragtool.mdx."""
provider = GenerativeAiProvider(
api_key="your-api-key",
model_name="gemini-embedding-001",
task_type="RETRIEVAL_DOCUMENT",
)
assert provider.model_name == "gemini-embedding-001"
assert provider.task_type == "RETRIEVAL_DOCUMENT"
def test_ragtool_jina_config(self):
"""Test RagTool Jina config from ragtool.mdx documentation."""
provider = JinaProvider(
api_key="your-api-key",
model_name="jina-embeddings-v3",
)
assert provider.model_name == "jina-embeddings-v3"
def test_ragtool_sentence_transformer_config(self):
"""Test RagTool SentenceTransformer config from ragtool.mdx."""
provider = SentenceTransformerProvider(
model_name="all-mpnet-base-v2",
device="cuda",
normalize_embeddings=True,
)
assert provider.model_name == "all-mpnet-base-v2"
assert provider.device == "cuda"
assert provider.normalize_embeddings is True
class TestLegacyConfigurationFormats:
"""Test legacy configuration formats that should still work."""
def test_legacy_google_with_model_key(self):
"""Test legacy Google config using 'model' instead of 'model_name'."""
provider = GenerativeAiProvider(
api_key="test-key",
model="text-embedding-005",
task_type="retrieval_document",
)
assert provider.model_name == "text-embedding-005"
assert provider.task_type == "retrieval_document"
def test_legacy_openai_with_model_key(self):
"""Test legacy OpenAI config using 'model' instead of 'model_name'."""
provider = OpenAIProvider(
api_key="test-key",
model="text-embedding-ada-002",
)
assert provider.model_name == "text-embedding-ada-002"
def test_legacy_cohere_with_model_key(self):
"""Test legacy Cohere config using 'model' instead of 'model_name'."""
provider = CohereProvider(
api_key="test-key",
model="embed-multilingual-v3.0",
)
assert provider.model_name == "embed-multilingual-v3.0"
def test_legacy_azure_with_model_key(self):
"""Test legacy Azure config using 'model' instead of 'model_name'."""
provider = AzureProvider(
api_key="test-key",
deployment_id="test-deployment",
model="text-embedding-3-large",
)
assert provider.model_name == "text-embedding-3-large"