fix: add dict overload to build_embedder and type default embedder

This commit is contained in:
Greyson LaLonde
2026-02-26 16:04:28 -05:00
committed by GitHub
parent 86d3ee022d
commit 373abbb6b7
2 changed files with 21 additions and 8 deletions

View File

@@ -6,7 +6,7 @@ from concurrent.futures import Future, ThreadPoolExecutor
from datetime import datetime
import threading
import time
from typing import Any, Literal
from typing import TYPE_CHECKING, Any, Literal
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
@@ -30,13 +30,20 @@ from crewai.memory.types import (
compute_composite_score,
embed_text,
)
from crewai.rag.embeddings.factory import build_embedder
from crewai.rag.embeddings.providers.openai.types import OpenAIProviderSpec
def _default_embedder() -> Any:
if TYPE_CHECKING:
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
def _default_embedder() -> OpenAIEmbeddingFunction:
"""Build default OpenAI embedder for memory."""
from crewai.rag.embeddings.factory import build_embedder
return build_embedder({"provider": "openai", "config": {}})
spec: OpenAIProviderSpec = {"provider": "openai", "config": {}}
return build_embedder(spec)
class Memory:
@@ -194,9 +201,7 @@ class Memory:
if self._embedder_instance is None:
try:
if isinstance(self._embedder_config, dict):
from crewai.rag.embeddings.factory import build_embedder
self._embedder_instance = build_embedder(self._embedder_config) # type: ignore[call-overload]
self._embedder_instance = build_embedder(self._embedder_config)
else:
self._embedder_instance = _default_embedder()
except Exception as e:

View File

@@ -216,6 +216,10 @@ def build_embedder_from_dict(
def build_embedder_from_dict(spec: ONNXProviderSpec) -> ONNXMiniLM_L6_V2: ...
@overload
def build_embedder_from_dict(spec: dict[str, Any]) -> EmbeddingFunction[Any]: ...
def build_embedder_from_dict(spec): # type: ignore[no-untyped-def]
"""Build an embedding function instance from a dictionary specification.
@@ -341,6 +345,10 @@ def build_embedder(spec: Text2VecProviderSpec) -> Text2VecEmbeddingFunction: ...
def build_embedder(spec: ONNXProviderSpec) -> ONNXMiniLM_L6_V2: ...
@overload
def build_embedder(spec: dict[str, Any]) -> EmbeddingFunction[Any]: ...
def build_embedder(spec): # type: ignore[no-untyped-def]
"""Build an embedding function from either a provider spec or a provider instance.