Files
crewAI/src/crewai/rag/embeddings/providers/openai/openai_provider.py
Greyson LaLonde 2485ed93d6 feat: upgrade chromadb to v1.1.0, improve types
- update imports and include handling for chromadb v1.1.0  
- fix mypy and typing_compat issues (required, typeddict, voyageai)  
- refine embedderconfig typing and allow base provider instances  
- handle mem0 as special case for external memory storage  
- bump tools and clean up redundant deps
2025-09-25 20:48:37 -04:00

59 lines
1.9 KiB
Python

"""OpenAI embeddings provider."""
from typing import Any
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class OpenAIProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
"""OpenAI embeddings provider."""
embedding_callable: type[OpenAIEmbeddingFunction] = Field(
default=OpenAIEmbeddingFunction,
description="OpenAI embedding function class",
)
api_key: str | None = Field(
default=None, description="OpenAI API key", validation_alias="OPENAI_API_KEY"
)
model_name: str = Field(
default="text-embedding-ada-002",
description="Model name to use for embeddings",
validation_alias="OPENAI_MODEL_NAME",
)
api_base: str | None = Field(
default=None,
description="Base URL for API requests",
validation_alias="OPENAI_API_BASE",
)
api_type: str | None = Field(
default=None,
description="API type (e.g., 'azure')",
validation_alias="OPENAI_API_TYPE",
)
api_version: str | None = Field(
default=None, description="API version", validation_alias="OPENAI_API_VERSION"
)
default_headers: dict[str, Any] | None = Field(
default=None, description="Default headers for API requests"
)
dimensions: int | None = Field(
default=None,
description="Embedding dimensions",
validation_alias="OPENAI_DIMENSIONS",
)
deployment_id: str | None = Field(
default=None,
description="Azure deployment ID",
validation_alias="OPENAI_DEPLOYMENT_ID",
)
organization_id: str | None = Field(
default=None,
description="OpenAI organization ID",
validation_alias="OPENAI_ORGANIZATION_ID",
)