From 1cc251b4b8d2ddb9523210219d9454d3f476a9ab Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Wed, 25 Mar 2026 23:42:09 +0800 Subject: [PATCH] feat: add Qdrant Edge storage backend for memory system --- lib/crewai/pyproject.toml | 3 + .../memory/storage/qdrant_edge_storage.py | 872 ++++++++++++++++++ .../src/crewai/memory/unified_memory.py | 21 +- .../tests/memory/test_qdrant_edge_storage.py | 353 +++++++ uv.lock | 27 +- 5 files changed, 1268 insertions(+), 8 deletions(-) create mode 100644 lib/crewai/src/crewai/memory/storage/qdrant_edge_storage.py create mode 100644 lib/crewai/tests/memory/test_qdrant_edge_storage.py diff --git a/lib/crewai/pyproject.toml b/lib/crewai/pyproject.toml index 0b52b26bc..2a80698b5 100644 --- a/lib/crewai/pyproject.toml +++ b/lib/crewai/pyproject.toml @@ -106,6 +106,9 @@ a2a = [ file-processing = [ "crewai-files", ] +qdrant-edge = [ + "qdrant-edge-py>=0.6.0", +] [project.scripts] diff --git a/lib/crewai/src/crewai/memory/storage/qdrant_edge_storage.py b/lib/crewai/src/crewai/memory/storage/qdrant_edge_storage.py new file mode 100644 index 000000000..f20faa408 --- /dev/null +++ b/lib/crewai/src/crewai/memory/storage/qdrant_edge_storage.py @@ -0,0 +1,872 @@ +"""Qdrant Edge storage backend for the unified memory system. + +Uses a write-local/sync-central pattern for safe multi-process access. +Each worker process writes to its own local shard (keyed by PID). Reads +fan out to both local and central shards, merging results. On close, +local records are flushed to the shared central shard. +""" + +from __future__ import annotations + +import asyncio +import atexit +from datetime import datetime, timezone +import logging +import os +from pathlib import Path +import shutil +from typing import Any, Final +import uuid + +from qdrant_edge import ( + CountRequest, + Distance, + EdgeConfig, + EdgeShard, + EdgeVectorParams, + FacetRequest, + FieldCondition, + Filter, + MatchValue, + PayloadSchemaType, + Point, + Query, + QueryRequest, + ScrollRequest, + UpdateOperation, +) + +from crewai.memory.types import MemoryRecord, ScopeInfo + + +_logger = logging.getLogger(__name__) + +VECTOR_NAME: Final[str] = "memory" + +DEFAULT_VECTOR_DIM: Final[int] = 1536 + +_SCROLL_BATCH: Final[int] = 256 + + +def _uuid_to_point_id(uuid_str: str) -> int: + """Convert a UUID string to a stable Qdrant point ID. + + Falls back to hashing for non-UUID strings. + """ + try: + return uuid.UUID(uuid_str).int % (2**63 - 1) + except ValueError: + return int.from_bytes(uuid_str.encode()[:8].ljust(8, b"\x00"), "big") % ( + 2**63 - 1 + ) + + +def _build_scope_ancestors(scope: str) -> list[str]: + """Build the list of all ancestor scopes for prefix filtering. + + For scope ``/crew/sales/agent``, returns + ``["/", "/crew", "/crew/sales", "/crew/sales/agent"]``. + """ + parts = scope.strip("/").split("/") + ancestors: list[str] = ["/"] + current = "" + for part in parts: + if part: + current = f"{current}/{part}" + ancestors.append(current) + return ancestors + + +class QdrantEdgeStorage: + """Qdrant Edge storage backend with write-local/sync-central pattern. + + Each worker process gets its own local shard for writes. + Reads merge results from both local and central shards. On close, + local records are flushed to the shared central shard. + """ + + def __init__( + self, + path: str | Path | None = None, + vector_dim: int | None = None, + ) -> None: + """Initialize Qdrant Edge storage. + + Args: + path: Base directory for shard storage. Defaults to + ``$CREWAI_STORAGE_DIR/memory/qdrant-edge`` or the + platform data directory. + vector_dim: Embedding vector dimensionality. Auto-detected + from the first saved embedding when ``None``. + """ + if path is None: + storage_dir = os.environ.get("CREWAI_STORAGE_DIR") + if storage_dir: + path = Path(storage_dir) / "memory" / "qdrant-edge" + else: + from crewai.utilities.paths import db_storage_path + + path = Path(db_storage_path()) / "memory" / "qdrant-edge" + + self._base_path = Path(path) + self._central_path = self._base_path / "central" + self._local_path = self._base_path / f"worker-{os.getpid()}" + self._vector_dim = vector_dim or 0 + self._config: EdgeConfig | None = None + self._local_has_data = self._local_path.exists() + self._closed = False + self._indexes_created = False + + if self._vector_dim > 0: + self._config = self._build_config(self._vector_dim) + + if self._config is None and self._central_path.exists(): + try: + shard = EdgeShard.load(str(self._central_path)) + if shard.count(CountRequest()) > 0: + pts, _ = shard.scroll( + ScrollRequest(limit=1, with_payload=False, with_vector=True) + ) + if pts and pts[0].vector: + vec = pts[0].vector + if isinstance(vec, dict) and VECTOR_NAME in vec: + vec_data = vec[VECTOR_NAME] + dim = len(vec_data) if isinstance(vec_data, list) else 0 + if dim > 0: + self._vector_dim = dim + self._config = self._build_config(dim) + shard.close() + except Exception: + _logger.debug("Failed to detect dim from central shard", exc_info=True) + + self._cleanup_orphaned_shards() + atexit.register(self.close) + + @staticmethod + def _build_config(dim: int) -> EdgeConfig: + """Build an EdgeConfig for the given vector dimensionality.""" + return EdgeConfig( + vectors={VECTOR_NAME: EdgeVectorParams(size=dim, distance=Distance.Cosine)}, + ) + + def _open_shard(self, path: Path) -> EdgeShard: + """Open an existing shard or create a new one at *path*.""" + path.mkdir(parents=True, exist_ok=True) + try: + return EdgeShard.load(str(path)) + except Exception: + if self._config is None: + raise + return EdgeShard.create(str(path), self._config) + + def _ensure_indexes(self, shard: EdgeShard) -> None: + """Create payload indexes for efficient filtering.""" + if self._indexes_created: + return + try: + shard.update( + UpdateOperation.create_field_index( + "scope_ancestors", PayloadSchemaType.Keyword + ) + ) + shard.update( + UpdateOperation.create_field_index( + "categories", PayloadSchemaType.Keyword + ) + ) + shard.update( + UpdateOperation.create_field_index( + "record_id", PayloadSchemaType.Keyword + ) + ) + self._indexes_created = True + except Exception: + _logger.debug("Index creation failed (may already exist)", exc_info=True) + + def _record_to_point(self, record: MemoryRecord) -> Point: + """Convert a MemoryRecord to a Qdrant Point.""" + return Point( + id=_uuid_to_point_id(record.id), + vector={ + VECTOR_NAME: record.embedding + if record.embedding + else [0.0] * self._vector_dim, + }, + payload={ + "record_id": record.id, + "content": record.content, + "scope": record.scope, + "scope_ancestors": _build_scope_ancestors(record.scope), + "categories": record.categories, + "metadata": record.metadata, + "importance": record.importance, + "created_at": record.created_at.isoformat(), + "last_accessed": record.last_accessed.isoformat(), + "source": record.source or "", + "private": record.private, + }, + ) + + @staticmethod + def _payload_to_record( + payload: dict[str, Any], + vector: dict[str, list[float]] | None = None, + ) -> MemoryRecord: + """Reconstruct a MemoryRecord from a Qdrant payload.""" + + def _parse_dt(val: Any) -> datetime: + if val is None: + return datetime.now(timezone.utc).replace(tzinfo=None) + if isinstance(val, datetime): + return val + return datetime.fromisoformat(str(val).replace("Z", "+00:00")) + + return MemoryRecord( + id=str(payload["record_id"]), + content=str(payload["content"]), + scope=str(payload["scope"]), + categories=payload.get("categories", []), + metadata=payload.get("metadata", {}), + importance=float(payload.get("importance", 0.5)), + created_at=_parse_dt(payload.get("created_at")), + last_accessed=_parse_dt(payload.get("last_accessed")), + embedding=vector.get(VECTOR_NAME) if vector else None, + source=payload.get("source") or None, + private=bool(payload.get("private", False)), + ) + + @staticmethod + def _build_scope_filter(scope_prefix: str | None) -> Filter | None: + """Build a Qdrant Filter for scope prefix matching.""" + if scope_prefix is None or not scope_prefix.strip("/"): + return None + prefix = scope_prefix.rstrip("/") + if not prefix.startswith("/"): + prefix = "/" + prefix + return Filter( + must=[FieldCondition(key="scope_ancestors", match=MatchValue(value=prefix))] + ) + + @staticmethod + def _scroll_all( + shard: EdgeShard, + filt: Filter | None = None, + with_vector: bool = False, + ) -> list[Any]: + """Scroll all points matching a filter from a shard.""" + all_points: list[Any] = [] + offset = None + while True: + batch, next_offset = shard.scroll( + ScrollRequest( + limit=_SCROLL_BATCH, + offset=offset, + with_payload=True, + with_vector=with_vector, + filter=filt, + ) + ) + all_points.extend(batch) + if next_offset is None or not batch: + break + offset = next_offset + return all_points + + def save(self, records: list[MemoryRecord]) -> None: + """Save records to the worker-local shard.""" + if not records: + return + + if self._vector_dim == 0: + for r in records: + if r.embedding and len(r.embedding) > 0: + self._vector_dim = len(r.embedding) + break + if self._config is None and self._vector_dim > 0: + self._config = self._build_config(self._vector_dim) + if self._config is None: + self._config = self._build_config(DEFAULT_VECTOR_DIM) + self._vector_dim = DEFAULT_VECTOR_DIM + + points = [self._record_to_point(r) for r in records] + local = self._open_shard(self._local_path) + try: + self._ensure_indexes(local) + local.update(UpdateOperation.upsert_points(points)) + local.flush() + self._local_has_data = True + finally: + local.close() + + def search( + self, + query_embedding: list[float], + scope_prefix: str | None = None, + categories: list[str] | None = None, + metadata_filter: dict[str, Any] | None = None, + limit: int = 10, + min_score: float = 0.0, + ) -> list[tuple[MemoryRecord, float]]: + """Search both central and local shards, merge results.""" + filt = self._build_scope_filter(scope_prefix) + fetch_limit = limit * 3 if (categories or metadata_filter) else limit + all_scored: list[tuple[dict[str, Any], float, bool]] = [] + + for shard_path in (self._central_path, self._local_path): + if not shard_path.exists(): + continue + is_local = shard_path == self._local_path + try: + shard = EdgeShard.load(str(shard_path)) + results = shard.query( + QueryRequest( + query=Query.Nearest(list(query_embedding), using=VECTOR_NAME), + filter=filt, + limit=fetch_limit, + with_payload=True, + with_vector=False, + ) + ) + all_scored.extend( + (sp.payload or {}, float(sp.score), is_local) for sp in results + ) + shard.close() + except Exception: + _logger.debug("Search failed on %s", shard_path, exc_info=True) + + seen: dict[str, tuple[dict[str, Any], float]] = {} + local_ids: set[str] = set() + for payload, score, is_local in all_scored: + rid = payload["record_id"] + if is_local: + local_ids.add(rid) + seen[rid] = (payload, score) + elif rid not in local_ids: + if rid not in seen or score > seen[rid][1]: + seen[rid] = (payload, score) + + ranked = sorted(seen.values(), key=lambda x: x[1], reverse=True) + out: list[tuple[MemoryRecord, float]] = [] + for payload, score in ranked: + record = self._payload_to_record(payload) + if categories and not any(c in record.categories for c in categories): + continue + if metadata_filter and not all( + record.metadata.get(k) == v for k, v in metadata_filter.items() + ): + continue + if score < min_score: + continue + out.append((record, score)) + if len(out) >= limit: + break + return out[:limit] + + def delete( + self, + scope_prefix: str | None = None, + categories: list[str] | None = None, + record_ids: list[str] | None = None, + older_than: datetime | None = None, + metadata_filter: dict[str, Any] | None = None, + ) -> int: + """Delete matching records from central shard.""" + total_deleted = 0 + for shard_path in (self._central_path, self._local_path): + if not shard_path.exists(): + continue + try: + total_deleted += self._delete_from_shard_path( + shard_path, + scope_prefix, + categories, + record_ids, + older_than, + metadata_filter, + ) + except Exception: + _logger.debug("Delete failed on %s", shard_path, exc_info=True) + return total_deleted + + def _delete_from_shard_path( + self, + shard_path: Path, + scope_prefix: str | None, + categories: list[str] | None, + record_ids: list[str] | None, + older_than: datetime | None, + metadata_filter: dict[str, Any] | None, + ) -> int: + """Delete matching records from a shard at the given path.""" + shard = EdgeShard.load(str(shard_path)) + try: + deleted = self._delete_from_shard( + shard, + scope_prefix, + categories, + record_ids, + older_than, + metadata_filter, + ) + shard.flush() + finally: + shard.close() + return deleted + + def _delete_from_shard( + self, + shard: EdgeShard, + scope_prefix: str | None, + categories: list[str] | None, + record_ids: list[str] | None, + older_than: datetime | None, + metadata_filter: dict[str, Any] | None, + ) -> int: + """Delete matching records from a single shard, returning count deleted.""" + before = shard.count(CountRequest()) + + if record_ids and not (categories or metadata_filter or older_than): + point_ids: list[int | uuid.UUID | str] = [ + _uuid_to_point_id(rid) for rid in record_ids + ] + shard.update(UpdateOperation.delete_points(point_ids)) + return before - shard.count(CountRequest()) + + if categories or metadata_filter or older_than: + scope_filter = self._build_scope_filter(scope_prefix) + points = self._scroll_all(shard, filt=scope_filter) + allowed_ids: set[str] | None = set(record_ids) if record_ids else None + to_delete: list[int | uuid.UUID | str] = [] + for pt in points: + record = self._payload_to_record(pt.payload or {}) + if allowed_ids and record.id not in allowed_ids: + continue + if categories and not any(c in record.categories for c in categories): + continue + if metadata_filter and not all( + record.metadata.get(k) == v for k, v in metadata_filter.items() + ): + continue + if older_than and record.created_at >= older_than: + continue + to_delete.append(pt.id) + if to_delete: + shard.update(UpdateOperation.delete_points(to_delete)) + return before - shard.count(CountRequest()) + + scope_filter = self._build_scope_filter(scope_prefix) + if scope_filter: + shard.update(UpdateOperation.delete_points_by_filter(filter=scope_filter)) + else: + points = self._scroll_all(shard) + if points: + all_ids: list[int | uuid.UUID | str] = [p.id for p in points] + shard.update(UpdateOperation.delete_points(all_ids)) + return before - shard.count(CountRequest()) + + def update(self, record: MemoryRecord) -> None: + """Update a record by upserting with the same point ID.""" + if self._config is None: + if record.embedding and len(record.embedding) > 0: + self._vector_dim = len(record.embedding) + self._config = self._build_config(self._vector_dim) + else: + self._config = self._build_config(DEFAULT_VECTOR_DIM) + self._vector_dim = DEFAULT_VECTOR_DIM + + point = self._record_to_point(record) + local = self._open_shard(self._local_path) + try: + self._ensure_indexes(local) + local.update(UpdateOperation.upsert_points([point])) + local.flush() + self._local_has_data = True + finally: + local.close() + + def get_record(self, record_id: str) -> MemoryRecord | None: + """Return a single record by ID, or None if not found.""" + point_id = _uuid_to_point_id(record_id) + for shard_path in (self._local_path, self._central_path): + if not shard_path.exists(): + continue + try: + shard = EdgeShard.load(str(shard_path)) + records = shard.retrieve([point_id], True, True) + shard.close() + if records: + payload = records[0].payload or {} + vec = records[0].vector + vec_dict = vec if isinstance(vec, dict) else None + return self._payload_to_record(payload, vec_dict) # type: ignore[arg-type] + except Exception: + _logger.debug("get_record failed on %s", shard_path, exc_info=True) + return None + + def list_records( + self, + scope_prefix: str | None = None, + limit: int = 200, + offset: int = 0, + ) -> list[MemoryRecord]: + """List records in a scope, newest first.""" + filt = self._build_scope_filter(scope_prefix) + all_records: list[MemoryRecord] = [] + seen_ids: set[str] = set() + + for shard_path in (self._local_path, self._central_path): + if not shard_path.exists(): + continue + try: + shard = EdgeShard.load(str(shard_path)) + points = self._scroll_all(shard, filt=filt) + shard.close() + for pt in points: + rid = pt.payload["record_id"] + if rid not in seen_ids: + seen_ids.add(rid) + all_records.append(self._payload_to_record(pt.payload)) + except Exception: + _logger.debug("list_records failed on %s", shard_path, exc_info=True) + + all_records.sort(key=lambda r: r.created_at, reverse=True) + return all_records[offset : offset + limit] + + def get_scope_info(self, scope: str) -> ScopeInfo: + """Get information about a scope.""" + scope = scope.rstrip("/") or "/" + prefix = scope if scope != "/" else None + filt = self._build_scope_filter(prefix) + + all_points: list[Any] = [] + for shard_path in (self._central_path, self._local_path): + if not shard_path.exists(): + continue + try: + shard = EdgeShard.load(str(shard_path)) + all_points.extend(self._scroll_all(shard, filt=filt)) + shard.close() + except Exception: + _logger.debug("get_scope_info failed on %s", shard_path, exc_info=True) + + if not all_points: + return ScopeInfo( + path=scope, + record_count=0, + categories=[], + oldest_record=None, + newest_record=None, + child_scopes=[], + ) + + seen: dict[str, Any] = {} + for pt in all_points: + rid = pt.payload["record_id"] + if rid not in seen: + seen[rid] = pt + + categories_set: set[str] = set() + oldest: datetime | None = None + newest: datetime | None = None + child_prefix = (scope + "/") if scope != "/" else "/" + children: set[str] = set() + + for pt in seen.values(): + payload = pt.payload + sc = str(payload.get("scope", "")) + if child_prefix and sc.startswith(child_prefix): + rest = sc[len(child_prefix) :] + first_component = rest.split("/", 1)[0] + if first_component: + children.add(child_prefix + first_component) + for c in payload.get("categories", []): + categories_set.add(c) + created = payload.get("created_at") + if created: + dt = datetime.fromisoformat(str(created).replace("Z", "+00:00")) + if oldest is None or dt < oldest: + oldest = dt + if newest is None or dt > newest: + newest = dt + + return ScopeInfo( + path=scope, + record_count=len(seen), + categories=sorted(categories_set), + oldest_record=oldest, + newest_record=newest, + child_scopes=sorted(children), + ) + + def list_scopes(self, parent: str = "/") -> list[str]: + """List immediate child scopes under a parent path.""" + parent = parent.rstrip("/") or "" + prefix = (parent + "/") if parent else "/" + + all_scopes: set[str] = set() + filt = self._build_scope_filter(prefix if prefix != "/" else None) + for shard_path in (self._central_path, self._local_path): + if not shard_path.exists(): + continue + try: + shard = EdgeShard.load(str(shard_path)) + points = self._scroll_all(shard, filt=filt) + shard.close() + for pt in points: + sc = str(pt.payload.get("scope", "")) + if sc.startswith(prefix) and sc != (prefix.rstrip("/") or "/"): + rest = sc[len(prefix) :] + first_component = rest.split("/", 1)[0] + if first_component: + all_scopes.add(prefix + first_component) + except Exception: + _logger.debug("list_scopes failed on %s", shard_path, exc_info=True) + return sorted(all_scopes) + + def list_categories(self, scope_prefix: str | None = None) -> dict[str, int]: + """List categories and their counts within a scope.""" + if not self._local_has_data and self._central_path.exists(): + try: + shard = EdgeShard.load(str(self._central_path)) + try: + shard.update( + UpdateOperation.create_field_index( + "categories", PayloadSchemaType.Keyword + ) + ) + except Exception: # noqa: S110 + pass + filt = self._build_scope_filter(scope_prefix) + facet_result = shard.facet( + FacetRequest(key="categories", limit=1000, filter=filt) + ) + shard.close() + return {str(hit.value): hit.count for hit in facet_result.hits} + except Exception: + _logger.debug("list_categories failed on central", exc_info=True) + + counts: dict[str, int] = {} + for record in self.list_records(scope_prefix=scope_prefix, limit=50_000): + for c in record.categories: + counts[c] = counts.get(c, 0) + 1 + return counts + + def count(self, scope_prefix: str | None = None) -> int: + """Count records in scope (and subscopes).""" + filt = self._build_scope_filter(scope_prefix) + if not self._local_has_data: + if self._central_path.exists(): + try: + shard = EdgeShard.load(str(self._central_path)) + result = shard.count(CountRequest(filter=filt)) + shard.close() + return result + except Exception: + _logger.debug("count failed on central", exc_info=True) + return 0 + seen_ids: set[str] = set() + for shard_path in (self._local_path, self._central_path): + if not shard_path.exists(): + continue + try: + shard = EdgeShard.load(str(shard_path)) + for pt in self._scroll_all(shard, filt=filt): + seen_ids.add(pt.payload["record_id"]) + shard.close() + except Exception: + _logger.debug("count failed on %s", shard_path, exc_info=True) + return len(seen_ids) + + def reset(self, scope_prefix: str | None = None) -> None: + """Reset (delete all) memories in scope.""" + if scope_prefix is None or not scope_prefix.strip("/"): + for shard_path in (self._central_path, self._local_path): + if shard_path.exists(): + shutil.rmtree(shard_path, ignore_errors=True) + self._local_has_data = False + self._indexes_created = False + return + + self.delete(scope_prefix=scope_prefix) + + def touch_records(self, record_ids: list[str]) -> None: + """Update last_accessed to now for the given record IDs.""" + if not record_ids: + return + now = datetime.now(timezone.utc).replace(tzinfo=None).isoformat() + point_ids: list[int | uuid.UUID | str] = [ + _uuid_to_point_id(rid) for rid in record_ids + ] + for shard_path in (self._central_path, self._local_path): + if not shard_path.exists(): + continue + try: + shard = EdgeShard.load(str(shard_path)) + shard.update( + UpdateOperation.set_payload(point_ids, {"last_accessed": now}) + ) + shard.flush() + shard.close() + except Exception: + _logger.debug("touch_records failed on %s", shard_path, exc_info=True) + + def optimize(self) -> None: + """Compact the central shard synchronously.""" + if not self._central_path.exists(): + return + try: + shard = EdgeShard.load(str(self._central_path)) + shard.optimize() + shard.close() + except Exception: + _logger.debug("optimize failed", exc_info=True) + + def _upsert_to_central(self, points: list[Any]) -> None: + """Convert scrolled points to Qdrant Points and upsert to central shard.""" + qdrant_points = [ + Point( + id=pt.id, + vector=pt.vector if pt.vector else {}, + payload=pt.payload if pt.payload else {}, + ) + for pt in points + ] + central = self._open_shard(self._central_path) + try: + self._ensure_indexes(central) + central.update(UpdateOperation.upsert_points(qdrant_points)) + central.flush() + finally: + central.close() + + def flush_to_central(self) -> None: + """Sync local shard records to the central shard.""" + if not self._local_has_data or not self._local_path.exists(): + return + + try: + local = EdgeShard.load(str(self._local_path)) + except Exception: + _logger.debug("flush_to_central: failed to open local shard", exc_info=True) + return + + points = self._scroll_all(local, with_vector=True) + local.close() + + if not points: + shutil.rmtree(self._local_path, ignore_errors=True) + self._local_has_data = False + return + + self._upsert_to_central(points) + shutil.rmtree(self._local_path, ignore_errors=True) + self._local_has_data = False + + def close(self) -> None: + """Flush local shard to central and clean up.""" + if self._closed: + return + self._closed = True + atexit.unregister(self.close) + try: + self.flush_to_central() + except Exception: + _logger.debug("close: flush_to_central failed", exc_info=True) + + def _cleanup_orphaned_shards(self) -> None: + """Sync and remove local shards from dead worker processes.""" + if not self._base_path.exists(): + return + for entry in self._base_path.iterdir(): + if not entry.is_dir() or not entry.name.startswith("worker-"): + continue + pid_str = entry.name.removeprefix("worker-") + try: + pid = int(pid_str) + except ValueError: + continue + if pid == os.getpid(): + continue + try: + os.kill(pid, 0) + continue + except ProcessLookupError: + _logger.debug("Worker %d is dead, shard is orphaned", pid) + except PermissionError: + continue + + _logger.info("Cleaning up orphaned shard for dead worker %d", pid) + try: + orphan = EdgeShard.load(str(entry)) + points = self._scroll_all(orphan, with_vector=True) + orphan.close() + + if not points: + shutil.rmtree(entry, ignore_errors=True) + continue + + if self._config is None: + for pt in points: + vec = pt.vector + if isinstance(vec, dict) and VECTOR_NAME in vec: + vec_data = vec[VECTOR_NAME] + if isinstance(vec_data, list) and len(vec_data) > 0: + self._vector_dim = len(vec_data) + self._config = self._build_config(self._vector_dim) + break + + if self._config is None: + _logger.warning( + "Cannot recover orphaned shard %s: vector dimension unknown", + entry, + ) + continue + + self._upsert_to_central(points) + shutil.rmtree(entry, ignore_errors=True) + except Exception: + _logger.warning( + "Failed to recover orphaned shard %s", entry, exc_info=True + ) + + async def asave(self, records: list[MemoryRecord]) -> None: + """Save memory records asynchronously.""" + await asyncio.to_thread(self.save, records) + + async def asearch( + self, + query_embedding: list[float], + scope_prefix: str | None = None, + categories: list[str] | None = None, + metadata_filter: dict[str, Any] | None = None, + limit: int = 10, + min_score: float = 0.0, + ) -> list[tuple[MemoryRecord, float]]: + """Search for memories asynchronously.""" + return await asyncio.to_thread( + self.search, + query_embedding, + scope_prefix=scope_prefix, + categories=categories, + metadata_filter=metadata_filter, + limit=limit, + min_score=min_score, + ) + + async def adelete( + self, + scope_prefix: str | None = None, + categories: list[str] | None = None, + record_ids: list[str] | None = None, + older_than: datetime | None = None, + metadata_filter: dict[str, Any] | None = None, + ) -> int: + """Delete memories asynchronously.""" + return await asyncio.to_thread( + self.delete, + scope_prefix=scope_prefix, + categories=categories, + record_ids=record_ids, + older_than=older_than, + metadata_filter=metadata_filter, + ) diff --git a/lib/crewai/src/crewai/memory/unified_memory.py b/lib/crewai/src/crewai/memory/unified_memory.py index 488e3c94a..1454f0fcf 100644 --- a/lib/crewai/src/crewai/memory/unified_memory.py +++ b/lib/crewai/src/crewai/memory/unified_memory.py @@ -173,13 +173,18 @@ class Memory(BaseModel): ) if isinstance(self.storage, str): - from crewai.memory.storage.lancedb_storage import LanceDBStorage + if self.storage == "qdrant-edge": + from crewai.memory.storage.qdrant_edge_storage import QdrantEdgeStorage - self._storage = ( - LanceDBStorage() - if self.storage == "lancedb" - else LanceDBStorage(path=self.storage) - ) + self._storage = QdrantEdgeStorage() + elif self.storage == "lancedb": + from crewai.memory.storage.lancedb_storage import LanceDBStorage + + self._storage = LanceDBStorage() + else: + from crewai.memory.storage.lancedb_storage import LanceDBStorage + + self._storage = LanceDBStorage(path=self.storage) else: self._storage = self.storage @@ -293,8 +298,10 @@ class Memory(BaseModel): future.result() # blocks until done; re-raises exceptions def close(self) -> None: - """Drain pending saves and shut down the background thread pool.""" + """Drain pending saves, flush storage, and shut down the background thread pool.""" self.drain_writes() + if hasattr(self._storage, "close"): + self._storage.close() self._save_pool.shutdown(wait=True) def _encode_batch( diff --git a/lib/crewai/tests/memory/test_qdrant_edge_storage.py b/lib/crewai/tests/memory/test_qdrant_edge_storage.py new file mode 100644 index 000000000..a5b36c0a2 --- /dev/null +++ b/lib/crewai/tests/memory/test_qdrant_edge_storage.py @@ -0,0 +1,353 @@ +"""Tests for Qdrant Edge storage backend.""" + +from __future__ import annotations + +import importlib +from datetime import datetime, timedelta, timezone +from pathlib import Path +from typing import TYPE_CHECKING, Any +from unittest.mock import MagicMock + +import pytest + +pytestmark = pytest.mark.skipif( + importlib.util.find_spec("qdrant_edge") is None, + reason="qdrant-edge-py not installed", +) + +if TYPE_CHECKING: + from crewai.memory.storage.qdrant_edge_storage import QdrantEdgeStorage + +from crewai.memory.types import MemoryRecord + + +def _make_storage(path: str, vector_dim: int = 4) -> QdrantEdgeStorage: + from crewai.memory.storage.qdrant_edge_storage import QdrantEdgeStorage + + return QdrantEdgeStorage(path=path, vector_dim=vector_dim) + + +@pytest.fixture +def storage(tmp_path: Path) -> QdrantEdgeStorage: + return _make_storage(str(tmp_path / "edge")) + + +def _rec( + content: str = "test", + scope: str = "/", + categories: list[str] | None = None, + importance: float = 0.5, + embedding: list[float] | None = None, + metadata: dict | None = None, + created_at: datetime | None = None, +) -> MemoryRecord: + return MemoryRecord( + content=content, + scope=scope, + categories=categories or [], + importance=importance, + embedding=embedding or [0.1, 0.2, 0.3, 0.4], + metadata=metadata or {}, + **({"created_at": created_at} if created_at else {}), + ) + + +# --- Basic CRUD --- + + +def test_save_search(storage: QdrantEdgeStorage) -> None: + r = _rec(content="test content", scope="/foo", categories=["cat1"], importance=0.8) + storage.save([r]) + results = storage.search([0.1, 0.2, 0.3, 0.4], scope_prefix="/foo", limit=5) + assert len(results) == 1 + rec, score = results[0] + assert rec.content == "test content" + assert rec.scope == "/foo" + assert score >= 0.0 + + +def test_delete_count(storage: QdrantEdgeStorage) -> None: + r = _rec(scope="/") + storage.save([r]) + assert storage.count() == 1 + n = storage.delete(scope_prefix="/") + assert n >= 1 + assert storage.count() == 0 + + +def test_update_get_record(storage: QdrantEdgeStorage) -> None: + r = _rec(content="original", scope="/a") + storage.save([r]) + r.content = "updated" + storage.update(r) + found = storage.get_record(r.id) + assert found is not None + assert found.content == "updated" + + +def test_get_record_not_found(storage: QdrantEdgeStorage) -> None: + assert storage.get_record("nonexistent-id") is None + + +# --- Scope operations --- + + +def test_list_scopes_get_scope_info(storage: QdrantEdgeStorage) -> None: + storage.save([ + _rec(content="a", scope="/"), + _rec(content="b", scope="/team"), + ]) + scopes = storage.list_scopes("/") + assert "/team" in scopes + info = storage.get_scope_info("/") + assert info.record_count >= 1 + assert info.path == "/" + + +def test_scope_prefix_filter(storage: QdrantEdgeStorage) -> None: + storage.save([ + _rec(content="sales note", scope="/crew/sales"), + _rec(content="eng note", scope="/crew/eng"), + _rec(content="other note", scope="/other"), + ]) + results = storage.search([0.1, 0.2, 0.3, 0.4], scope_prefix="/crew", limit=10) + assert len(results) == 2 + scopes = {r.scope for r, _ in results} + assert "/crew/sales" in scopes + assert "/crew/eng" in scopes + + +# --- Filtering --- + + +def test_category_filter(storage: QdrantEdgeStorage) -> None: + storage.save([ + _rec(content="cat1 item", categories=["cat1"]), + _rec(content="cat2 item", categories=["cat2"]), + ]) + results = storage.search( + [0.1, 0.2, 0.3, 0.4], categories=["cat1"], limit=10 + ) + assert len(results) == 1 + assert results[0][0].categories == ["cat1"] + + +def test_metadata_filter(storage: QdrantEdgeStorage) -> None: + storage.save([ + _rec(content="with key", metadata={"env": "prod"}), + _rec(content="without key", metadata={"env": "dev"}), + ]) + results = storage.search( + [0.1, 0.2, 0.3, 0.4], metadata_filter={"env": "prod"}, limit=10 + ) + assert len(results) == 1 + assert results[0][0].metadata["env"] == "prod" + + +# --- List & pagination --- + + +def test_list_records_pagination(storage: QdrantEdgeStorage) -> None: + records = [ + _rec( + content=f"item {i}", + created_at=datetime(2025, 1, 1) + timedelta(days=i), + ) + for i in range(5) + ] + storage.save(records) + page1 = storage.list_records(limit=2, offset=0) + page2 = storage.list_records(limit=2, offset=2) + assert len(page1) == 2 + assert len(page2) == 2 + # Newest first. + assert page1[0].created_at >= page1[1].created_at + + +def test_list_categories(storage: QdrantEdgeStorage) -> None: + storage.save([ + _rec(categories=["a", "b"]), + _rec(categories=["b", "c"]), + ]) + cats = storage.list_categories() + assert cats.get("b", 0) == 2 + assert cats.get("a", 0) >= 1 + assert cats.get("c", 0) >= 1 + + +# --- Touch & reset --- + + +def test_touch_records(storage: QdrantEdgeStorage) -> None: + r = _rec() + storage.save([r]) + before = storage.get_record(r.id) + assert before is not None + old_accessed = before.last_accessed + storage.touch_records([r.id]) + after = storage.get_record(r.id) + assert after is not None + assert after.last_accessed >= old_accessed + + +def test_reset_full(storage: QdrantEdgeStorage) -> None: + storage.save([_rec(scope="/a"), _rec(scope="/b")]) + assert storage.count() == 2 + storage.reset() + assert storage.count() == 0 + + +def test_reset_scoped(storage: QdrantEdgeStorage) -> None: + storage.save([_rec(scope="/a"), _rec(scope="/b")]) + storage.reset(scope_prefix="/a") + assert storage.count() == 1 + + +# --- Dual-shard & sync --- + + +def test_flush_to_central(tmp_path: Path) -> None: + s = _make_storage(str(tmp_path / "edge")) + s.save([_rec(content="to sync")]) + assert s._local_has_data + s.flush_to_central() + assert not s._local_has_data + assert not s._local_path.exists() + # Central should have the record. + assert s.count() == 1 + + +def test_dual_shard_search(tmp_path: Path) -> None: + s = _make_storage(str(tmp_path / "edge")) + # Save and flush to central. + s.save([_rec(content="central record", scope="/a")]) + s.flush_to_central() + # Save to local only. + s._closed = False # Reset for continued use. + s.save([_rec(content="local record", scope="/b")]) + # Search should find both. + results = s.search([0.1, 0.2, 0.3, 0.4], limit=10) + assert len(results) == 2 + contents = {r.content for r, _ in results} + assert "central record" in contents + assert "local record" in contents + + +def test_close_lifecycle(tmp_path: Path) -> None: + s = _make_storage(str(tmp_path / "edge")) + s.save([_rec(content="persisted")]) + s.close() + # Reopen a new storage — should find the record in central. + s2 = _make_storage(str(tmp_path / "edge")) + results = s2.search([0.1, 0.2, 0.3, 0.4], limit=5) + assert len(results) == 1 + assert results[0][0].content == "persisted" + s2.close() + + +def test_orphaned_shard_cleanup(tmp_path: Path) -> None: + base = tmp_path / "edge" + # Create a fake orphaned shard using a PID that doesn't exist. + fake_pid = 99999999 + s1 = _make_storage(str(base)) + # Manually create a shard at the orphaned path. + orphan_path = base / f"worker-{fake_pid}" + orphan_path.mkdir(parents=True, exist_ok=True) + from qdrant_edge import ( + EdgeConfig, + EdgeShard, + EdgeVectorParams, + Distance, + Point, + UpdateOperation, + ) + + config = EdgeConfig( + vectors={"memory": EdgeVectorParams(size=4, distance=Distance.Cosine)} + ) + orphan = EdgeShard.create(str(orphan_path), config) + orphan.update( + UpdateOperation.upsert_points([ + Point( + id=12345, + vector={"memory": [0.5, 0.5, 0.5, 0.5]}, + payload={ + "record_id": "orphan-uuid", + "content": "orphaned", + "scope": "/", + "scope_ancestors": ["/"], + "categories": [], + "metadata": {}, + "importance": 0.5, + "created_at": datetime.now(timezone.utc).replace(tzinfo=None).isoformat(), + "last_accessed": datetime.now(timezone.utc).replace(tzinfo=None).isoformat(), + "source": "", + "private": False, + }, + ) + ]) + ) + orphan.flush() + orphan.close() + s1.close() + + # Creating a new storage should detect and recover the orphaned shard. + s2 = _make_storage(str(base)) + assert not orphan_path.exists() + # The orphaned record should now be in central. + results = s2.search([0.5, 0.5, 0.5, 0.5], limit=5) + assert len(results) >= 1 + assert any(r.content == "orphaned" for r, _ in results) + s2.close() + + +# --- Integration with Memory class --- + + +def test_memory_with_qdrant_edge(tmp_path: Path) -> None: + from crewai.memory.unified_memory import Memory + + mock_embedder = MagicMock() + mock_embedder.side_effect = lambda texts: [[0.1, 0.2, 0.3, 0.4] for _ in texts] + + storage = _make_storage(str(tmp_path / "edge")) + m = Memory( + storage=storage, + llm=MagicMock(), + embedder=mock_embedder, + ) + r = m.remember( + "We decided to use Qdrant Edge.", + scope="/project", + categories=["decision"], + importance=0.7, + ) + assert r.content == "We decided to use Qdrant Edge." + + matches = m.recall("Qdrant", scope="/project", limit=5, depth="shallow") + assert len(matches) >= 1 + m.close() + + +def test_memory_string_storage_qdrant_edge(tmp_path: Path) -> None: + """Test that storage='qdrant-edge' string instantiation works.""" + import os + + os.environ["CREWAI_STORAGE_DIR"] = str(tmp_path) + try: + from crewai.memory.unified_memory import Memory + + mock_embedder = MagicMock() + mock_embedder.side_effect = lambda texts: [[0.1, 0.2, 0.3, 0.4] for _ in texts] + + m = Memory( + storage="qdrant-edge", + llm=MagicMock(), + embedder=mock_embedder, + ) + from crewai.memory.storage.qdrant_edge_storage import QdrantEdgeStorage + + assert isinstance(m._storage, QdrantEdgeStorage) + m.close() + finally: + os.environ.pop("CREWAI_STORAGE_DIR", None) diff --git a/uv.lock b/uv.lock index 5eed2bdca..ced50114f 100644 --- a/uv.lock +++ b/uv.lock @@ -1205,6 +1205,9 @@ pandas = [ qdrant = [ { name = "qdrant-client", extra = ["fastembed"] }, ] +qdrant-edge = [ + { name = "qdrant-edge-py" }, +] tools = [ { name = "crewai-tools" }, ] @@ -1259,6 +1262,7 @@ requires-dist = [ { name = "python-dotenv", specifier = "~=1.1.1" }, { name = "pyyaml", specifier = "~=6.0" }, { name = "qdrant-client", extras = ["fastembed"], marker = "extra == 'qdrant'", specifier = "~=1.14.3" }, + { name = "qdrant-edge-py", marker = "extra == 'qdrant-edge'", specifier = ">=0.6.0" }, { name = "regex", specifier = "~=2026.1.15" }, { name = "textual", specifier = ">=7.5.0" }, { name = "tiktoken", marker = "extra == 'embeddings'", specifier = "~=0.8.0" }, @@ -1268,7 +1272,7 @@ requires-dist = [ { name = "uv", specifier = "~=0.9.13" }, { name = "voyageai", marker = "extra == 'voyageai'", specifier = "~=0.3.5" }, ] -provides-extras = ["a2a", "anthropic", "aws", "azure-ai-inference", "bedrock", "docling", "embeddings", "file-processing", "google-genai", "litellm", "mem0", "openpyxl", "pandas", "qdrant", "tools", "voyageai", "watson"] +provides-extras = ["a2a", "anthropic", "aws", "azure-ai-inference", "bedrock", "docling", "embeddings", "file-processing", "google-genai", "litellm", "mem0", "openpyxl", "pandas", "qdrant", "qdrant-edge", "tools", "voyageai", "watson"] [[package]] name = "crewai-devtools" @@ -6613,6 +6617,27 @@ fastembed = [ { name = "fastembed", version = "0.7.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.13'" }, ] +[[package]] +name = "qdrant-edge-py" +version = "0.6.0" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1c/72/fce3df4e4b8882b5b00ab3d0a574bbeee2d39a8e520ccf246f456effd185/qdrant_edge_py-0.6.0-cp310-abi3-macosx_10_12_x86_64.whl", hash = "sha256:c9d463e7fa81541d60ab8671e6e92a9afd8c4a0e2cfb7e13ea8f5d76e70b877a", size = 9728290, upload-time = "2026-03-19T21:16:15.03Z" }, + { url = "https://files.pythonhosted.org/packages/41/99/70f4e87f7f2ef68c5f92104b914c0e756c22b4bd19957de30a213dadff22/qdrant_edge_py-0.6.0-cp310-abi3-macosx_11_0_arm64.whl", hash = "sha256:a18b0bf0355260466bb8d453f2cedc7a9e4f6a2e9d9c58489b859150a3c7e0a6", size = 9203390, upload-time = "2026-03-19T21:16:17.255Z" }, + { url = "https://files.pythonhosted.org/packages/80/55/998ea744a4cef59c69e86b7b2b57ca2f2d4b0f86c212c7b43dd90cc6360e/qdrant_edge_py-0.6.0-cp310-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cda53f31d8693d090ec564e6761037f57af6f342ac2eef82e1c160c00d80f331", size = 10287388, upload-time = "2026-03-19T21:16:19.215Z" }, + { url = "https://files.pythonhosted.org/packages/40/d2/9e24a9c57699fe6df9a4f3b6cd0d4c3c9f0bfdbd502a28d25fdfadd44ab5/qdrant_edge_py-0.6.0-cp310-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:80c5e8f8cf650e422a3d313e394bde2760c6206914cd9d6142c9c5e730a76639", size = 9752632, upload-time = "2026-03-19T21:16:21.409Z" }, + { url = "https://files.pythonhosted.org/packages/0c/3c/a01840efcae392e5a376a483b9a19705ed0f5bc030befbe3d25b58a6d3d4/qdrant_edge_py-0.6.0-cp310-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:d2ab0d209f693fd0d5225072441ed47eccee4f7044470a293c54a3ffdf963cfc", size = 10287245, upload-time = "2026-03-19T21:16:24.366Z" }, + { url = "https://files.pythonhosted.org/packages/7a/45/a3ec5e7d36c5dd4510e4f90d0adaf6aa3e66cff35884ff3edefce240fd77/qdrant_edge_py-0.6.0-cp310-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:9abd0c3aedfed380d4c4a82626004b746bd05cb6a8e28e1b2fe7467726dc8840", size = 9935881, upload-time = "2026-03-19T21:16:26.384Z" }, + { url = "https://files.pythonhosted.org/packages/66/0d/43c9033fbb12f0858d5af73b842acb02b3208fe1a31882def2ef23fd560c/qdrant_edge_py-0.6.0-cp310-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:ea51a917fc1b927d799d60e166337b6837ee3da39c23d4dc736b82b67497ff12", size = 10507046, upload-time = "2026-03-19T21:16:28.536Z" }, + { url = "https://files.pythonhosted.org/packages/73/33/b2ead1c51a59d31d19418e6d6ca8ea3ce0f32f76efdd48248a1a3791357f/qdrant_edge_py-0.6.0-cp310-abi3-win_amd64.whl", hash = "sha256:d8376e30b53fbb5d9ac8b0aea683173096d7a775b351110aee4337460c906e71", size = 9905482, upload-time = "2026-03-19T21:16:30.555Z" }, + { url = "https://files.pythonhosted.org/packages/09/be/a054ac8902e942b0d44e27e8c0e4d3593a34bb143726aa3d9bebd215e7f7/qdrant_edge_py-0.6.0-pp311-pypy311_pp73-macosx_10_12_x86_64.whl", hash = "sha256:6e94804d9aa0c973fe25c83aec16da8c0f9e6a955a0cb1668bd972e1ca4b5604", size = 9724896, upload-time = "2026-03-19T21:16:32.793Z" }, + { url = "https://files.pythonhosted.org/packages/19/30/285eed25d8bab071b9867937b1e0fdc002c0c1180ff43476e5044029e73c/qdrant_edge_py-0.6.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:2ca40da1fa22ff4fd05e669d76c1087d3354486bcb685e9b07b1ca0ab5ef6b97", size = 9199009, upload-time = "2026-03-19T21:16:34.954Z" }, + { url = "https://files.pythonhosted.org/packages/41/d7/b729bbd887476a0a3040fc95d2548e519601d69b2f9d7ece83daf7958372/qdrant_edge_py-0.6.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:12fde5356eeb83ce8031a339ca73ea0a1a9b98927843f5bf7fa5c0412ca5ff79", size = 10279079, upload-time = "2026-03-19T21:16:36.876Z" }, + { url = "https://files.pythonhosted.org/packages/74/2e/68ef2346b6971b8b4d6b479099618dc2879d8c2e357065f8910aeb8b6ed5/qdrant_edge_py-0.6.0-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:c110af3ddbd4a5dae0421457e4a6f1f83c24411ea1187d557367ef5499cb6bef", size = 9746991, upload-time = "2026-03-19T21:16:38.968Z" }, + { url = "https://files.pythonhosted.org/packages/cd/46/3bfcc5e13d1a7d110a2d1ecf86c63a781e71e543712232be59d7a3f34e96/qdrant_edge_py-0.6.0-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:839651466c217bb8f684a3a0b9ad0726c670fcc734b552eef3ad76fbb4f5a12b", size = 10282664, upload-time = "2026-03-19T21:16:40.952Z" }, + { url = "https://files.pythonhosted.org/packages/80/54/7ba6bbaa2b53a188b0a43a6c063007e9a58afa3e35326f63518efbc6f5e8/qdrant_edge_py-0.6.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:c7665230dc4a2412412765fbdf9053e32b32f4c60579881ed68140b4d0ba6915", size = 9901015, upload-time = "2026-03-19T21:16:43.407Z" }, +] + [[package]] name = "questionary" version = "2.1.1"