feat: upgrade chromadb to v1.1.0, improve types

- update imports and include handling for chromadb v1.1.0  
- fix mypy and typing_compat issues (required, typeddict, voyageai)  
- refine embedderconfig typing and allow base provider instances  
- handle mem0 as special case for external memory storage  
- bump tools and clean up redundant deps
This commit is contained in:
Greyson LaLonde
2025-09-25 20:48:37 -04:00
committed by GitHub
parent ce5ea9be6f
commit 2485ed93d6
35 changed files with 383 additions and 316 deletions

View File

@@ -10,7 +10,6 @@ from chromadb.api.models.AsyncCollection import AsyncCollection
from chromadb.api.models.Collection import Collection
from chromadb.api.types import (
Include,
IncludeEnum,
QueryResult,
)
@@ -142,9 +141,12 @@ def _extract_search_params(
score_threshold=kwargs.get("score_threshold"),
where=kwargs.get("where"),
where_document=kwargs.get("where_document"),
include=kwargs.get(
"include",
[IncludeEnum.metadatas, IncludeEnum.documents, IncludeEnum.distances],
include=cast(
Include,
kwargs.get(
"include",
["metadatas", "documents", "distances"],
),
),
)
@@ -193,7 +195,7 @@ def _convert_chromadb_results_to_search_results(
"""
search_results: list[SearchResult] = []
include_strings = [item.value for item in include] if include else []
include_strings = list(include) if include else []
ids = results["ids"][0] if results.get("ids") else []

View File

@@ -1,5 +1,7 @@
"""Amazon Bedrock embeddings provider."""
from typing import Any
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
AmazonBedrockEmbeddingFunction,
)
@@ -7,15 +9,8 @@ from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
try:
from boto3.session import Session # type: ignore[import-untyped]
except ImportError as exc:
raise ImportError(
"boto3 is required for amazon-bedrock embeddings. Install it with: uv add boto3"
) from exc
def create_aws_session() -> Session:
def create_aws_session() -> Any:
"""Create an AWS session for Bedrock.
Returns:
@@ -53,6 +48,6 @@ class BedrockProvider(BaseEmbeddingsProvider[AmazonBedrockEmbeddingFunction]):
description="Model name to use for embeddings",
validation_alias="BEDROCK_MODEL_NAME",
)
session: Session = Field(
session: Any = Field(
default_factory=create_aws_session, description="AWS session object"
)

View File

@@ -1,6 +1,8 @@
"""Type definitions for AWS embedding providers."""
from typing import Annotated, Any, Literal, TypedDict
from typing import Annotated, Any, Literal
from typing_extensions import Required, TypedDict
class BedrockProviderConfig(TypedDict, total=False):
@@ -10,8 +12,8 @@ class BedrockProviderConfig(TypedDict, total=False):
session: Any
class BedrockProviderSpec(TypedDict):
class BedrockProviderSpec(TypedDict, total=False):
"""Bedrock provider specification."""
provider: Literal["amazon-bedrock"]
provider: Required[Literal["amazon-bedrock"]]
config: BedrockProviderConfig

View File

@@ -1,6 +1,8 @@
"""Type definitions for Cohere embedding providers."""
from typing import Annotated, Literal, TypedDict
from typing import Annotated, Literal
from typing_extensions import Required, TypedDict
class CohereProviderConfig(TypedDict, total=False):
@@ -10,8 +12,8 @@ class CohereProviderConfig(TypedDict, total=False):
model_name: Annotated[str, "large"]
class CohereProviderSpec(TypedDict):
class CohereProviderSpec(TypedDict, total=False):
"""Cohere provider specification."""
provider: Literal["cohere"]
provider: Required[Literal["cohere"]]
config: CohereProviderConfig

View File

@@ -1,8 +1,9 @@
"""Type definitions for custom embedding providers."""
from typing import Literal, TypedDict
from typing import Literal
from chromadb.api.types import EmbeddingFunction
from typing_extensions import Required, TypedDict
class CustomProviderConfig(TypedDict, total=False):
@@ -11,8 +12,8 @@ class CustomProviderConfig(TypedDict, total=False):
embedding_callable: type[EmbeddingFunction]
class CustomProviderSpec(TypedDict):
class CustomProviderSpec(TypedDict, total=False):
"""Custom provider specification."""
provider: Literal["custom"]
provider: Required[Literal["custom"]]
config: CustomProviderConfig

View File

@@ -1,6 +1,8 @@
"""Type definitions for Google embedding providers."""
from typing import Annotated, Literal, TypedDict
from typing import Annotated, Literal
from typing_extensions import Required, TypedDict
class GenerativeAiProviderConfig(TypedDict, total=False):
@@ -27,8 +29,8 @@ class VertexAIProviderConfig(TypedDict, total=False):
region: Annotated[str, "us-central1"]
class VertexAIProviderSpec(TypedDict):
class VertexAIProviderSpec(TypedDict, total=False):
"""Vertex AI provider specification."""
provider: Literal["google-vertex"]
provider: Required[Literal["google-vertex"]]
config: VertexAIProviderConfig

View File

@@ -1,6 +1,8 @@
"""Type definitions for HuggingFace embedding providers."""
from typing import Literal, TypedDict
from typing import Literal
from typing_extensions import Required, TypedDict
class HuggingFaceProviderConfig(TypedDict, total=False):
@@ -9,8 +11,8 @@ class HuggingFaceProviderConfig(TypedDict, total=False):
url: str
class HuggingFaceProviderSpec(TypedDict):
class HuggingFaceProviderSpec(TypedDict, total=False):
"""HuggingFace provider specification."""
provider: Literal["huggingface"]
provider: Required[Literal["huggingface"]]
config: HuggingFaceProviderConfig

View File

@@ -2,11 +2,6 @@
from typing import cast
import ibm_watsonx_ai.foundation_models as watson_models # type: ignore[import-not-found, import-untyped]
from ibm_watsonx_ai import Credentials # type: ignore[import-not-found, import-untyped]
from ibm_watsonx_ai.metanames import ( # type: ignore[import-not-found, import-untyped]
EmbedTextParamsMetaNames as EmbedParams,
)
from typing_extensions import Unpack
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
@@ -34,6 +29,21 @@ class WatsonEmbeddingFunction(EmbeddingFunction[Documents]):
Returns:
List of embedding vectors.
"""
try:
import ibm_watsonx_ai.foundation_models as watson_models # type: ignore[import-not-found, import-untyped]
from ibm_watsonx_ai import (
Credentials, # type: ignore[import-not-found, import-untyped]
)
from ibm_watsonx_ai.metanames import ( # type: ignore[import-not-found, import-untyped]
EmbedTextParamsMetaNames as EmbedParams,
)
except ImportError as e:
raise ImportError(
"ibm-watsonx-ai is required for watson embeddings. "
"Install it with: uv add ibm-watsonx-ai"
) from e
if isinstance(input, str):
input = [input]

View File

@@ -1,6 +1,8 @@
"""Type definitions for IBM Watson embedding providers."""
from typing import Annotated, Any, Literal, TypedDict
from typing import Annotated, Any, Literal
from typing_extensions import Required, TypedDict
class WatsonProviderConfig(TypedDict, total=False):
@@ -35,8 +37,8 @@ class WatsonProviderConfig(TypedDict, total=False):
proxies: dict
class WatsonProviderSpec(TypedDict):
class WatsonProviderSpec(TypedDict, total=False):
"""Watson provider specification."""
provider: Literal["watson"]
provider: Required[Literal["watson"]]
config: WatsonProviderConfig

View File

@@ -1,9 +1,7 @@
"""IBM Watson embeddings provider."""
from ibm_watsonx_ai import ( # type: ignore[import-not-found,import-untyped]
APIClient,
Credentials,
)
from typing import Any
from pydantic import Field, model_validator
from typing_extensions import Self
@@ -28,9 +26,7 @@ class WatsonProvider(BaseEmbeddingsProvider[WatsonEmbeddingFunction]):
params: dict[str, str | dict[str, str]] | None = Field(
default=None, description="Additional parameters"
)
credentials: Credentials | None = Field(
default=None, description="Watson credentials"
)
credentials: Any | None = Field(default=None, description="Watson credentials")
project_id: str | None = Field(
default=None,
description="Watson project ID",
@@ -39,7 +35,7 @@ class WatsonProvider(BaseEmbeddingsProvider[WatsonEmbeddingFunction]):
space_id: str | None = Field(
default=None, description="Watson space ID", validation_alias="WATSON_SPACE_ID"
)
api_client: APIClient | None = Field(default=None, description="Watson API client")
api_client: Any | None = Field(default=None, description="Watson API client")
verify: bool | str | None = Field(
default=None, description="SSL verification", validation_alias="WATSON_VERIFY"
)

View File

@@ -1,6 +1,8 @@
"""Type definitions for Instructor embedding providers."""
from typing import Annotated, Literal, TypedDict
from typing import Annotated, Literal
from typing_extensions import Required, TypedDict
class InstructorProviderConfig(TypedDict, total=False):
@@ -11,8 +13,8 @@ class InstructorProviderConfig(TypedDict, total=False):
instruction: str
class InstructorProviderSpec(TypedDict):
class InstructorProviderSpec(TypedDict, total=False):
"""Instructor provider specification."""
provider: Literal["instructor"]
provider: Required[Literal["instructor"]]
config: InstructorProviderConfig

View File

@@ -1,6 +1,8 @@
"""Type definitions for Jina embedding providers."""
from typing import Annotated, Literal, TypedDict
from typing import Annotated, Literal
from typing_extensions import Required, TypedDict
class JinaProviderConfig(TypedDict, total=False):
@@ -10,8 +12,8 @@ class JinaProviderConfig(TypedDict, total=False):
model_name: Annotated[str, "jina-embeddings-v2-base-en"]
class JinaProviderSpec(TypedDict):
class JinaProviderSpec(TypedDict, total=False):
"""Jina provider specification."""
provider: Literal["jina"]
provider: Required[Literal["jina"]]
config: JinaProviderConfig

View File

@@ -1,6 +1,8 @@
"""Type definitions for Microsoft Azure embedding providers."""
from typing import Annotated, Any, Literal, TypedDict
from typing import Annotated, Any, Literal
from typing_extensions import Required, TypedDict
class AzureProviderConfig(TypedDict, total=False):
@@ -17,8 +19,8 @@ class AzureProviderConfig(TypedDict, total=False):
organization_id: str
class AzureProviderSpec(TypedDict):
class AzureProviderSpec(TypedDict, total=False):
"""Azure provider specification."""
provider: Literal["azure"]
provider: Required[Literal["azure"]]
config: AzureProviderConfig

View File

@@ -1,6 +1,8 @@
"""Type definitions for Ollama embedding providers."""
from typing import Annotated, Literal, TypedDict
from typing import Annotated, Literal
from typing_extensions import Required, TypedDict
class OllamaProviderConfig(TypedDict, total=False):
@@ -10,8 +12,8 @@ class OllamaProviderConfig(TypedDict, total=False):
model_name: str
class OllamaProviderSpec(TypedDict):
class OllamaProviderSpec(TypedDict, total=False):
"""Ollama provider specification."""
provider: Literal["ollama"]
provider: Required[Literal["ollama"]]
config: OllamaProviderConfig

View File

@@ -1,6 +1,8 @@
"""Type definitions for ONNX embedding providers."""
from typing import Literal, TypedDict
from typing import Literal
from typing_extensions import Required, TypedDict
class ONNXProviderConfig(TypedDict, total=False):
@@ -9,8 +11,8 @@ class ONNXProviderConfig(TypedDict, total=False):
preferred_providers: list[str]
class ONNXProviderSpec(TypedDict):
class ONNXProviderSpec(TypedDict, total=False):
"""ONNX provider specification."""
provider: Literal["onnx"]
provider: Required[Literal["onnx"]]
config: ONNXProviderConfig

View File

@@ -17,8 +17,8 @@ class OpenAIProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
default=OpenAIEmbeddingFunction,
description="OpenAI embedding function class",
)
api_key: str = Field(
description="OpenAI API key", validation_alias="OPENAI_API_KEY"
api_key: str | None = Field(
default=None, description="OpenAI API key", validation_alias="OPENAI_API_KEY"
)
model_name: str = Field(
default="text-embedding-ada-002",

View File

@@ -1,6 +1,8 @@
"""Type definitions for OpenAI embedding providers."""
from typing import Annotated, Any, Literal, TypedDict
from typing import Annotated, Any, Literal
from typing_extensions import Required, TypedDict
class OpenAIProviderConfig(TypedDict, total=False):
@@ -17,8 +19,8 @@ class OpenAIProviderConfig(TypedDict, total=False):
organization_id: str
class OpenAIProviderSpec(TypedDict):
class OpenAIProviderSpec(TypedDict, total=False):
"""OpenAI provider specification."""
provider: Literal["openai"]
provider: Required[Literal["openai"]]
config: OpenAIProviderConfig

View File

@@ -1,6 +1,8 @@
"""Type definitions for OpenCLIP embedding providers."""
from typing import Annotated, Literal, TypedDict
from typing import Annotated, Literal
from typing_extensions import Required, TypedDict
class OpenCLIPProviderConfig(TypedDict, total=False):
@@ -14,5 +16,5 @@ class OpenCLIPProviderConfig(TypedDict, total=False):
class OpenCLIPProviderSpec(TypedDict):
"""OpenCLIP provider specification."""
provider: Literal["openclip"]
provider: Required[Literal["openclip"]]
config: OpenCLIPProviderConfig

View File

@@ -1,6 +1,8 @@
"""Type definitions for Roboflow embedding providers."""
from typing import Annotated, Literal, TypedDict
from typing import Annotated, Literal
from typing_extensions import Required, TypedDict
class RoboflowProviderConfig(TypedDict, total=False):
@@ -13,5 +15,5 @@ class RoboflowProviderConfig(TypedDict, total=False):
class RoboflowProviderSpec(TypedDict):
"""Roboflow provider specification."""
provider: Literal["roboflow"]
provider: Required[Literal["roboflow"]]
config: RoboflowProviderConfig

View File

@@ -1,6 +1,8 @@
"""Type definitions for SentenceTransformer embedding providers."""
from typing import Annotated, Literal, TypedDict
from typing import Annotated, Literal
from typing_extensions import Required, TypedDict
class SentenceTransformerProviderConfig(TypedDict, total=False):
@@ -14,5 +16,5 @@ class SentenceTransformerProviderConfig(TypedDict, total=False):
class SentenceTransformerProviderSpec(TypedDict):
"""SentenceTransformer provider specification."""
provider: Literal["sentence-transformer"]
provider: Required[Literal["sentence-transformer"]]
config: SentenceTransformerProviderConfig

View File

@@ -1,6 +1,8 @@
"""Type definitions for Text2Vec embedding providers."""
from typing import Annotated, Literal, TypedDict
from typing import Annotated, Literal
from typing_extensions import Required, TypedDict
class Text2VecProviderConfig(TypedDict, total=False):
@@ -12,5 +14,5 @@ class Text2VecProviderConfig(TypedDict, total=False):
class Text2VecProviderSpec(TypedDict):
"""Text2Vec provider specification."""
provider: Literal["text2vec"]
provider: Required[Literal["text2vec"]]
config: Text2VecProviderConfig

View File

@@ -2,7 +2,6 @@
from typing import cast
import voyageai
from typing_extensions import Unpack
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
@@ -19,6 +18,14 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]):
Args:
**kwargs: Configuration parameters for VoyageAI.
"""
try:
import voyageai # type: ignore[import-not-found]
except ImportError as e:
raise ImportError(
"voyageai is required for voyageai embeddings. "
"Install it with: uv add voyageai"
) from e
self._config = kwargs
self._client = voyageai.Client(
api_key=kwargs["api_key"],
@@ -35,6 +42,7 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]):
Returns:
List of embedding vectors.
"""
if isinstance(input, str):
input = [input]

View File

@@ -1,6 +1,8 @@
"""Type definitions for VoyageAI embedding providers."""
from typing import Annotated, Literal, TypedDict
from typing import Annotated, Literal
from typing_extensions import Required, TypedDict
class VoyageAIProviderConfig(TypedDict, total=False):
@@ -19,5 +21,5 @@ class VoyageAIProviderConfig(TypedDict, total=False):
class VoyageAIProviderSpec(TypedDict):
"""VoyageAI provider specification."""
provider: Literal["voyageai"]
provider: Required[Literal["voyageai"]]
config: VoyageAIProviderConfig

View File

@@ -1,7 +1,8 @@
"""Type definitions for the embeddings module."""
from typing import Literal
from typing import Literal, TypeAlias
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec
from crewai.rag.embeddings.providers.cohere.types import CohereProviderSpec
from crewai.rag.embeddings.providers.custom.types import CustomProviderSpec
@@ -66,3 +67,7 @@ AllowedEmbeddingProviders = Literal[
"voyageai",
"watson",
]
EmbedderConfig: TypeAlias = (
ProviderSpec | BaseEmbeddingsProvider | type[BaseEmbeddingsProvider]
)