mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 23:58:34 +00:00
- update imports and include handling for chromadb v1.1.0 - fix mypy and typing_compat issues (required, typeddict, voyageai) - refine embedderconfig typing and allow base provider instances - handle mem0 as special case for external memory storage - bump tools and clean up redundant deps
59 lines
1.9 KiB
Python
59 lines
1.9 KiB
Python
"""OpenAI embeddings provider."""
|
|
|
|
from typing import Any
|
|
|
|
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
|
OpenAIEmbeddingFunction,
|
|
)
|
|
from pydantic import Field
|
|
|
|
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
|
|
|
|
|
class OpenAIProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
|
|
"""OpenAI embeddings provider."""
|
|
|
|
embedding_callable: type[OpenAIEmbeddingFunction] = Field(
|
|
default=OpenAIEmbeddingFunction,
|
|
description="OpenAI embedding function class",
|
|
)
|
|
api_key: str | None = Field(
|
|
default=None, description="OpenAI API key", validation_alias="OPENAI_API_KEY"
|
|
)
|
|
model_name: str = Field(
|
|
default="text-embedding-ada-002",
|
|
description="Model name to use for embeddings",
|
|
validation_alias="OPENAI_MODEL_NAME",
|
|
)
|
|
api_base: str | None = Field(
|
|
default=None,
|
|
description="Base URL for API requests",
|
|
validation_alias="OPENAI_API_BASE",
|
|
)
|
|
api_type: str | None = Field(
|
|
default=None,
|
|
description="API type (e.g., 'azure')",
|
|
validation_alias="OPENAI_API_TYPE",
|
|
)
|
|
api_version: str | None = Field(
|
|
default=None, description="API version", validation_alias="OPENAI_API_VERSION"
|
|
)
|
|
default_headers: dict[str, Any] | None = Field(
|
|
default=None, description="Default headers for API requests"
|
|
)
|
|
dimensions: int | None = Field(
|
|
default=None,
|
|
description="Embedding dimensions",
|
|
validation_alias="OPENAI_DIMENSIONS",
|
|
)
|
|
deployment_id: str | None = Field(
|
|
default=None,
|
|
description="Azure deployment ID",
|
|
validation_alias="OPENAI_DEPLOYMENT_ID",
|
|
)
|
|
organization_id: str | None = Field(
|
|
default=None,
|
|
description="OpenAI organization ID",
|
|
validation_alias="OPENAI_ORGANIZATION_ID",
|
|
)
|