mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-15 02:58:30 +00:00
- introduce baseembeddingsprovider and helper for embedding functions - add core embedding types and migrate providers, factory, and storage modules - remove unused type aliases and fix pydantic schema error - update providers with env var support and related fixes
59 lines
1.8 KiB
Python
59 lines
1.8 KiB
Python
"""Azure OpenAI embeddings provider."""
|
|
|
|
from typing import Any
|
|
|
|
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
|
OpenAIEmbeddingFunction,
|
|
)
|
|
from pydantic import Field
|
|
|
|
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
|
|
|
|
|
class AzureProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
|
|
"""Azure OpenAI embeddings provider."""
|
|
|
|
embedding_callable: type[OpenAIEmbeddingFunction] = Field(
|
|
default=OpenAIEmbeddingFunction,
|
|
description="Azure OpenAI embedding function class",
|
|
)
|
|
api_key: str = Field(description="Azure API key", validation_alias="OPENAI_API_KEY")
|
|
api_base: str | None = Field(
|
|
default=None,
|
|
description="Azure endpoint URL",
|
|
validation_alias="OPENAI_API_BASE",
|
|
)
|
|
api_type: str = Field(
|
|
default="azure",
|
|
description="API type for Azure",
|
|
validation_alias="OPENAI_API_TYPE",
|
|
)
|
|
api_version: str | None = Field(
|
|
default=None,
|
|
description="Azure API version",
|
|
validation_alias="OPENAI_API_VERSION",
|
|
)
|
|
model_name: str = Field(
|
|
default="text-embedding-ada-002",
|
|
description="Model name to use for embeddings",
|
|
validation_alias="OPENAI_MODEL_NAME",
|
|
)
|
|
default_headers: dict[str, Any] | None = Field(
|
|
default=None, description="Default headers for API requests"
|
|
)
|
|
dimensions: int | None = Field(
|
|
default=None,
|
|
description="Embedding dimensions",
|
|
validation_alias="OPENAI_DIMENSIONS",
|
|
)
|
|
deployment_id: str | None = Field(
|
|
default=None,
|
|
description="Azure deployment ID",
|
|
validation_alias="OPENAI_DEPLOYMENT_ID",
|
|
)
|
|
organization_id: str | None = Field(
|
|
default=None,
|
|
description="Organization ID",
|
|
validation_alias="OPENAI_ORGANIZATION_ID",
|
|
)
|