diff --git a/lib/crewai/src/crewai/memory/unified_memory.py b/lib/crewai/src/crewai/memory/unified_memory.py index 531a91208..9d921fa1f 100644 --- a/lib/crewai/src/crewai/memory/unified_memory.py +++ b/lib/crewai/src/crewai/memory/unified_memory.py @@ -6,7 +6,7 @@ from concurrent.futures import Future, ThreadPoolExecutor from datetime import datetime import threading import time -from typing import Any, Literal +from typing import TYPE_CHECKING, Any, Literal from crewai.events.event_bus import crewai_event_bus from crewai.events.types.memory_events import ( @@ -30,13 +30,20 @@ from crewai.memory.types import ( compute_composite_score, embed_text, ) +from crewai.rag.embeddings.factory import build_embedder +from crewai.rag.embeddings.providers.openai.types import OpenAIProviderSpec -def _default_embedder() -> Any: +if TYPE_CHECKING: + from chromadb.utils.embedding_functions.openai_embedding_function import ( + OpenAIEmbeddingFunction, + ) + + +def _default_embedder() -> OpenAIEmbeddingFunction: """Build default OpenAI embedder for memory.""" - from crewai.rag.embeddings.factory import build_embedder - - return build_embedder({"provider": "openai", "config": {}}) + spec: OpenAIProviderSpec = {"provider": "openai", "config": {}} + return build_embedder(spec) class Memory: @@ -194,9 +201,7 @@ class Memory: if self._embedder_instance is None: try: if isinstance(self._embedder_config, dict): - from crewai.rag.embeddings.factory import build_embedder - - self._embedder_instance = build_embedder(self._embedder_config) # type: ignore[call-overload] + self._embedder_instance = build_embedder(self._embedder_config) else: self._embedder_instance = _default_embedder() except Exception as e: diff --git a/lib/crewai/src/crewai/rag/embeddings/factory.py b/lib/crewai/src/crewai/rag/embeddings/factory.py index 41a9233da..802779320 100644 --- a/lib/crewai/src/crewai/rag/embeddings/factory.py +++ b/lib/crewai/src/crewai/rag/embeddings/factory.py @@ -216,6 +216,10 @@ def build_embedder_from_dict( def build_embedder_from_dict(spec: ONNXProviderSpec) -> ONNXMiniLM_L6_V2: ... +@overload +def build_embedder_from_dict(spec: dict[str, Any]) -> EmbeddingFunction[Any]: ... + + def build_embedder_from_dict(spec): # type: ignore[no-untyped-def] """Build an embedding function instance from a dictionary specification. @@ -341,6 +345,10 @@ def build_embedder(spec: Text2VecProviderSpec) -> Text2VecEmbeddingFunction: ... def build_embedder(spec: ONNXProviderSpec) -> ONNXMiniLM_L6_V2: ... +@overload +def build_embedder(spec: dict[str, Any]) -> EmbeddingFunction[Any]: ... + + def build_embedder(spec): # type: ignore[no-untyped-def] """Build an embedding function from either a provider spec or a provider instance.