mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-30 23:02:50 +00:00
fix: add dict overload to build_embedder and type default embedder
This commit is contained in:
@@ -6,7 +6,7 @@ from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Literal
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
@@ -30,13 +30,20 @@ from crewai.memory.types import (
|
||||
compute_composite_score,
|
||||
embed_text,
|
||||
)
|
||||
from crewai.rag.embeddings.factory import build_embedder
|
||||
from crewai.rag.embeddings.providers.openai.types import OpenAIProviderSpec
|
||||
|
||||
|
||||
def _default_embedder() -> Any:
|
||||
if TYPE_CHECKING:
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
|
||||
|
||||
def _default_embedder() -> OpenAIEmbeddingFunction:
|
||||
"""Build default OpenAI embedder for memory."""
|
||||
from crewai.rag.embeddings.factory import build_embedder
|
||||
|
||||
return build_embedder({"provider": "openai", "config": {}})
|
||||
spec: OpenAIProviderSpec = {"provider": "openai", "config": {}}
|
||||
return build_embedder(spec)
|
||||
|
||||
|
||||
class Memory:
|
||||
@@ -194,9 +201,7 @@ class Memory:
|
||||
if self._embedder_instance is None:
|
||||
try:
|
||||
if isinstance(self._embedder_config, dict):
|
||||
from crewai.rag.embeddings.factory import build_embedder
|
||||
|
||||
self._embedder_instance = build_embedder(self._embedder_config) # type: ignore[call-overload]
|
||||
self._embedder_instance = build_embedder(self._embedder_config)
|
||||
else:
|
||||
self._embedder_instance = _default_embedder()
|
||||
except Exception as e:
|
||||
|
||||
@@ -216,6 +216,10 @@ def build_embedder_from_dict(
|
||||
def build_embedder_from_dict(spec: ONNXProviderSpec) -> ONNXMiniLM_L6_V2: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(spec: dict[str, Any]) -> EmbeddingFunction[Any]: ...
|
||||
|
||||
|
||||
def build_embedder_from_dict(spec): # type: ignore[no-untyped-def]
|
||||
"""Build an embedding function instance from a dictionary specification.
|
||||
|
||||
@@ -341,6 +345,10 @@ def build_embedder(spec: Text2VecProviderSpec) -> Text2VecEmbeddingFunction: ...
|
||||
def build_embedder(spec: ONNXProviderSpec) -> ONNXMiniLM_L6_V2: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: dict[str, Any]) -> EmbeddingFunction[Any]: ...
|
||||
|
||||
|
||||
def build_embedder(spec): # type: ignore[no-untyped-def]
|
||||
"""Build an embedding function from either a provider spec or a provider instance.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user