mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-04 00:32:36 +00:00
fix: rag tool embeddings config
* fix: ensure config is not flattened, add tests * chore: refactor inits to model_validator * chore: refactor rag tool config parsing * chore: add initial docs * chore: add additional validation aliases for provider env vars * chore: add solid docs * chore: move imports to top * fix: revert circular import * fix: lazy import qdrant-client * fix: allow collection name config * chore: narrow model names for google * chore: update additional docs * chore: add backward compat on model name aliases * chore: add tests for config changes
This commit is contained in:
@@ -1,28 +1,51 @@
|
||||
"""Adapter for CrewAI's native RAG system."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeAlias, TypedDict
|
||||
from typing import TYPE_CHECKING, Any, TypeAlias, TypedDict, cast
|
||||
import uuid
|
||||
|
||||
from crewai.rag.config.types import RagConfigType
|
||||
from crewai.rag.config.utils import get_rag_client
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.factory import create_client
|
||||
from crewai.rag.qdrant.config import QdrantConfig
|
||||
from crewai.rag.types import BaseRecord, SearchResult
|
||||
from pydantic import PrivateAttr
|
||||
from qdrant_client.models import VectorParams
|
||||
from typing_extensions import Unpack
|
||||
from pydantic.dataclasses import is_pydantic_dataclass
|
||||
from typing_extensions import TypeIs, Unpack
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from crewai_tools.rag.misc import sanitize_metadata_for_chromadb
|
||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.rag.qdrant.config import QdrantConfig
|
||||
|
||||
|
||||
ContentItem: TypeAlias = str | Path | dict[str, Any]
|
||||
|
||||
|
||||
def _is_qdrant_config(config: Any) -> TypeIs[QdrantConfig]:
|
||||
"""Check if config is a QdrantConfig using safe duck typing.
|
||||
|
||||
Args:
|
||||
config: RAG configuration to check.
|
||||
|
||||
Returns:
|
||||
True if config is a QdrantConfig instance.
|
||||
"""
|
||||
if not is_pydantic_dataclass(config):
|
||||
return False
|
||||
|
||||
try:
|
||||
return cast(bool, config.provider == "qdrant") # type: ignore[attr-defined]
|
||||
except (AttributeError, ImportError):
|
||||
return False
|
||||
|
||||
|
||||
class AddDocumentParams(TypedDict, total=False):
|
||||
"""Parameters for adding documents to the RAG system."""
|
||||
|
||||
@@ -56,8 +79,9 @@ class CrewAIRagAdapter(Adapter):
|
||||
else:
|
||||
self._client = get_rag_client()
|
||||
collection_params: dict[str, Any] = {"collection_name": self.collection_name}
|
||||
if isinstance(self.config, QdrantConfig) and self.config.vectors_config:
|
||||
if isinstance(self.config.vectors_config, VectorParams):
|
||||
|
||||
if self.config is not None and _is_qdrant_config(self.config):
|
||||
if self.config.vectors_config is not None:
|
||||
collection_params["vectors_config"] = self.config.vectors_config
|
||||
self._client.get_or_create_collection(**collection_params)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user