mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-01 15:22:37 +00:00
feat: upgrade chromadb to v1.1.0, improve types
- update imports and include handling for chromadb v1.1.0 - fix mypy and typing_compat issues (required, typeddict, voyageai) - refine embedderconfig typing and allow base provider instances - handle mem0 as special case for external memory storage - bump tools and clean up redundant deps
This commit is contained in:
@@ -10,7 +10,6 @@ from chromadb.api.models.AsyncCollection import AsyncCollection
|
||||
from chromadb.api.models.Collection import Collection
|
||||
from chromadb.api.types import (
|
||||
Include,
|
||||
IncludeEnum,
|
||||
QueryResult,
|
||||
)
|
||||
|
||||
@@ -142,9 +141,12 @@ def _extract_search_params(
|
||||
score_threshold=kwargs.get("score_threshold"),
|
||||
where=kwargs.get("where"),
|
||||
where_document=kwargs.get("where_document"),
|
||||
include=kwargs.get(
|
||||
"include",
|
||||
[IncludeEnum.metadatas, IncludeEnum.documents, IncludeEnum.distances],
|
||||
include=cast(
|
||||
Include,
|
||||
kwargs.get(
|
||||
"include",
|
||||
["metadatas", "documents", "distances"],
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -193,7 +195,7 @@ def _convert_chromadb_results_to_search_results(
|
||||
"""
|
||||
search_results: list[SearchResult] = []
|
||||
|
||||
include_strings = [item.value for item in include] if include else []
|
||||
include_strings = list(include) if include else []
|
||||
|
||||
ids = results["ids"][0] if results.get("ids") else []
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Amazon Bedrock embeddings provider."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
||||
AmazonBedrockEmbeddingFunction,
|
||||
)
|
||||
@@ -7,15 +9,8 @@ from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
try:
|
||||
from boto3.session import Session # type: ignore[import-untyped]
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"boto3 is required for amazon-bedrock embeddings. Install it with: uv add boto3"
|
||||
) from exc
|
||||
|
||||
|
||||
def create_aws_session() -> Session:
|
||||
def create_aws_session() -> Any:
|
||||
"""Create an AWS session for Bedrock.
|
||||
|
||||
Returns:
|
||||
@@ -53,6 +48,6 @@ class BedrockProvider(BaseEmbeddingsProvider[AmazonBedrockEmbeddingFunction]):
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="BEDROCK_MODEL_NAME",
|
||||
)
|
||||
session: Session = Field(
|
||||
session: Any = Field(
|
||||
default_factory=create_aws_session, description="AWS session object"
|
||||
)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Type definitions for AWS embedding providers."""
|
||||
|
||||
from typing import Annotated, Any, Literal, TypedDict
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class BedrockProviderConfig(TypedDict, total=False):
|
||||
@@ -10,8 +12,8 @@ class BedrockProviderConfig(TypedDict, total=False):
|
||||
session: Any
|
||||
|
||||
|
||||
class BedrockProviderSpec(TypedDict):
|
||||
class BedrockProviderSpec(TypedDict, total=False):
|
||||
"""Bedrock provider specification."""
|
||||
|
||||
provider: Literal["amazon-bedrock"]
|
||||
provider: Required[Literal["amazon-bedrock"]]
|
||||
config: BedrockProviderConfig
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Type definitions for Cohere embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal, TypedDict
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class CohereProviderConfig(TypedDict, total=False):
|
||||
@@ -10,8 +12,8 @@ class CohereProviderConfig(TypedDict, total=False):
|
||||
model_name: Annotated[str, "large"]
|
||||
|
||||
|
||||
class CohereProviderSpec(TypedDict):
|
||||
class CohereProviderSpec(TypedDict, total=False):
|
||||
"""Cohere provider specification."""
|
||||
|
||||
provider: Literal["cohere"]
|
||||
provider: Required[Literal["cohere"]]
|
||||
config: CohereProviderConfig
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"""Type definitions for custom embedding providers."""
|
||||
|
||||
from typing import Literal, TypedDict
|
||||
from typing import Literal
|
||||
|
||||
from chromadb.api.types import EmbeddingFunction
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class CustomProviderConfig(TypedDict, total=False):
|
||||
@@ -11,8 +12,8 @@ class CustomProviderConfig(TypedDict, total=False):
|
||||
embedding_callable: type[EmbeddingFunction]
|
||||
|
||||
|
||||
class CustomProviderSpec(TypedDict):
|
||||
class CustomProviderSpec(TypedDict, total=False):
|
||||
"""Custom provider specification."""
|
||||
|
||||
provider: Literal["custom"]
|
||||
provider: Required[Literal["custom"]]
|
||||
config: CustomProviderConfig
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Type definitions for Google embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal, TypedDict
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class GenerativeAiProviderConfig(TypedDict, total=False):
|
||||
@@ -27,8 +29,8 @@ class VertexAIProviderConfig(TypedDict, total=False):
|
||||
region: Annotated[str, "us-central1"]
|
||||
|
||||
|
||||
class VertexAIProviderSpec(TypedDict):
|
||||
class VertexAIProviderSpec(TypedDict, total=False):
|
||||
"""Vertex AI provider specification."""
|
||||
|
||||
provider: Literal["google-vertex"]
|
||||
provider: Required[Literal["google-vertex"]]
|
||||
config: VertexAIProviderConfig
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Type definitions for HuggingFace embedding providers."""
|
||||
|
||||
from typing import Literal, TypedDict
|
||||
from typing import Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class HuggingFaceProviderConfig(TypedDict, total=False):
|
||||
@@ -9,8 +11,8 @@ class HuggingFaceProviderConfig(TypedDict, total=False):
|
||||
url: str
|
||||
|
||||
|
||||
class HuggingFaceProviderSpec(TypedDict):
|
||||
class HuggingFaceProviderSpec(TypedDict, total=False):
|
||||
"""HuggingFace provider specification."""
|
||||
|
||||
provider: Literal["huggingface"]
|
||||
provider: Required[Literal["huggingface"]]
|
||||
config: HuggingFaceProviderConfig
|
||||
|
||||
@@ -2,11 +2,6 @@
|
||||
|
||||
from typing import cast
|
||||
|
||||
import ibm_watsonx_ai.foundation_models as watson_models # type: ignore[import-not-found, import-untyped]
|
||||
from ibm_watsonx_ai import Credentials # type: ignore[import-not-found, import-untyped]
|
||||
from ibm_watsonx_ai.metanames import ( # type: ignore[import-not-found, import-untyped]
|
||||
EmbedTextParamsMetaNames as EmbedParams,
|
||||
)
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||
@@ -34,6 +29,21 @@ class WatsonEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
Returns:
|
||||
List of embedding vectors.
|
||||
"""
|
||||
try:
|
||||
import ibm_watsonx_ai.foundation_models as watson_models # type: ignore[import-not-found, import-untyped]
|
||||
from ibm_watsonx_ai import (
|
||||
Credentials, # type: ignore[import-not-found, import-untyped]
|
||||
)
|
||||
from ibm_watsonx_ai.metanames import ( # type: ignore[import-not-found, import-untyped]
|
||||
EmbedTextParamsMetaNames as EmbedParams,
|
||||
)
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"ibm-watsonx-ai is required for watson embeddings. "
|
||||
"Install it with: uv add ibm-watsonx-ai"
|
||||
) from e
|
||||
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Type definitions for IBM Watson embedding providers."""
|
||||
|
||||
from typing import Annotated, Any, Literal, TypedDict
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class WatsonProviderConfig(TypedDict, total=False):
|
||||
@@ -35,8 +37,8 @@ class WatsonProviderConfig(TypedDict, total=False):
|
||||
proxies: dict
|
||||
|
||||
|
||||
class WatsonProviderSpec(TypedDict):
|
||||
class WatsonProviderSpec(TypedDict, total=False):
|
||||
"""Watson provider specification."""
|
||||
|
||||
provider: Literal["watson"]
|
||||
provider: Required[Literal["watson"]]
|
||||
config: WatsonProviderConfig
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
"""IBM Watson embeddings provider."""
|
||||
|
||||
from ibm_watsonx_ai import ( # type: ignore[import-not-found,import-untyped]
|
||||
APIClient,
|
||||
Credentials,
|
||||
)
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
@@ -28,9 +26,7 @@ class WatsonProvider(BaseEmbeddingsProvider[WatsonEmbeddingFunction]):
|
||||
params: dict[str, str | dict[str, str]] | None = Field(
|
||||
default=None, description="Additional parameters"
|
||||
)
|
||||
credentials: Credentials | None = Field(
|
||||
default=None, description="Watson credentials"
|
||||
)
|
||||
credentials: Any | None = Field(default=None, description="Watson credentials")
|
||||
project_id: str | None = Field(
|
||||
default=None,
|
||||
description="Watson project ID",
|
||||
@@ -39,7 +35,7 @@ class WatsonProvider(BaseEmbeddingsProvider[WatsonEmbeddingFunction]):
|
||||
space_id: str | None = Field(
|
||||
default=None, description="Watson space ID", validation_alias="WATSON_SPACE_ID"
|
||||
)
|
||||
api_client: APIClient | None = Field(default=None, description="Watson API client")
|
||||
api_client: Any | None = Field(default=None, description="Watson API client")
|
||||
verify: bool | str | None = Field(
|
||||
default=None, description="SSL verification", validation_alias="WATSON_VERIFY"
|
||||
)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Type definitions for Instructor embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal, TypedDict
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class InstructorProviderConfig(TypedDict, total=False):
|
||||
@@ -11,8 +13,8 @@ class InstructorProviderConfig(TypedDict, total=False):
|
||||
instruction: str
|
||||
|
||||
|
||||
class InstructorProviderSpec(TypedDict):
|
||||
class InstructorProviderSpec(TypedDict, total=False):
|
||||
"""Instructor provider specification."""
|
||||
|
||||
provider: Literal["instructor"]
|
||||
provider: Required[Literal["instructor"]]
|
||||
config: InstructorProviderConfig
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Type definitions for Jina embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal, TypedDict
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class JinaProviderConfig(TypedDict, total=False):
|
||||
@@ -10,8 +12,8 @@ class JinaProviderConfig(TypedDict, total=False):
|
||||
model_name: Annotated[str, "jina-embeddings-v2-base-en"]
|
||||
|
||||
|
||||
class JinaProviderSpec(TypedDict):
|
||||
class JinaProviderSpec(TypedDict, total=False):
|
||||
"""Jina provider specification."""
|
||||
|
||||
provider: Literal["jina"]
|
||||
provider: Required[Literal["jina"]]
|
||||
config: JinaProviderConfig
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Type definitions for Microsoft Azure embedding providers."""
|
||||
|
||||
from typing import Annotated, Any, Literal, TypedDict
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class AzureProviderConfig(TypedDict, total=False):
|
||||
@@ -17,8 +19,8 @@ class AzureProviderConfig(TypedDict, total=False):
|
||||
organization_id: str
|
||||
|
||||
|
||||
class AzureProviderSpec(TypedDict):
|
||||
class AzureProviderSpec(TypedDict, total=False):
|
||||
"""Azure provider specification."""
|
||||
|
||||
provider: Literal["azure"]
|
||||
provider: Required[Literal["azure"]]
|
||||
config: AzureProviderConfig
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Type definitions for Ollama embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal, TypedDict
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class OllamaProviderConfig(TypedDict, total=False):
|
||||
@@ -10,8 +12,8 @@ class OllamaProviderConfig(TypedDict, total=False):
|
||||
model_name: str
|
||||
|
||||
|
||||
class OllamaProviderSpec(TypedDict):
|
||||
class OllamaProviderSpec(TypedDict, total=False):
|
||||
"""Ollama provider specification."""
|
||||
|
||||
provider: Literal["ollama"]
|
||||
provider: Required[Literal["ollama"]]
|
||||
config: OllamaProviderConfig
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Type definitions for ONNX embedding providers."""
|
||||
|
||||
from typing import Literal, TypedDict
|
||||
from typing import Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class ONNXProviderConfig(TypedDict, total=False):
|
||||
@@ -9,8 +11,8 @@ class ONNXProviderConfig(TypedDict, total=False):
|
||||
preferred_providers: list[str]
|
||||
|
||||
|
||||
class ONNXProviderSpec(TypedDict):
|
||||
class ONNXProviderSpec(TypedDict, total=False):
|
||||
"""ONNX provider specification."""
|
||||
|
||||
provider: Literal["onnx"]
|
||||
provider: Required[Literal["onnx"]]
|
||||
config: ONNXProviderConfig
|
||||
|
||||
@@ -17,8 +17,8 @@ class OpenAIProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
|
||||
default=OpenAIEmbeddingFunction,
|
||||
description="OpenAI embedding function class",
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="OpenAI API key", validation_alias="OPENAI_API_KEY"
|
||||
api_key: str | None = Field(
|
||||
default=None, description="OpenAI API key", validation_alias="OPENAI_API_KEY"
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="text-embedding-ada-002",
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Type definitions for OpenAI embedding providers."""
|
||||
|
||||
from typing import Annotated, Any, Literal, TypedDict
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class OpenAIProviderConfig(TypedDict, total=False):
|
||||
@@ -17,8 +19,8 @@ class OpenAIProviderConfig(TypedDict, total=False):
|
||||
organization_id: str
|
||||
|
||||
|
||||
class OpenAIProviderSpec(TypedDict):
|
||||
class OpenAIProviderSpec(TypedDict, total=False):
|
||||
"""OpenAI provider specification."""
|
||||
|
||||
provider: Literal["openai"]
|
||||
provider: Required[Literal["openai"]]
|
||||
config: OpenAIProviderConfig
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Type definitions for OpenCLIP embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal, TypedDict
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class OpenCLIPProviderConfig(TypedDict, total=False):
|
||||
@@ -14,5 +16,5 @@ class OpenCLIPProviderConfig(TypedDict, total=False):
|
||||
class OpenCLIPProviderSpec(TypedDict):
|
||||
"""OpenCLIP provider specification."""
|
||||
|
||||
provider: Literal["openclip"]
|
||||
provider: Required[Literal["openclip"]]
|
||||
config: OpenCLIPProviderConfig
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Type definitions for Roboflow embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal, TypedDict
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class RoboflowProviderConfig(TypedDict, total=False):
|
||||
@@ -13,5 +15,5 @@ class RoboflowProviderConfig(TypedDict, total=False):
|
||||
class RoboflowProviderSpec(TypedDict):
|
||||
"""Roboflow provider specification."""
|
||||
|
||||
provider: Literal["roboflow"]
|
||||
provider: Required[Literal["roboflow"]]
|
||||
config: RoboflowProviderConfig
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Type definitions for SentenceTransformer embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal, TypedDict
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class SentenceTransformerProviderConfig(TypedDict, total=False):
|
||||
@@ -14,5 +16,5 @@ class SentenceTransformerProviderConfig(TypedDict, total=False):
|
||||
class SentenceTransformerProviderSpec(TypedDict):
|
||||
"""SentenceTransformer provider specification."""
|
||||
|
||||
provider: Literal["sentence-transformer"]
|
||||
provider: Required[Literal["sentence-transformer"]]
|
||||
config: SentenceTransformerProviderConfig
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Type definitions for Text2Vec embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal, TypedDict
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class Text2VecProviderConfig(TypedDict, total=False):
|
||||
@@ -12,5 +14,5 @@ class Text2VecProviderConfig(TypedDict, total=False):
|
||||
class Text2VecProviderSpec(TypedDict):
|
||||
"""Text2Vec provider specification."""
|
||||
|
||||
provider: Literal["text2vec"]
|
||||
provider: Required[Literal["text2vec"]]
|
||||
config: Text2VecProviderConfig
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from typing import cast
|
||||
|
||||
import voyageai
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||
@@ -19,6 +18,14 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
Args:
|
||||
**kwargs: Configuration parameters for VoyageAI.
|
||||
"""
|
||||
try:
|
||||
import voyageai # type: ignore[import-not-found]
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"voyageai is required for voyageai embeddings. "
|
||||
"Install it with: uv add voyageai"
|
||||
) from e
|
||||
self._config = kwargs
|
||||
self._client = voyageai.Client(
|
||||
api_key=kwargs["api_key"],
|
||||
@@ -35,6 +42,7 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
Returns:
|
||||
List of embedding vectors.
|
||||
"""
|
||||
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Type definitions for VoyageAI embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal, TypedDict
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class VoyageAIProviderConfig(TypedDict, total=False):
|
||||
@@ -19,5 +21,5 @@ class VoyageAIProviderConfig(TypedDict, total=False):
|
||||
class VoyageAIProviderSpec(TypedDict):
|
||||
"""VoyageAI provider specification."""
|
||||
|
||||
provider: Literal["voyageai"]
|
||||
provider: Required[Literal["voyageai"]]
|
||||
config: VoyageAIProviderConfig
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Type definitions for the embeddings module."""
|
||||
|
||||
from typing import Literal
|
||||
from typing import Literal, TypeAlias
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec
|
||||
from crewai.rag.embeddings.providers.cohere.types import CohereProviderSpec
|
||||
from crewai.rag.embeddings.providers.custom.types import CustomProviderSpec
|
||||
@@ -66,3 +67,7 @@ AllowedEmbeddingProviders = Literal[
|
||||
"voyageai",
|
||||
"watson",
|
||||
]
|
||||
|
||||
EmbedderConfig: TypeAlias = (
|
||||
ProviderSpec | BaseEmbeddingsProvider | type[BaseEmbeddingsProvider]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user