mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-04 06:29:22 +00:00
feat(valkey): shared cache config + ValkeyCache for A2A and file uploads
Extract duplicated Redis URL parsing into a shared cache_config utility. Introduce ValkeyCache as a lightweight async key/value cache using valkey-glide. Wire it into A2A task handling, agent card caching, and file upload caching. Part 1/4 of Valkey storage implementation. fix: async-safe embeddings and resilient drain_writes Add bytes→float validators on MemoryRecord and ItemState to handle Valkey returning embeddings as raw bytes. Make embed_texts() safe when called from an async context by using a thread pool. Improve drain_writes() with per-save timeouts and error logging instead of raising on failure. Part 3/4 of Valkey storage implementation. feat(valkey): ValkeyStorage vector memory backend Add ValkeyStorage, a distributed StorageBackend implementation using Valkey-GLIDE with Valkey Search for vector similarity. Wire it into Memory as the 'valkey' storage option. Pin scrapegraph-py<2 to fix unrelated upstream breakage. Part 4/4 of Valkey storage implementation. fix: use datetime.utcnow() for last_accessed consistency MemoryRecord defaults use utcnow() for created_at and last_accessed. Match that in ValkeyStorage.update_record() to avoid timezone inconsistency in recency scoring. feat(valkey): shared cache config + ValkeyCache for A2A and file uploads Extract duplicated Redis URL parsing into a shared cache_config utility. Introduce ValkeyCache as a lightweight async key/value cache using valkey-glide. Wire it into A2A task handling, agent card caching, and file upload caching. Part 1/4 of Valkey storage implementation. fix: handle non-numeric database path in cache URL parsing Extract _parse_db_from_path() helper that catches ValueError for paths like /mydb and defaults to 0 with a warning, instead of crashing. fix: async-safe embeddings and resilient drain_writes Add bytes→float validators on MemoryRecord and ItemState to handle Valkey returning embeddings as raw bytes. Make embed_texts() safe when called from an async context by using a thread pool. Improve drain_writes() with per-save timeouts and error logging instead of raising on failure. Part 3/4 of Valkey storage implementation. fix: catch concurrent.futures.TimeoutError for Python 3.10 compat In Python <3.11, concurrent.futures.TimeoutError is distinct from the builtin TimeoutError. Catch both so the timeout warning path works on all supported Python versions.
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
"""Cache for tracking uploaded files using aiocache."""
|
||||
"""Cache for tracking uploaded files using aiocache or ValkeyCache."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -10,10 +10,11 @@ from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Protocol
|
||||
|
||||
from aiocache import Cache # type: ignore[import-untyped]
|
||||
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
|
||||
from crewai.utilities.cache_config import parse_cache_url
|
||||
|
||||
from crewai_files.core.constants import DEFAULT_MAX_CACHE_ENTRIES, DEFAULT_TTL_SECONDS
|
||||
from crewai_files.uploaders.factory import ProviderType
|
||||
@@ -51,6 +52,33 @@ class CachedUpload:
|
||||
return False
|
||||
return datetime.now(timezone.utc) >= self.expires_at
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Serialize to a JSON-compatible dict."""
|
||||
return {
|
||||
"file_id": self.file_id,
|
||||
"provider": self.provider,
|
||||
"file_uri": self.file_uri,
|
||||
"content_type": self.content_type,
|
||||
"uploaded_at": self.uploaded_at.isoformat(),
|
||||
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> CachedUpload:
|
||||
"""Deserialize from a dict."""
|
||||
return cls(
|
||||
file_id=data["file_id"],
|
||||
provider=data["provider"],
|
||||
file_uri=data.get("file_uri"),
|
||||
content_type=data["content_type"],
|
||||
uploaded_at=datetime.fromisoformat(data["uploaded_at"]),
|
||||
expires_at=(
|
||||
datetime.fromisoformat(data["expires_at"])
|
||||
if data.get("expires_at")
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _make_key(file_hash: str, provider: str) -> str:
|
||||
"""Create a cache key from file hash and provider."""
|
||||
@@ -58,14 +86,7 @@ def _make_key(file_hash: str, provider: str) -> str:
|
||||
|
||||
|
||||
def _compute_file_hash_streaming(chunks: Iterator[bytes]) -> str:
|
||||
"""Compute SHA-256 hash from streaming chunks.
|
||||
|
||||
Args:
|
||||
chunks: Iterator of byte chunks.
|
||||
|
||||
Returns:
|
||||
Hexadecimal hash string.
|
||||
"""
|
||||
"""Compute SHA-256 hash from streaming chunks."""
|
||||
hasher = hashlib.sha256()
|
||||
for chunk in chunks:
|
||||
hasher.update(chunk)
|
||||
@@ -73,10 +94,7 @@ def _compute_file_hash_streaming(chunks: Iterator[bytes]) -> str:
|
||||
|
||||
|
||||
def _compute_file_hash(file: FileInput) -> str:
|
||||
"""Compute SHA-256 hash of file content.
|
||||
|
||||
Uses streaming for FilePath sources to avoid loading large files into memory.
|
||||
"""
|
||||
"""Compute SHA-256 hash of file content."""
|
||||
from crewai_files.core.sources import FilePath
|
||||
|
||||
source = file._file_source
|
||||
@@ -86,10 +104,73 @@ def _compute_file_hash(file: FileInput) -> str:
|
||||
return hashlib.sha256(content).hexdigest()
|
||||
|
||||
|
||||
class UploadCache:
|
||||
"""Async cache for tracking uploaded files using aiocache.
|
||||
class CacheBackend(Protocol):
|
||||
"""Protocol for cache backends used by UploadCache."""
|
||||
|
||||
Supports in-memory caching by default, with optional Redis backend
|
||||
async def get(self, key: str) -> CachedUpload | None: ...
|
||||
async def set(self, key: str, value: CachedUpload, ttl: int) -> None: ...
|
||||
async def delete(self, key: str) -> bool: ...
|
||||
|
||||
|
||||
class AiocacheBackend:
|
||||
"""Cache backend backed by aiocache (memory or Redis)."""
|
||||
|
||||
def __init__(self, cache: Cache) -> None: # type: ignore[no-any-unimported]
|
||||
self._cache = cache
|
||||
|
||||
async def get(self, key: str) -> CachedUpload | None:
|
||||
result = await self._cache.get(key)
|
||||
if isinstance(result, CachedUpload):
|
||||
return result
|
||||
return None
|
||||
|
||||
async def set(self, key: str, value: CachedUpload, ttl: int) -> None:
|
||||
await self._cache.set(key, value, ttl=ttl)
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
result = await self._cache.delete(key)
|
||||
return bool(result > 0 if isinstance(result, int) else result)
|
||||
|
||||
|
||||
class ValkeyCacheBackend:
|
||||
"""Cache backend backed by ValkeyCache (JSON serialization)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str = "localhost",
|
||||
port: int = 6379,
|
||||
db: int = 0,
|
||||
password: str | None = None,
|
||||
default_ttl: int | None = None,
|
||||
) -> None:
|
||||
from crewai.memory.storage.valkey_cache import ValkeyCache
|
||||
|
||||
self._cache = ValkeyCache(
|
||||
host=host, port=port, db=db, password=password, default_ttl=default_ttl
|
||||
)
|
||||
|
||||
async def get(self, key: str) -> CachedUpload | None:
|
||||
data = await self._cache.get(key)
|
||||
if data is None:
|
||||
return None
|
||||
try:
|
||||
return CachedUpload.from_dict(data)
|
||||
except (KeyError, ValueError) as e:
|
||||
logger.warning(f"Failed to deserialize cached upload: {e}")
|
||||
return None
|
||||
|
||||
async def set(self, key: str, value: CachedUpload, ttl: int) -> None:
|
||||
await self._cache.set(key, value.to_dict(), ttl=ttl)
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
await self._cache.delete(key)
|
||||
return True # ValkeyCache.delete is void
|
||||
|
||||
|
||||
class UploadCache:
|
||||
"""Async cache for tracking uploaded files.
|
||||
|
||||
Supports in-memory caching by default, with optional Redis or Valkey backend
|
||||
for distributed setups.
|
||||
|
||||
Attributes:
|
||||
@@ -110,7 +191,7 @@ class UploadCache:
|
||||
Args:
|
||||
ttl: Default TTL in seconds.
|
||||
namespace: Cache namespace.
|
||||
cache_type: Backend type ("memory" or "redis").
|
||||
cache_type: Backend type ("memory", "redis", or "valkey").
|
||||
max_entries: Maximum cache entries (None for unlimited).
|
||||
**cache_kwargs: Additional args for cache backend.
|
||||
"""
|
||||
@@ -120,18 +201,39 @@ class UploadCache:
|
||||
self._provider_keys: dict[ProviderType, set[str]] = {}
|
||||
self._key_access_order: list[str] = []
|
||||
|
||||
self._backend: CacheBackend = self._create_backend(
|
||||
cache_type, namespace, ttl, **cache_kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_backend(
|
||||
cache_type: str,
|
||||
namespace: str,
|
||||
ttl: int,
|
||||
**cache_kwargs: Any,
|
||||
) -> CacheBackend:
|
||||
"""Create the appropriate cache backend."""
|
||||
if cache_type == "valkey":
|
||||
conn = parse_cache_url() or {}
|
||||
return ValkeyCacheBackend(
|
||||
host=cache_kwargs.get("host", conn.get("host", "localhost")),
|
||||
port=cache_kwargs.get("port", conn.get("port", 6379)),
|
||||
db=cache_kwargs.get("db", conn.get("db", 0)),
|
||||
password=cache_kwargs.get("password", conn.get("password")),
|
||||
default_ttl=ttl,
|
||||
)
|
||||
if cache_type == "redis":
|
||||
self._cache = Cache(
|
||||
Cache.REDIS,
|
||||
serializer=PickleSerializer(),
|
||||
namespace=namespace,
|
||||
**cache_kwargs,
|
||||
)
|
||||
else:
|
||||
self._cache = Cache(
|
||||
serializer=PickleSerializer(),
|
||||
namespace=namespace,
|
||||
return AiocacheBackend(
|
||||
Cache(
|
||||
Cache.REDIS,
|
||||
serializer=PickleSerializer(),
|
||||
namespace=namespace,
|
||||
**cache_kwargs,
|
||||
)
|
||||
)
|
||||
return AiocacheBackend(
|
||||
Cache(serializer=PickleSerializer(), namespace=namespace)
|
||||
)
|
||||
|
||||
def _track_key(self, provider: ProviderType, key: str) -> None:
|
||||
"""Track a key for a provider (for cleanup) and access order."""
|
||||
@@ -157,11 +259,9 @@ class UploadCache:
|
||||
"""
|
||||
if self.max_entries is None:
|
||||
return 0
|
||||
|
||||
current_count = len(self)
|
||||
if current_count < self.max_entries:
|
||||
return 0
|
||||
|
||||
to_evict = max(1, self.max_entries // 10)
|
||||
return await self._evict_oldest(to_evict)
|
||||
|
||||
@@ -176,31 +276,24 @@ class UploadCache:
|
||||
"""
|
||||
evicted = 0
|
||||
keys_to_evict = self._key_access_order[:count]
|
||||
|
||||
for key in keys_to_evict:
|
||||
await self._cache.delete(key)
|
||||
await self._backend.delete(key)
|
||||
self._key_access_order.remove(key)
|
||||
for provider_keys in self._provider_keys.values():
|
||||
provider_keys.discard(key)
|
||||
evicted += 1
|
||||
|
||||
if evicted > 0:
|
||||
logger.debug(f"Evicted {evicted} oldest cache entries")
|
||||
|
||||
return evicted
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Async public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def aget(
|
||||
self, file: FileInput, provider: ProviderType
|
||||
) -> CachedUpload | None:
|
||||
"""Get a cached upload for a file.
|
||||
|
||||
Args:
|
||||
file: The file to look up.
|
||||
provider: The provider name.
|
||||
|
||||
Returns:
|
||||
Cached upload if found and not expired, None otherwise.
|
||||
"""
|
||||
"""Get a cached upload for a file."""
|
||||
file_hash = _compute_file_hash(file)
|
||||
return await self.aget_by_hash(file_hash, provider)
|
||||
|
||||
@@ -217,17 +310,14 @@ class UploadCache:
|
||||
Cached upload if found and not expired, None otherwise.
|
||||
"""
|
||||
key = _make_key(file_hash, provider)
|
||||
result = await self._cache.get(key)
|
||||
|
||||
result = await self._backend.get(key)
|
||||
if result is None:
|
||||
return None
|
||||
if isinstance(result, CachedUpload):
|
||||
if result.is_expired():
|
||||
await self._cache.delete(key)
|
||||
self._untrack_key(provider, key)
|
||||
return None
|
||||
return result
|
||||
return None
|
||||
if result.is_expired():
|
||||
await self._backend.delete(key)
|
||||
self._untrack_key(provider, key)
|
||||
return None
|
||||
return result
|
||||
|
||||
async def aset(
|
||||
self,
|
||||
@@ -237,18 +327,7 @@ class UploadCache:
|
||||
file_uri: str | None = None,
|
||||
expires_at: datetime | None = None,
|
||||
) -> CachedUpload:
|
||||
"""Cache an uploaded file.
|
||||
|
||||
Args:
|
||||
file: The file that was uploaded.
|
||||
provider: The provider name.
|
||||
file_id: Provider-specific file identifier.
|
||||
file_uri: Optional URI for accessing the file.
|
||||
expires_at: When the upload expires.
|
||||
|
||||
Returns:
|
||||
The created cache entry.
|
||||
"""
|
||||
"""Cache an uploaded file."""
|
||||
file_hash = _compute_file_hash(file)
|
||||
return await self.aset_by_hash(
|
||||
file_hash=file_hash,
|
||||
@@ -282,7 +361,6 @@ class UploadCache:
|
||||
The created cache entry.
|
||||
"""
|
||||
await self._evict_if_needed()
|
||||
|
||||
key = _make_key(file_hash, provider)
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
@@ -299,7 +377,7 @@ class UploadCache:
|
||||
if expires_at is not None:
|
||||
ttl = max(0, int((expires_at - now).total_seconds()))
|
||||
|
||||
await self._cache.set(key, cached, ttl=ttl)
|
||||
await self._backend.set(key, cached, ttl=ttl)
|
||||
self._track_key(provider, key)
|
||||
logger.debug(f"Cached upload: {file_id} for provider {provider}")
|
||||
return cached
|
||||
@@ -316,9 +394,7 @@ class UploadCache:
|
||||
"""
|
||||
file_hash = _compute_file_hash(file)
|
||||
key = _make_key(file_hash, provider)
|
||||
|
||||
result = await self._cache.delete(key)
|
||||
removed = bool(result > 0 if isinstance(result, int) else result)
|
||||
removed = await self._backend.delete(key)
|
||||
if removed:
|
||||
self._untrack_key(provider, key)
|
||||
return removed
|
||||
@@ -335,11 +411,10 @@ class UploadCache:
|
||||
"""
|
||||
if provider not in self._provider_keys:
|
||||
return False
|
||||
|
||||
for key in list(self._provider_keys[provider]):
|
||||
cached = await self._cache.get(key)
|
||||
if isinstance(cached, CachedUpload) and cached.file_id == file_id:
|
||||
await self._cache.delete(key)
|
||||
cached = await self._backend.get(key)
|
||||
if cached is not None and cached.file_id == file_id:
|
||||
await self._backend.delete(key)
|
||||
self._untrack_key(provider, key)
|
||||
return True
|
||||
return False
|
||||
@@ -351,17 +426,13 @@ class UploadCache:
|
||||
Number of entries removed.
|
||||
"""
|
||||
removed = 0
|
||||
|
||||
for provider, keys in list(self._provider_keys.items()):
|
||||
for key in list(keys):
|
||||
cached = await self._cache.get(key)
|
||||
if cached is None or (
|
||||
isinstance(cached, CachedUpload) and cached.is_expired()
|
||||
):
|
||||
await self._cache.delete(key)
|
||||
cached = await self._backend.get(key)
|
||||
if cached is None or cached.is_expired():
|
||||
await self._backend.delete(key)
|
||||
self._untrack_key(provider, key)
|
||||
removed += 1
|
||||
|
||||
if removed > 0:
|
||||
logger.debug(f"Cleared {removed} expired cache entries")
|
||||
return removed
|
||||
@@ -373,9 +444,12 @@ class UploadCache:
|
||||
Number of entries cleared.
|
||||
"""
|
||||
count = sum(len(keys) for keys in self._provider_keys.values())
|
||||
await self._cache.clear(namespace=self.namespace)
|
||||
# Delete all tracked keys individually (works for all backends)
|
||||
for keys in self._provider_keys.values():
|
||||
for key in keys:
|
||||
await self._backend.delete(key)
|
||||
self._provider_keys.clear()
|
||||
|
||||
self._key_access_order.clear()
|
||||
if count > 0:
|
||||
logger.debug(f"Cleared {count} cache entries")
|
||||
return count
|
||||
@@ -391,14 +465,17 @@ class UploadCache:
|
||||
"""
|
||||
if provider not in self._provider_keys:
|
||||
return []
|
||||
|
||||
results: list[CachedUpload] = []
|
||||
for key in list(self._provider_keys[provider]):
|
||||
cached = await self._cache.get(key)
|
||||
if isinstance(cached, CachedUpload) and not cached.is_expired():
|
||||
cached = await self._backend.get(key)
|
||||
if cached is not None and not cached.is_expired():
|
||||
results.append(cached)
|
||||
return results
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Sync wrappers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _run_sync(coro: Any) -> Any:
|
||||
"""Run an async coroutine from sync context without blocking event loop."""
|
||||
@@ -489,11 +566,7 @@ class UploadCache:
|
||||
return sum(len(keys) for keys in self._provider_keys.values())
|
||||
|
||||
def get_providers(self) -> builtins.set[ProviderType]:
|
||||
"""Get all provider names that have cached entries.
|
||||
|
||||
Returns:
|
||||
Set of provider names.
|
||||
"""
|
||||
"""Get all provider names that have cached entries."""
|
||||
return builtins.set(self._provider_keys.keys())
|
||||
|
||||
|
||||
@@ -506,17 +579,7 @@ def get_upload_cache(
|
||||
cache_type: str = "memory",
|
||||
**cache_kwargs: Any,
|
||||
) -> UploadCache:
|
||||
"""Get or create the default upload cache.
|
||||
|
||||
Args:
|
||||
ttl: Default TTL in seconds.
|
||||
namespace: Cache namespace.
|
||||
cache_type: Backend type ("memory" or "redis").
|
||||
**cache_kwargs: Additional args for cache backend.
|
||||
|
||||
Returns:
|
||||
The upload cache instance.
|
||||
"""
|
||||
"""Get or create the default upload cache."""
|
||||
global _default_cache
|
||||
if _default_cache is None:
|
||||
_default_cache = UploadCache(
|
||||
|
||||
@@ -110,6 +110,9 @@ file-processing = [
|
||||
qdrant-edge = [
|
||||
"qdrant-edge-py>=0.6.0",
|
||||
]
|
||||
valkey = [
|
||||
"valkey-glide>=1.3.0",
|
||||
]
|
||||
|
||||
|
||||
[tool.uv]
|
||||
|
||||
@@ -13,8 +13,12 @@ from types import MethodType
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from a2a.client.errors import A2AClientHTTPError
|
||||
from a2a.types import AgentCapabilities, AgentCard, AgentSkill
|
||||
from aiocache import cached # type: ignore[import-untyped]
|
||||
from a2a.types import (
|
||||
AgentCapabilities,
|
||||
AgentCard,
|
||||
AgentSkill,
|
||||
)
|
||||
from aiocache import cached, caches # type: ignore[import-untyped]
|
||||
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
|
||||
import httpx
|
||||
|
||||
@@ -32,6 +36,7 @@ from crewai.events.types.a2a_events import (
|
||||
A2AAuthenticationFailedEvent,
|
||||
A2AConnectionErrorEvent,
|
||||
)
|
||||
from crewai.utilities.cache_config import get_aiocache_config
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -40,6 +45,18 @@ if TYPE_CHECKING:
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
_cache_configured = False
|
||||
|
||||
|
||||
def _ensure_cache_configured() -> None:
|
||||
"""Configure aiocache on first use (lazy initialization)."""
|
||||
global _cache_configured
|
||||
if _cache_configured:
|
||||
return
|
||||
caches.set_config(get_aiocache_config())
|
||||
_cache_configured = True
|
||||
|
||||
|
||||
def _get_tls_verify(auth: ClientAuthScheme | None) -> ssl.SSLContext | bool | str:
|
||||
"""Get TLS verify parameter from auth scheme.
|
||||
|
||||
@@ -191,6 +208,7 @@ async def afetch_agent_card(
|
||||
else:
|
||||
auth_hash = _auth_store.compute_key("none", "")
|
||||
_auth_store.set(auth_hash, auth)
|
||||
_ensure_cache_configured()
|
||||
agent_card: AgentCard = await _afetch_agent_card_cached(
|
||||
endpoint, auth_hash, timeout
|
||||
)
|
||||
|
||||
@@ -9,9 +9,8 @@ from datetime import datetime
|
||||
from functools import wraps
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from a2a.server.agent_execution import RequestContext
|
||||
from a2a.server.events import EventQueue
|
||||
@@ -38,7 +37,6 @@ from a2a.utils import (
|
||||
from a2a.utils.errors import ServerError
|
||||
from aiocache import SimpleMemoryCache, caches # type: ignore[import-untyped]
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from crewai.a2a.utils.agent_card import _get_server_config
|
||||
from crewai.a2a.utils.content_type import validate_message_parts
|
||||
@@ -50,12 +48,18 @@ from crewai.events.types.a2a_events import (
|
||||
A2AServerTaskStartedEvent,
|
||||
)
|
||||
from crewai.task import Task
|
||||
from crewai.utilities.cache_config import (
|
||||
get_aiocache_config,
|
||||
parse_cache_url,
|
||||
use_valkey_cache,
|
||||
)
|
||||
from crewai.utilities.pydantic_schema_utils import create_model_from_schema
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.a2a.extensions.server import ExtensionContext, ServerExtensionRegistry
|
||||
from crewai.agent import Agent
|
||||
from crewai.memory.storage.valkey_cache import ValkeyCache
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -64,52 +68,61 @@ P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class RedisCacheConfig(TypedDict, total=False):
|
||||
"""Configuration for aiocache Redis backend."""
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lazy cache initialisation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
cache: str
|
||||
endpoint: str
|
||||
port: int
|
||||
db: int
|
||||
password: str
|
||||
_task_cache: ValkeyCache | None = None
|
||||
_lazy_init_complete = False
|
||||
_cache_init_lock = threading.Lock()
|
||||
|
||||
# Cancellation polling interval in seconds.
|
||||
_CANCEL_POLL_INTERVAL = 0.1
|
||||
|
||||
# Configure aiocache at import time (matches upstream behaviour).
|
||||
# This is safe — it only touches aiocache, no optional dependencies.
|
||||
# The Valkey path is deferred to _ensure_task_cache() to avoid importing
|
||||
# valkey-glide at module level (it may not be installed).
|
||||
if not use_valkey_cache():
|
||||
caches.set_config(get_aiocache_config())
|
||||
|
||||
|
||||
def _parse_redis_url(url: str) -> RedisCacheConfig:
|
||||
"""Parse a Redis URL into aiocache configuration.
|
||||
def _ensure_task_cache() -> None:
|
||||
"""Initialise the Valkey task cache on first use (thread-safe).
|
||||
|
||||
Args:
|
||||
url: Redis connection URL (e.g., redis://localhost:6379/0).
|
||||
|
||||
Returns:
|
||||
Configuration dict for aiocache.RedisCache.
|
||||
For the aiocache path, configuration happens at module level above.
|
||||
This function only needs to run for the Valkey path.
|
||||
"""
|
||||
parsed = urlparse(url)
|
||||
config: RedisCacheConfig = {
|
||||
"cache": "aiocache.RedisCache",
|
||||
"endpoint": parsed.hostname or "localhost",
|
||||
"port": parsed.port or 6379,
|
||||
}
|
||||
if parsed.path and parsed.path != "/":
|
||||
try:
|
||||
config["db"] = int(parsed.path.lstrip("/"))
|
||||
except ValueError:
|
||||
pass
|
||||
if parsed.password:
|
||||
config["password"] = parsed.password
|
||||
return config
|
||||
global _task_cache, _lazy_init_complete
|
||||
if _lazy_init_complete:
|
||||
return
|
||||
|
||||
with _cache_init_lock:
|
||||
if _lazy_init_complete:
|
||||
return
|
||||
|
||||
_redis_url = os.environ.get("REDIS_URL")
|
||||
if use_valkey_cache():
|
||||
from crewai.memory.storage.valkey_cache import ValkeyCache
|
||||
|
||||
caches.set_config(
|
||||
{
|
||||
"default": _parse_redis_url(_redis_url)
|
||||
if _redis_url
|
||||
else {
|
||||
"cache": "aiocache.SimpleMemoryCache",
|
||||
}
|
||||
}
|
||||
)
|
||||
conn = parse_cache_url() or {}
|
||||
try:
|
||||
_task_cache = ValkeyCache(
|
||||
host=conn.get("host", "localhost"),
|
||||
port=conn.get("port", 6379),
|
||||
db=conn.get("db", 0),
|
||||
password=conn.get("password"),
|
||||
default_ttl=3600,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to initialize ValkeyCache for task cancellation, "
|
||||
"falling back to aiocache",
|
||||
extra={"error": str(e)},
|
||||
)
|
||||
caches.set_config(get_aiocache_config())
|
||||
_task_cache = None
|
||||
|
||||
_lazy_init_complete = True
|
||||
|
||||
|
||||
def cancellable(
|
||||
@@ -130,6 +143,8 @@ def cancellable(
|
||||
@wraps(fn)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
"""Wrap function with cancellation monitoring."""
|
||||
_ensure_task_cache()
|
||||
|
||||
context: RequestContext | None = None
|
||||
for arg in args:
|
||||
if isinstance(arg, RequestContext):
|
||||
@@ -142,19 +157,34 @@ def cancellable(
|
||||
return await fn(*args, **kwargs)
|
||||
|
||||
task_id = context.task_id
|
||||
cache = caches.get("default")
|
||||
|
||||
async def poll_for_cancel() -> bool:
|
||||
"""Poll cache for cancellation flag."""
|
||||
async def poll_for_cancel_valkey() -> bool:
|
||||
"""Poll ValkeyCache for cancellation flag."""
|
||||
while True:
|
||||
if _task_cache is not None and await _task_cache.get(
|
||||
f"cancel:{task_id}"
|
||||
):
|
||||
return True
|
||||
await asyncio.sleep(_CANCEL_POLL_INTERVAL)
|
||||
|
||||
async def poll_for_cancel_aiocache() -> bool:
|
||||
"""Poll aiocache for cancellation flag."""
|
||||
cache = caches.get("default")
|
||||
while True:
|
||||
if await cache.get(f"cancel:{task_id}"):
|
||||
return True
|
||||
await asyncio.sleep(0.1)
|
||||
await asyncio.sleep(_CANCEL_POLL_INTERVAL)
|
||||
|
||||
async def watch_for_cancel() -> bool:
|
||||
"""Watch for cancellation events via pub/sub or polling."""
|
||||
if _task_cache is not None:
|
||||
# ValkeyCache: use polling (pub/sub not implemented yet)
|
||||
return await poll_for_cancel_valkey()
|
||||
|
||||
# aiocache: use pub/sub if Redis, otherwise poll
|
||||
cache = caches.get("default")
|
||||
if isinstance(cache, SimpleMemoryCache):
|
||||
return await poll_for_cancel()
|
||||
return await poll_for_cancel_aiocache()
|
||||
|
||||
try:
|
||||
client = cache.client
|
||||
@@ -168,7 +198,7 @@ def cancellable(
|
||||
"Cancel watcher Redis error, falling back to polling",
|
||||
extra={"task_id": task_id, "error": str(e)},
|
||||
)
|
||||
return await poll_for_cancel()
|
||||
return await poll_for_cancel_aiocache()
|
||||
return False
|
||||
|
||||
execute_task = asyncio.create_task(fn(*args, **kwargs))
|
||||
@@ -190,7 +220,12 @@ def cancellable(
|
||||
cancel_watch.cancel()
|
||||
return execute_task.result()
|
||||
finally:
|
||||
await cache.delete(f"cancel:{task_id}")
|
||||
# Clean up cancellation flag
|
||||
if _task_cache is not None:
|
||||
await _task_cache.delete(f"cancel:{task_id}")
|
||||
else:
|
||||
cache = caches.get("default")
|
||||
await cache.delete(f"cancel:{task_id}")
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -475,6 +510,8 @@ async def cancel(
|
||||
if task_id is None or context_id is None:
|
||||
raise ServerError(InvalidParamsError(message="task_id and context_id required"))
|
||||
|
||||
_ensure_task_cache()
|
||||
|
||||
if context.current_task and context.current_task.status.state in (
|
||||
TaskState.completed,
|
||||
TaskState.failed,
|
||||
@@ -482,11 +519,16 @@ async def cancel(
|
||||
):
|
||||
return context.current_task
|
||||
|
||||
cache = caches.get("default")
|
||||
|
||||
await cache.set(f"cancel:{task_id}", True, ttl=3600)
|
||||
if not isinstance(cache, SimpleMemoryCache):
|
||||
await cache.client.publish(f"cancel:{task_id}", "cancel")
|
||||
if _task_cache is not None:
|
||||
# Use ValkeyCache
|
||||
await _task_cache.set(f"cancel:{task_id}", True, ttl=3600)
|
||||
# Note: pub/sub not implemented for ValkeyCache yet, relies on polling
|
||||
else:
|
||||
# Use aiocache
|
||||
cache = caches.get("default")
|
||||
await cache.set(f"cancel:{task_id}", True, ttl=3600)
|
||||
if not isinstance(cache, SimpleMemoryCache):
|
||||
await cache.client.publish(f"cancel:{task_id}", "cancel")
|
||||
|
||||
await event_queue.enqueue_event(
|
||||
TaskStatusUpdateEvent(
|
||||
|
||||
@@ -18,7 +18,7 @@ import math
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from crewai.flow.flow import Flow, listen, start
|
||||
from crewai.memory.analyze import (
|
||||
@@ -68,6 +68,27 @@ class ItemState(BaseModel):
|
||||
plan: ConsolidationPlan | None = None
|
||||
result_record: MemoryRecord | None = None
|
||||
|
||||
@field_validator("similar_records", "result_record", mode="before")
|
||||
@classmethod
|
||||
def ensure_embedding_is_list(cls, v: Any) -> Any:
|
||||
"""Ensure MemoryRecord embeddings are list[float], not bytes.
|
||||
|
||||
Delegates to MemoryRecord.validate_embedding for consistent behavior
|
||||
(e.g. empty bytes → None).
|
||||
"""
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, list):
|
||||
for record in v:
|
||||
if isinstance(record, MemoryRecord) and isinstance(
|
||||
record.embedding, bytes
|
||||
):
|
||||
record.embedding = MemoryRecord.validate_embedding(record.embedding)
|
||||
return v
|
||||
if isinstance(v, MemoryRecord) and isinstance(v.embedding, bytes):
|
||||
v.embedding = MemoryRecord.validate_embedding(v.embedding)
|
||||
return v
|
||||
|
||||
|
||||
class EncodingState(BaseModel):
|
||||
"""Batch-level state for the encoding flow."""
|
||||
|
||||
198
lib/crewai/src/crewai/memory/storage/valkey_cache.py
Normal file
198
lib/crewai/src/crewai/memory/storage/valkey_cache.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""Valkey-based cache implementation for CrewAI.
|
||||
|
||||
This module provides a simple cache interface using Valkey-GLIDE client
|
||||
for caching operations with optional TTL support. It replaces Redis usage
|
||||
in A2A communication, file uploads, and agent card caching.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from glide import GlideClient, GlideClientConfiguration, NodeAddress
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ValkeyCache:
|
||||
"""Simple cache interface using Valkey-GLIDE client.
|
||||
|
||||
Provides get/set/delete/exists operations for caching with optional TTL.
|
||||
Uses JSON serialization for complex values and lazy client initialization.
|
||||
|
||||
Example:
|
||||
>>> cache = ValkeyCache(host="localhost", port=6379)
|
||||
>>> await cache.set("key", {"data": "value"}, ttl=3600)
|
||||
>>> value = await cache.get("key")
|
||||
>>> await cache.delete("key")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str = "localhost",
|
||||
port: int = 6379,
|
||||
db: int = 0,
|
||||
password: str | None = None,
|
||||
default_ttl: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize Valkey cache.
|
||||
|
||||
Args:
|
||||
host: Valkey server hostname.
|
||||
port: Valkey server port.
|
||||
db: Database number to use.
|
||||
password: Optional password for authentication.
|
||||
default_ttl: Default TTL in seconds (None = no expiration).
|
||||
"""
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._db = db
|
||||
self._password = password
|
||||
self._default_ttl = default_ttl
|
||||
self._client: GlideClient | None = None
|
||||
|
||||
async def _get_client(self) -> GlideClient:
|
||||
"""Get or create Valkey client (lazy initialization).
|
||||
|
||||
Returns:
|
||||
Initialized GlideClient instance.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If connection to Valkey fails.
|
||||
TimeoutError: If connection attempt times out (10 seconds).
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
if self._client is None:
|
||||
host = self._host
|
||||
port = self._port
|
||||
db = self._db
|
||||
try:
|
||||
from glide import ServerCredentials
|
||||
|
||||
config = GlideClientConfiguration(
|
||||
addresses=[NodeAddress(host, port)],
|
||||
database_id=db,
|
||||
credentials=(
|
||||
ServerCredentials(password=self._password)
|
||||
if self._password
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
# Add connection timeout (10 seconds)
|
||||
try:
|
||||
self._client = await asyncio.wait_for(
|
||||
GlideClient.create(config), timeout=10.0
|
||||
)
|
||||
except asyncio.TimeoutError as e:
|
||||
_logger.error("Connection timeout connecting to Valkey")
|
||||
raise TimeoutError(
|
||||
"Connection timeout to Valkey. "
|
||||
"Ensure Valkey is running and accessible."
|
||||
) from e
|
||||
|
||||
_logger.info("Valkey cache client initialized")
|
||||
except (TimeoutError, RuntimeError):
|
||||
raise
|
||||
except Exception as e:
|
||||
_logger.error(
|
||||
"Failed to create Valkey cache client: %s", type(e).__name__
|
||||
)
|
||||
raise RuntimeError(
|
||||
"Cannot connect to Valkey. Check connection settings."
|
||||
) from e
|
||||
|
||||
return self._client
|
||||
|
||||
async def get(self, key: str) -> Any | None:
|
||||
"""Get value from cache.
|
||||
|
||||
Args:
|
||||
key: Cache key.
|
||||
|
||||
Returns:
|
||||
Cached value (deserialized from JSON) or None if not found.
|
||||
"""
|
||||
client = await self._get_client()
|
||||
value = await client.get(key)
|
||||
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
_logger.warning(f"Failed to deserialize cached value for key: {key}")
|
||||
return None
|
||||
|
||||
async def set(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
ttl: int | None = None,
|
||||
) -> None:
|
||||
"""Set value in cache.
|
||||
|
||||
Args:
|
||||
key: Cache key.
|
||||
value: Value to cache (will be serialized to JSON).
|
||||
ttl: TTL in seconds (None uses default_ttl, 0 = no expiration).
|
||||
|
||||
Raises:
|
||||
TypeError: If value is not JSON-serializable.
|
||||
"""
|
||||
from glide import ExpirySet, ExpiryType
|
||||
|
||||
client = await self._get_client()
|
||||
try:
|
||||
serialized = json.dumps(value)
|
||||
except (TypeError, ValueError) as e:
|
||||
_logger.error("Cannot serialize value for key %r: %s", key, e)
|
||||
raise TypeError(
|
||||
f"Value for cache key {key!r} is not JSON-serializable: {e}"
|
||||
) from e
|
||||
|
||||
ttl_to_use = ttl if ttl is not None else self._default_ttl
|
||||
|
||||
if ttl_to_use and ttl_to_use > 0:
|
||||
# Set with expiration using SET command with EX option
|
||||
await client.set(
|
||||
key,
|
||||
serialized,
|
||||
expiry=ExpirySet(ExpiryType.SEC, ttl_to_use),
|
||||
)
|
||||
else:
|
||||
await client.set(key, serialized)
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
"""Delete value from cache.
|
||||
|
||||
Args:
|
||||
key: Cache key to delete.
|
||||
"""
|
||||
client = await self._get_client()
|
||||
await client.delete([key])
|
||||
|
||||
async def exists(self, key: str) -> bool:
|
||||
"""Check if key exists in cache.
|
||||
|
||||
Args:
|
||||
key: Cache key to check.
|
||||
|
||||
Returns:
|
||||
True if key exists, False otherwise.
|
||||
"""
|
||||
client = await self._get_client()
|
||||
result = await client.exists([key])
|
||||
return result > 0
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close Valkey client connection."""
|
||||
if self._client:
|
||||
await self._client.close()
|
||||
self._client = None
|
||||
_logger.debug("Valkey cache client closed")
|
||||
1893
lib/crewai/src/crewai/memory/storage/valkey_storage.py
Normal file
1893
lib/crewai/src/crewai/memory/storage/valkey_storage.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -2,13 +2,17 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import concurrent.futures
|
||||
from datetime import datetime
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
# When searching the vector store, we ask for more results than the caller
|
||||
# requested so that post-search steps (composite scoring, deduplication,
|
||||
# category filtering) have enough candidates to fill the final result set.
|
||||
@@ -57,6 +61,23 @@ class MemoryRecord(BaseModel):
|
||||
repr=False,
|
||||
description="Vector embedding for semantic search. Excluded from serialization to save tokens.",
|
||||
)
|
||||
|
||||
@field_validator("embedding", mode="before")
|
||||
@classmethod
|
||||
def validate_embedding(cls, v: Any) -> list[float] | None:
|
||||
"""Ensure embedding is always list[float] or None, never bytes."""
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, bytes):
|
||||
# Convert bytes to list[float] if needed
|
||||
import numpy as np
|
||||
|
||||
if len(v) == 0:
|
||||
return None
|
||||
arr = np.frombuffer(v, dtype=np.float32)
|
||||
return [float(x) for x in arr]
|
||||
return [float(x) for x in v]
|
||||
|
||||
source: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
@@ -304,7 +325,11 @@ def embed_text(embedder: Any, text: str) -> list[float]:
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
return []
|
||||
|
||||
# Just call the embedder directly - the blocking issue needs to be fixed
|
||||
# at a higher level (making Memory.recall() async)
|
||||
result = embedder([text])
|
||||
|
||||
if not result:
|
||||
return []
|
||||
first = result[0]
|
||||
@@ -315,12 +340,27 @@ def embed_text(embedder: Any, text: str) -> list[float]:
|
||||
return list(first)
|
||||
|
||||
|
||||
# Reusable thread pool for running embedder calls from sync context
|
||||
# when an async event loop is already running. Uses max_workers=2 so
|
||||
# a single slow/timed-out call doesn't block subsequent embeds.
|
||||
_EMBED_POOL = concurrent.futures.ThreadPoolExecutor(max_workers=2)
|
||||
|
||||
|
||||
def embed_texts(embedder: Any, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed multiple texts in a single API call.
|
||||
|
||||
The embedder already accepts ``list[str]``, so this just calls it once
|
||||
with the full batch and normalises the output format.
|
||||
|
||||
When called from an async context, offloads the embedder to a thread pool
|
||||
so the embedding work doesn't run on the event loop thread. The calling
|
||||
thread still blocks on the result (unavoidable for a sync function), but
|
||||
this prevents the embedder from starving the event loop's I/O callbacks.
|
||||
The pool uses ``max_workers=2`` so a single timed-out call doesn't block
|
||||
subsequent embeds.
|
||||
|
||||
Note: the proper long-term fix is making ``Memory.recall()`` async.
|
||||
|
||||
Args:
|
||||
embedder: Callable that accepts a list of strings and returns embeddings.
|
||||
texts: List of texts to embed.
|
||||
@@ -328,6 +368,8 @@ def embed_texts(embedder: Any, texts: list[str]) -> list[list[float]]:
|
||||
Returns:
|
||||
List of embeddings, one per input text. Empty texts produce empty lists.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
if not texts:
|
||||
return []
|
||||
# Filter out empty texts, remembering their positions
|
||||
@@ -337,7 +379,28 @@ def embed_texts(embedder: Any, texts: list[str]) -> list[list[float]]:
|
||||
if not valid:
|
||||
return [[] for _ in texts]
|
||||
|
||||
result = embedder([t for _, t in valid])
|
||||
texts_to_embed = [t for _, t in valid]
|
||||
|
||||
# Check if we're in an async context
|
||||
result: Any
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
# We're in an async context but this is a sync function.
|
||||
# Offload to thread pool so the embedder doesn't run on the
|
||||
# event loop thread. The .result() call blocks this thread
|
||||
# (acceptable — callers like Memory.recall() are sync).
|
||||
try:
|
||||
result = _EMBED_POOL.submit(embedder, texts_to_embed).result(timeout=30)
|
||||
except concurrent.futures.TimeoutError:
|
||||
_logger.warning(
|
||||
"Embedder timed out after 30s, returning empty embeddings. "
|
||||
"The worker thread may still be running."
|
||||
)
|
||||
return [[] for _ in texts]
|
||||
except RuntimeError:
|
||||
# Not in async context, run directly
|
||||
result = embedder(texts_to_embed)
|
||||
|
||||
embeddings: list[list[float]] = [[] for _ in texts]
|
||||
for (orig_idx, _), emb in zip(valid, result, strict=False):
|
||||
if hasattr(emb, "tolist"):
|
||||
|
||||
@@ -2,9 +2,11 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import concurrent.futures
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
import contextvars
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Literal
|
||||
@@ -36,6 +38,9 @@ from crewai.rag.embeddings.factory import build_embedder
|
||||
from crewai.rag.embeddings.providers.openai.types import OpenAIProviderSpec
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
@@ -211,6 +216,18 @@ class Memory(BaseModel):
|
||||
from crewai.memory.storage.lancedb_storage import LanceDBStorage
|
||||
|
||||
self._storage = LanceDBStorage()
|
||||
elif self.storage == "valkey":
|
||||
from crewai.memory.storage.valkey_storage import ValkeyStorage
|
||||
from crewai.utilities.cache_config import parse_cache_url
|
||||
|
||||
conn = parse_cache_url() or {}
|
||||
self._storage = ValkeyStorage(
|
||||
host=conn.get("host", "localhost"),
|
||||
port=conn.get("port", 6379),
|
||||
db=conn.get("db", 0),
|
||||
password=conn.get("password"),
|
||||
use_tls=conn.get("use_tls", False),
|
||||
)
|
||||
else:
|
||||
from crewai.memory.storage.lancedb_storage import LanceDBStorage
|
||||
|
||||
@@ -316,16 +333,60 @@ class Memory(BaseModel):
|
||||
except Exception: # noqa: S110
|
||||
pass # swallow everything during shutdown
|
||||
|
||||
def drain_writes(self) -> None:
|
||||
def drain_writes(self, timeout_per_save: float = 60.0) -> None:
|
||||
"""Block until all pending background saves have completed.
|
||||
|
||||
Called automatically by ``recall()`` and should be called by the
|
||||
crew at shutdown to ensure no saves are lost.
|
||||
|
||||
Args:
|
||||
timeout_per_save: Maximum seconds to wait per save operation.
|
||||
Default 60s. If a save times out, logs warning
|
||||
but continues to avoid blocking crew completion.
|
||||
"""
|
||||
with self._pending_lock:
|
||||
pending = list(self._pending_saves)
|
||||
for future in pending:
|
||||
future.result() # blocks until done; re-raises exceptions
|
||||
|
||||
if pending:
|
||||
_logger.debug(
|
||||
"[DRAIN_WRITES] Waiting for %d pending saves...", len(pending)
|
||||
)
|
||||
|
||||
failed_saves = 0
|
||||
for i, future in enumerate(pending):
|
||||
try:
|
||||
_logger.debug(
|
||||
"[DRAIN_WRITES] Waiting for save %d/%d...", i + 1, len(pending)
|
||||
)
|
||||
future.result(timeout=timeout_per_save)
|
||||
_logger.debug(
|
||||
"[DRAIN_WRITES] Save %d/%d completed", i + 1, len(pending)
|
||||
)
|
||||
except (TimeoutError, concurrent.futures.TimeoutError): # noqa: PERF203
|
||||
failed_saves += 1
|
||||
_logger.warning(
|
||||
"[DRAIN_WRITES] Save %d/%d timed out after %ss. "
|
||||
"This save will be abandoned. Consider increasing timeout or checking "
|
||||
"LLM/embedder performance.",
|
||||
i + 1,
|
||||
len(pending),
|
||||
timeout_per_save,
|
||||
)
|
||||
# Don't raise - just log and continue to avoid blocking crew completion
|
||||
except Exception as e:
|
||||
failed_saves += 1
|
||||
_logger.error(
|
||||
"[DRAIN_WRITES] Save %d/%d failed: %s", i + 1, len(pending), e
|
||||
)
|
||||
# Don't raise - just log and continue
|
||||
|
||||
if failed_saves > 0:
|
||||
_logger.warning(
|
||||
"[DRAIN_WRITES] %d/%d saves failed or timed out. "
|
||||
"Some memories may not have been persisted.",
|
||||
failed_saves,
|
||||
len(pending),
|
||||
)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Drain pending saves, flush storage, and shut down the background thread pool."""
|
||||
|
||||
78
lib/crewai/src/crewai/utilities/cache_config.py
Normal file
78
lib/crewai/src/crewai/utilities/cache_config.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Shared cache configuration helpers for Valkey/Redis URL parsing."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_cache_url() -> dict[str, Any] | None:
|
||||
"""Parse VALKEY_URL or REDIS_URL from environment.
|
||||
|
||||
Priority: VALKEY_URL > REDIS_URL.
|
||||
|
||||
Returns:
|
||||
Dict with host, port, db, password keys, or None if no URL is set.
|
||||
"""
|
||||
url = os.environ.get("VALKEY_URL") or os.environ.get("REDIS_URL")
|
||||
if not url:
|
||||
return None
|
||||
parsed = urlparse(url)
|
||||
return {
|
||||
"host": parsed.hostname or "localhost",
|
||||
"port": parsed.port or 6379,
|
||||
"db": _parse_db_from_path(parsed.path),
|
||||
"password": parsed.password,
|
||||
"use_tls": parsed.scheme in ("rediss", "valkeys"),
|
||||
}
|
||||
|
||||
|
||||
def _parse_db_from_path(path: str | None) -> int:
|
||||
"""Parse database number from URL path, defaulting to 0."""
|
||||
if not path or path == "/":
|
||||
return 0
|
||||
try:
|
||||
return int(path.lstrip("/"))
|
||||
except ValueError:
|
||||
_logger.warning(
|
||||
"Invalid database number in URL path: %s, using default 0", path
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
def get_aiocache_config() -> dict[str, Any]:
|
||||
"""Build an aiocache configuration dict from environment.
|
||||
|
||||
Uses VALKEY_URL or REDIS_URL (both are Redis-wire-compatible) to
|
||||
configure ``aiocache.RedisCache``. Falls back to
|
||||
``aiocache.SimpleMemoryCache`` when neither variable is set.
|
||||
|
||||
Returns:
|
||||
Configuration dict suitable for ``aiocache.caches.set_config()``.
|
||||
"""
|
||||
conn = parse_cache_url()
|
||||
if conn is not None:
|
||||
return {
|
||||
"default": {
|
||||
"cache": "aiocache.RedisCache",
|
||||
"endpoint": conn["host"],
|
||||
"port": conn["port"],
|
||||
"db": conn.get("db", 0),
|
||||
"password": conn.get("password"),
|
||||
}
|
||||
}
|
||||
return {
|
||||
"default": {
|
||||
"cache": "aiocache.SimpleMemoryCache",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def use_valkey_cache() -> bool:
|
||||
"""Return True if VALKEY_URL is set in the environment."""
|
||||
return bool(os.environ.get("VALKEY_URL"))
|
||||
511
lib/crewai/tests/memory/storage/test_valkey_cache.py
Normal file
511
lib/crewai/tests/memory/storage/test_valkey_cache.py
Normal file
@@ -0,0 +1,511 @@
|
||||
"""Tests for ValkeyCache implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.memory.storage.valkey_cache import ValkeyCache
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_glide_client() -> AsyncMock:
|
||||
"""Create a mock GlideClient for testing."""
|
||||
client = AsyncMock()
|
||||
client.get = AsyncMock()
|
||||
client.set = AsyncMock()
|
||||
client.delete = AsyncMock()
|
||||
client.exists = AsyncMock()
|
||||
client.close = AsyncMock()
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def valkey_cache(mock_glide_client: AsyncMock) -> ValkeyCache:
|
||||
"""Create a ValkeyCache instance with mocked client."""
|
||||
cache = ValkeyCache(host="localhost", port=6379, db=0)
|
||||
|
||||
# Mock the client creation to return our mock
|
||||
async def mock_create_client() -> AsyncMock:
|
||||
cache._client = mock_glide_client
|
||||
return mock_glide_client
|
||||
|
||||
cache._get_client = mock_create_client # type: ignore[method-assign]
|
||||
return cache
|
||||
|
||||
|
||||
class TestValkeyCacheBasicOperations:
|
||||
"""Tests for basic ValkeyCache operations (get/set/delete/exists)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_and_get_string_value(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test setting and getting a string value."""
|
||||
# Mock get to return serialized string
|
||||
mock_glide_client.get.return_value = json.dumps("test_value")
|
||||
|
||||
# Set value
|
||||
await valkey_cache.set("test_key", "test_value")
|
||||
|
||||
# Verify set was called
|
||||
mock_glide_client.set.assert_called_once()
|
||||
call_args = mock_glide_client.set.call_args
|
||||
assert call_args[0][0] == "test_key"
|
||||
assert call_args[0][1] == json.dumps("test_value")
|
||||
|
||||
# Get value
|
||||
result = await valkey_cache.get("test_key")
|
||||
|
||||
# Verify get was called and result is correct
|
||||
mock_glide_client.get.assert_called_once_with("test_key")
|
||||
assert result == "test_value"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_and_get_dict_value(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test setting and getting a dictionary value."""
|
||||
test_dict = {"key1": "value1", "key2": 42, "key3": [1, 2, 3]}
|
||||
mock_glide_client.get.return_value = json.dumps(test_dict)
|
||||
|
||||
# Set value
|
||||
await valkey_cache.set("dict_key", test_dict)
|
||||
|
||||
# Verify set was called with serialized dict
|
||||
mock_glide_client.set.assert_called_once()
|
||||
call_args = mock_glide_client.set.call_args
|
||||
assert call_args[0][0] == "dict_key"
|
||||
assert call_args[0][1] == json.dumps(test_dict)
|
||||
|
||||
# Get value
|
||||
result = await valkey_cache.get("dict_key")
|
||||
|
||||
# Verify result matches original dict
|
||||
assert result == test_dict
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_and_get_list_value(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test setting and getting a list value."""
|
||||
test_list = [1, "two", 3.0, {"nested": "dict"}]
|
||||
mock_glide_client.get.return_value = json.dumps(test_list)
|
||||
|
||||
await valkey_cache.set("list_key", test_list)
|
||||
result = await valkey_cache.get("list_key")
|
||||
|
||||
assert result == test_list
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_nonexistent_key_returns_none(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test getting a non-existent key returns None."""
|
||||
mock_glide_client.get.return_value = None
|
||||
|
||||
result = await valkey_cache.get("nonexistent_key")
|
||||
|
||||
assert result is None
|
||||
mock_glide_client.get.assert_called_once_with("nonexistent_key")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_key(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test deleting a key."""
|
||||
await valkey_cache.delete("test_key")
|
||||
|
||||
mock_glide_client.delete.assert_called_once_with(["test_key"])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exists_returns_true_for_existing_key(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test exists returns True for existing key."""
|
||||
mock_glide_client.exists.return_value = 1
|
||||
|
||||
result = await valkey_cache.exists("existing_key")
|
||||
|
||||
assert result is True
|
||||
mock_glide_client.exists.assert_called_once_with(["existing_key"])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exists_returns_false_for_nonexistent_key(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test exists returns False for non-existent key."""
|
||||
mock_glide_client.exists.return_value = 0
|
||||
|
||||
result = await valkey_cache.exists("nonexistent_key")
|
||||
|
||||
assert result is False
|
||||
mock_glide_client.exists.assert_called_once_with(["nonexistent_key"])
|
||||
|
||||
|
||||
class TestValkeyCacheTTL:
|
||||
"""Tests for ValkeyCache TTL functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_with_explicit_ttl(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test setting a value with explicit TTL."""
|
||||
await valkey_cache.set("ttl_key", "value", ttl=3600)
|
||||
|
||||
# Verify set was called with expiry
|
||||
mock_glide_client.set.assert_called_once()
|
||||
call_args = mock_glide_client.set.call_args
|
||||
assert call_args[0][0] == "ttl_key"
|
||||
assert call_args[0][1] == json.dumps("value")
|
||||
assert "expiry" in call_args[1]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_with_default_ttl(
|
||||
self, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test setting a value with default TTL from constructor."""
|
||||
cache = ValkeyCache(host="localhost", port=6379, default_ttl=1800)
|
||||
|
||||
async def mock_create_client() -> AsyncMock:
|
||||
cache._client = mock_glide_client
|
||||
return mock_glide_client
|
||||
|
||||
cache._get_client = mock_create_client # type: ignore[method-assign]
|
||||
|
||||
await cache.set("default_ttl_key", "value")
|
||||
|
||||
# Verify set was called with default TTL
|
||||
mock_glide_client.set.assert_called_once()
|
||||
call_args = mock_glide_client.set.call_args
|
||||
assert "expiry" in call_args[1]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_without_ttl(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test setting a value without TTL (no expiration)."""
|
||||
await valkey_cache.set("no_ttl_key", "value")
|
||||
|
||||
# Verify set was called without expiry
|
||||
mock_glide_client.set.assert_called_once()
|
||||
call_args = mock_glide_client.set.call_args
|
||||
assert call_args[0][0] == "no_ttl_key"
|
||||
assert call_args[0][1] == json.dumps("value")
|
||||
# Should not have expiry parameter
|
||||
assert "expiry" not in call_args[1] or call_args[1].get("expiry") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_with_zero_ttl_no_expiration(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test setting a value with TTL=0 means no expiration."""
|
||||
await valkey_cache.set("zero_ttl_key", "value", ttl=0)
|
||||
|
||||
# Verify set was called without expiry
|
||||
mock_glide_client.set.assert_called_once()
|
||||
call_args = mock_glide_client.set.call_args
|
||||
assert "expiry" not in call_args[1] or call_args[1].get("expiry") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_explicit_ttl_overrides_default(
|
||||
self, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test explicit TTL overrides default TTL."""
|
||||
cache = ValkeyCache(host="localhost", port=6379, default_ttl=1800)
|
||||
|
||||
async def mock_create_client() -> AsyncMock:
|
||||
cache._client = mock_glide_client
|
||||
return mock_glide_client
|
||||
|
||||
cache._get_client = mock_create_client # type: ignore[method-assign]
|
||||
|
||||
await cache.set("override_key", "value", ttl=7200)
|
||||
|
||||
# Verify set was called with explicit TTL (7200), not default (1800)
|
||||
mock_glide_client.set.assert_called_once()
|
||||
call_args = mock_glide_client.set.call_args
|
||||
assert "expiry" in call_args[1]
|
||||
|
||||
|
||||
class TestValkeyCacheJSONSerialization:
|
||||
"""Tests for ValkeyCache JSON serialization edge cases."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serialize_none_value(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test serializing None value."""
|
||||
mock_glide_client.get.return_value = json.dumps(None)
|
||||
|
||||
await valkey_cache.set("none_key", None)
|
||||
result = await valkey_cache.get("none_key")
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serialize_boolean_values(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test serializing boolean values."""
|
||||
mock_glide_client.get.side_effect = [
|
||||
json.dumps(True),
|
||||
json.dumps(False),
|
||||
]
|
||||
|
||||
await valkey_cache.set("true_key", True)
|
||||
await valkey_cache.set("false_key", False)
|
||||
|
||||
result_true = await valkey_cache.get("true_key")
|
||||
result_false = await valkey_cache.get("false_key")
|
||||
|
||||
assert result_true is True
|
||||
assert result_false is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serialize_numeric_values(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test serializing numeric values (int, float)."""
|
||||
mock_glide_client.get.side_effect = [
|
||||
json.dumps(42),
|
||||
json.dumps(3.14159),
|
||||
json.dumps(0),
|
||||
json.dumps(-100),
|
||||
]
|
||||
|
||||
await valkey_cache.set("int_key", 42)
|
||||
await valkey_cache.set("float_key", 3.14159)
|
||||
await valkey_cache.set("zero_key", 0)
|
||||
await valkey_cache.set("negative_key", -100)
|
||||
|
||||
assert await valkey_cache.get("int_key") == 42
|
||||
assert await valkey_cache.get("float_key") == 3.14159
|
||||
assert await valkey_cache.get("zero_key") == 0
|
||||
assert await valkey_cache.get("negative_key") == -100
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serialize_empty_collections(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test serializing empty collections."""
|
||||
mock_glide_client.get.side_effect = [
|
||||
json.dumps([]),
|
||||
json.dumps({}),
|
||||
json.dumps(""),
|
||||
]
|
||||
|
||||
await valkey_cache.set("empty_list", [])
|
||||
await valkey_cache.set("empty_dict", {})
|
||||
await valkey_cache.set("empty_string", "")
|
||||
|
||||
assert await valkey_cache.get("empty_list") == []
|
||||
assert await valkey_cache.get("empty_dict") == {}
|
||||
assert await valkey_cache.get("empty_string") == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serialize_nested_structures(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test serializing deeply nested structures."""
|
||||
nested_data = {
|
||||
"level1": {
|
||||
"level2": {
|
||||
"level3": [1, 2, {"level4": "deep"}]
|
||||
}
|
||||
},
|
||||
"list": [{"a": 1}, {"b": 2}]
|
||||
}
|
||||
mock_glide_client.get.return_value = json.dumps(nested_data)
|
||||
|
||||
await valkey_cache.set("nested_key", nested_data)
|
||||
result = await valkey_cache.get("nested_key")
|
||||
|
||||
assert result == nested_data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deserialize_invalid_json_returns_none(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test deserializing invalid JSON returns None and logs warning."""
|
||||
mock_glide_client.get.return_value = "invalid json {{"
|
||||
|
||||
with patch("crewai.memory.storage.valkey_cache._logger") as mock_logger:
|
||||
result = await valkey_cache.get("invalid_key")
|
||||
|
||||
assert result is None
|
||||
mock_logger.warning.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serialize_unicode_strings(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test serializing unicode strings."""
|
||||
unicode_data = "Hello 世界 🌍 Привет"
|
||||
mock_glide_client.get.return_value = json.dumps(unicode_data)
|
||||
|
||||
await valkey_cache.set("unicode_key", unicode_data)
|
||||
result = await valkey_cache.get("unicode_key")
|
||||
|
||||
assert result == unicode_data
|
||||
|
||||
|
||||
class TestValkeyCacheConnectionManagement:
|
||||
"""Tests for ValkeyCache connection management."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lazy_client_initialization(self) -> None:
|
||||
"""Test client is initialized lazily on first use."""
|
||||
cache = ValkeyCache(host="localhost", port=6379)
|
||||
|
||||
# Client should be None initially
|
||||
assert cache._client is None
|
||||
|
||||
# Mock GlideClient.create
|
||||
with patch("crewai.memory.storage.valkey_cache.GlideClient") as mock_glide:
|
||||
mock_client = AsyncMock()
|
||||
mock_glide.create = AsyncMock(return_value=mock_client)
|
||||
mock_client.get = AsyncMock(return_value=None)
|
||||
|
||||
# First operation should initialize client
|
||||
await cache.get("test_key")
|
||||
|
||||
# Client should now be initialized
|
||||
assert cache._client is not None
|
||||
mock_glide.create.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_reuse_across_operations(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test client is reused across multiple operations."""
|
||||
mock_glide_client.get.return_value = json.dumps("value")
|
||||
mock_glide_client.exists.return_value = 1
|
||||
|
||||
# Perform multiple operations
|
||||
await valkey_cache.get("key1")
|
||||
await valkey_cache.set("key2", "value2")
|
||||
await valkey_cache.exists("key3")
|
||||
await valkey_cache.delete("key4")
|
||||
|
||||
# _get_client should return the same client instance
|
||||
client1 = await valkey_cache._get_client()
|
||||
client2 = await valkey_cache._get_client()
|
||||
assert client1 is client2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_connection(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test closing the client connection."""
|
||||
# Initialize client
|
||||
await valkey_cache._get_client()
|
||||
assert valkey_cache._client is not None
|
||||
|
||||
# Close connection
|
||||
await valkey_cache.close()
|
||||
|
||||
# Verify close was called and client is None
|
||||
mock_glide_client.close.assert_called_once()
|
||||
assert valkey_cache._client is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_error_raises_runtime_error(self) -> None:
|
||||
"""Test connection error raises RuntimeError with descriptive message."""
|
||||
cache = ValkeyCache(host="invalid-host", port=9999)
|
||||
|
||||
with patch("crewai.memory.storage.valkey_cache.GlideClient") as mock_glide:
|
||||
mock_glide.create = AsyncMock(side_effect=Exception("Connection refused"))
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
await cache._get_client()
|
||||
|
||||
assert "Cannot connect to Valkey" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authentication_with_password(self) -> None:
|
||||
"""Test client initialization with password authentication."""
|
||||
cache = ValkeyCache(
|
||||
host="localhost",
|
||||
port=6379,
|
||||
password="secret_password"
|
||||
)
|
||||
|
||||
with patch("crewai.memory.storage.valkey_cache.GlideClient") as mock_glide:
|
||||
mock_client = AsyncMock()
|
||||
mock_glide.create = AsyncMock(return_value=mock_client)
|
||||
|
||||
await cache._get_client()
|
||||
|
||||
# Verify GlideClient.create was called with credentials
|
||||
mock_glide.create.assert_called_once()
|
||||
config = mock_glide.create.call_args[0][0]
|
||||
assert hasattr(config, "credentials")
|
||||
|
||||
|
||||
class TestValkeyCacheEdgeCases:
|
||||
"""Tests for ValkeyCache edge cases and error conditions."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_with_special_characters_in_key(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test setting values with special characters in key."""
|
||||
special_keys = [
|
||||
"key:with:colons",
|
||||
"key/with/slashes",
|
||||
"key-with-dashes",
|
||||
"key_with_underscores",
|
||||
"key.with.dots",
|
||||
]
|
||||
|
||||
for key in special_keys:
|
||||
await valkey_cache.set(key, "value")
|
||||
mock_glide_client.set.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_value_serialization(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test serializing large values."""
|
||||
large_list = list(range(10000))
|
||||
mock_glide_client.get.return_value = json.dumps(large_list)
|
||||
|
||||
await valkey_cache.set("large_key", large_list)
|
||||
result = await valkey_cache.get("large_key")
|
||||
|
||||
assert result == large_list
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_operations(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test concurrent cache operations."""
|
||||
import asyncio
|
||||
|
||||
mock_glide_client.get.return_value = json.dumps("value")
|
||||
|
||||
# Perform concurrent operations
|
||||
tasks = [
|
||||
valkey_cache.set(f"key{i}", f"value{i}")
|
||||
for i in range(10)
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# Verify all operations completed
|
||||
assert mock_glide_client.set.call_count == 10
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_non_serializable_value_raises_type_error(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test that non-JSON-serializable values raise TypeError."""
|
||||
from datetime import datetime
|
||||
|
||||
with pytest.raises(TypeError, match="not JSON-serializable"):
|
||||
await valkey_cache.set("bad_key", datetime.now())
|
||||
|
||||
# Verify set was never called on the client
|
||||
mock_glide_client.set.assert_not_called()
|
||||
3074
lib/crewai/tests/memory/storage/test_valkey_storage.py
Normal file
3074
lib/crewai/tests/memory/storage/test_valkey_storage.py
Normal file
File diff suppressed because it is too large
Load Diff
267
lib/crewai/tests/memory/storage/test_valkey_storage_errors.py
Normal file
267
lib/crewai/tests/memory/storage/test_valkey_storage_errors.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""Tests for ValkeyStorage error handling."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.memory.storage.valkey_storage import ValkeyStorage
|
||||
from crewai.memory.types import MemoryRecord
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_glide_client() -> AsyncMock:
|
||||
"""Create a mock GlideClient for testing."""
|
||||
client = AsyncMock()
|
||||
client.hset = AsyncMock(return_value=1)
|
||||
client.zrange = AsyncMock(return_value=[])
|
||||
client.zadd = AsyncMock()
|
||||
client.sadd = AsyncMock()
|
||||
client.hgetall = AsyncMock(return_value={})
|
||||
client.close = AsyncMock()
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def valkey_storage(mock_glide_client: AsyncMock) -> ValkeyStorage:
|
||||
"""Create a ValkeyStorage instance with mocked client."""
|
||||
storage = ValkeyStorage(host="localhost", port=6379, db=0)
|
||||
|
||||
# Mock the client creation to return our mock
|
||||
async def mock_create_client() -> AsyncMock:
|
||||
storage._client = mock_glide_client
|
||||
return mock_glide_client
|
||||
|
||||
storage._get_client = mock_create_client # type: ignore[method-assign]
|
||||
return storage
|
||||
|
||||
|
||||
class TestSerializationErrors:
|
||||
"""Tests for serialization error handling."""
|
||||
|
||||
def test_serialization_error_raises_descriptive_exception(
|
||||
self, valkey_storage: ValkeyStorage
|
||||
) -> None:
|
||||
"""Test that serialization errors raise descriptive ValueError."""
|
||||
# Create a record with non-serializable metadata
|
||||
record = MemoryRecord(
|
||||
id="test-id",
|
||||
content="test content",
|
||||
scope="/test",
|
||||
categories=["test"],
|
||||
metadata={"bad_key": object()}, # Non-serializable object
|
||||
importance=0.5,
|
||||
created_at=datetime.now(),
|
||||
last_accessed=datetime.now(),
|
||||
embedding=[0.1, 0.2, 0.3],
|
||||
)
|
||||
|
||||
# Should raise ValueError with descriptive message
|
||||
with pytest.raises(ValueError, match="Failed to serialize record test-id"):
|
||||
valkey_storage._record_to_dict(record)
|
||||
|
||||
def test_serialization_error_includes_cause(
|
||||
self, valkey_storage: ValkeyStorage
|
||||
) -> None:
|
||||
"""Test that serialization error includes the original exception as cause."""
|
||||
# Create a mock record that will fail during JSON serialization
|
||||
# We need to bypass Pydantic validation, so we'll patch json.dumps
|
||||
record = MemoryRecord(
|
||||
id="test-id-2",
|
||||
content="test content",
|
||||
scope="/test",
|
||||
categories=["valid"],
|
||||
metadata={"key": "value"},
|
||||
importance=0.5,
|
||||
created_at=datetime.now(),
|
||||
last_accessed=datetime.now(),
|
||||
embedding=[0.1, 0.2, 0.3],
|
||||
)
|
||||
|
||||
# Patch json.dumps to raise an error
|
||||
with patch("json.dumps", side_effect=TypeError("Cannot serialize")):
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
valkey_storage._record_to_dict(record)
|
||||
|
||||
# Verify the exception has a cause
|
||||
assert exc_info.value.__cause__ is not None
|
||||
assert isinstance(exc_info.value.__cause__, TypeError)
|
||||
|
||||
|
||||
class TestDeserializationErrors:
|
||||
"""Tests for deserialization error handling."""
|
||||
|
||||
def test_deserialization_error_logs_and_returns_none(
|
||||
self, valkey_storage: ValkeyStorage, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""Test that deserialization errors log error and return None."""
|
||||
# Create malformed data (missing required fields)
|
||||
malformed_data = {
|
||||
"id": "test-id",
|
||||
"content": "test content",
|
||||
# Missing scope, categories, metadata, etc.
|
||||
}
|
||||
|
||||
# Should return None and log error
|
||||
result = valkey_storage._dict_to_record(malformed_data)
|
||||
|
||||
assert result is None
|
||||
assert "Failed to deserialize record test-id" in caplog.text
|
||||
|
||||
def test_deserialization_with_invalid_json_categories_uses_tag_fallback(
|
||||
self, valkey_storage: ValkeyStorage
|
||||
) -> None:
|
||||
"""Test that non-JSON categories fall back to TAG (comma-separated) parsing."""
|
||||
# Create data with non-JSON categories string
|
||||
data = {
|
||||
"id": "test-id-json",
|
||||
"content": "test content",
|
||||
"scope": "/test",
|
||||
"categories": "not valid json [", # Not JSON, treated as TAG format
|
||||
"metadata": "{}",
|
||||
"importance": "0.5",
|
||||
"created_at": "2024-01-01T12:00:00",
|
||||
"last_accessed": "2024-01-01T12:00:00",
|
||||
"source": "",
|
||||
"private": "false",
|
||||
}
|
||||
|
||||
result = valkey_storage._dict_to_record(data)
|
||||
|
||||
# TAG fallback: comma-split produces the raw string as a single category
|
||||
assert result is not None
|
||||
assert result.id == "test-id-json"
|
||||
assert result.categories == ["not valid json ["]
|
||||
|
||||
def test_deserialization_with_invalid_datetime_returns_none(
|
||||
self, valkey_storage: ValkeyStorage, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""Test that invalid datetime format returns None."""
|
||||
# Create data with invalid datetime
|
||||
invalid_data = {
|
||||
"id": "test-id-datetime",
|
||||
"content": "test content",
|
||||
"scope": "/test",
|
||||
"categories": '["test"]',
|
||||
"metadata": "{}",
|
||||
"importance": "0.5",
|
||||
"created_at": "not a datetime", # Invalid datetime
|
||||
"last_accessed": "2024-01-01T12:00:00",
|
||||
"source": "",
|
||||
"private": "false",
|
||||
}
|
||||
|
||||
result = valkey_storage._dict_to_record(invalid_data)
|
||||
|
||||
assert result is None
|
||||
assert "Failed to deserialize record test-id-datetime" in caplog.text
|
||||
|
||||
def test_deserialization_with_invalid_float_returns_none(
|
||||
self, valkey_storage: ValkeyStorage, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""Test that invalid float importance returns None."""
|
||||
# Create data with invalid float
|
||||
invalid_data = {
|
||||
"id": "test-id-float",
|
||||
"content": "test content",
|
||||
"scope": "/test",
|
||||
"categories": '["test"]',
|
||||
"metadata": "{}",
|
||||
"importance": "not a float", # Invalid float
|
||||
"created_at": "2024-01-01T12:00:00",
|
||||
"last_accessed": "2024-01-01T12:00:00",
|
||||
"source": "",
|
||||
"private": "false",
|
||||
}
|
||||
|
||||
result = valkey_storage._dict_to_record(invalid_data)
|
||||
|
||||
assert result is None
|
||||
assert "Failed to deserialize record test-id-float" in caplog.text
|
||||
|
||||
def test_deserialization_with_bytes_keys_uses_tag_fallback(
|
||||
self, valkey_storage: ValkeyStorage
|
||||
) -> None:
|
||||
"""Test that deserialization handles bytes keys with non-JSON categories via TAG fallback."""
|
||||
# Create data with bytes keys (as returned by Valkey)
|
||||
bytes_data = {
|
||||
b"id": b"test-id-bytes",
|
||||
b"content": b"test content",
|
||||
b"scope": b"/test",
|
||||
b"categories": b"invalid json [", # Not JSON, treated as TAG format
|
||||
b"metadata": b"{}",
|
||||
b"importance": b"0.5",
|
||||
b"created_at": b"2024-01-01T12:00:00",
|
||||
b"last_accessed": b"2024-01-01T12:00:00",
|
||||
}
|
||||
|
||||
result = valkey_storage._dict_to_record(bytes_data)
|
||||
|
||||
# TAG fallback: comma-split produces the raw string as a single category
|
||||
assert result is not None
|
||||
assert result.id == "test-id-bytes"
|
||||
assert result.categories == ["invalid json ["]
|
||||
|
||||
|
||||
class TestRetryBehaviorIntegration:
|
||||
"""Integration tests demonstrating retry behavior patterns."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mock_client_operation_with_retry_pattern(
|
||||
self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test demonstrating how retry would work with client operations."""
|
||||
from glide import ClosingError
|
||||
|
||||
# Mock a client operation that fails once
|
||||
mock_glide_client.hgetall.side_effect = [
|
||||
ClosingError("Connection lost"),
|
||||
{
|
||||
b"id": b"test-id",
|
||||
b"content": b"test content",
|
||||
b"scope": b"/test",
|
||||
b"categories": b'["test"]',
|
||||
b"metadata": b"{}",
|
||||
b"importance": b"0.5",
|
||||
b"created_at": b"2024-01-01T12:00:00",
|
||||
b"last_accessed": b"2024-01-01T12:00:00",
|
||||
b"source": b"",
|
||||
b"private": b"false",
|
||||
b"embedding": b"",
|
||||
},
|
||||
]
|
||||
|
||||
# First call fails, second succeeds
|
||||
with pytest.raises(ClosingError):
|
||||
await mock_glide_client.hgetall("record:test-id")
|
||||
|
||||
# Second call succeeds
|
||||
result = await mock_glide_client.hgetall("record:test-id")
|
||||
assert result is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serialization_error_not_retried(
|
||||
self, valkey_storage: ValkeyStorage
|
||||
) -> None:
|
||||
"""Test that serialization errors are not retried (they're not connection errors)."""
|
||||
# Create a record with non-serializable data
|
||||
record = MemoryRecord(
|
||||
id="test-id",
|
||||
content="test content",
|
||||
scope="/test",
|
||||
categories=["test"],
|
||||
metadata={"bad": object()},
|
||||
importance=0.5,
|
||||
created_at=datetime.now(),
|
||||
last_accessed=datetime.now(),
|
||||
embedding=[0.1, 0.2, 0.3],
|
||||
)
|
||||
|
||||
# Serialization error should not be retried
|
||||
with pytest.raises(ValueError, match="Failed to serialize"):
|
||||
valkey_storage._record_to_dict(record)
|
||||
1110
lib/crewai/tests/memory/storage/test_valkey_storage_scope.py
Normal file
1110
lib/crewai/tests/memory/storage/test_valkey_storage_scope.py
Normal file
File diff suppressed because it is too large
Load Diff
998
lib/crewai/tests/memory/storage/test_valkey_storage_search.py
Normal file
998
lib/crewai/tests/memory/storage/test_valkey_storage_search.py
Normal file
@@ -0,0 +1,998 @@
|
||||
"""Tests for ValkeyStorage vector search operation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.memory.storage.valkey_storage import ValkeyStorage
|
||||
from crewai.memory.types import MemoryRecord
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_glide_client() -> AsyncMock:
|
||||
"""Create a mock GlideClient for testing."""
|
||||
client = AsyncMock()
|
||||
client.hset = AsyncMock(return_value=1)
|
||||
client.zrange = AsyncMock(return_value=[])
|
||||
client.zadd = AsyncMock()
|
||||
client.sadd = AsyncMock()
|
||||
client.hgetall = AsyncMock(return_value={})
|
||||
client.close = AsyncMock()
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def valkey_storage(mock_glide_client: AsyncMock) -> ValkeyStorage:
|
||||
"""Create a ValkeyStorage instance with mocked client."""
|
||||
storage = ValkeyStorage(host="localhost", port=6379, db=0)
|
||||
|
||||
# Mock the client creation to return our mock
|
||||
async def mock_create_client() -> AsyncMock:
|
||||
storage._client = mock_glide_client
|
||||
return mock_glide_client
|
||||
|
||||
storage._get_client = mock_create_client # type: ignore[method-assign]
|
||||
return storage
|
||||
|
||||
|
||||
def create_mock_ft_search_response(
|
||||
records: list[tuple[MemoryRecord, float]]
|
||||
) -> list[int | dict[str, dict[str, str]]]:
|
||||
"""Create a mock FT.SEARCH response in native dict format.
|
||||
|
||||
Args:
|
||||
records: List of (MemoryRecord, score) tuples to include in response.
|
||||
|
||||
Returns:
|
||||
Mock FT.SEARCH response in the native format:
|
||||
[total_count, {doc_key: {field: value, ...}, ...}]
|
||||
"""
|
||||
if not records:
|
||||
return [0]
|
||||
|
||||
docs: dict[str, dict[str, str]] = {}
|
||||
|
||||
for record, score in records:
|
||||
doc_key = f"record:{record.id}"
|
||||
|
||||
# Build field dict
|
||||
fields: dict[str, str] = {}
|
||||
fields["id"] = record.id
|
||||
fields["content"] = record.content
|
||||
fields["scope"] = record.scope
|
||||
fields["categories"] = json.dumps(record.categories)
|
||||
fields["metadata"] = json.dumps(record.metadata)
|
||||
fields["importance"] = str(record.importance)
|
||||
fields["created_at"] = record.created_at.isoformat()
|
||||
fields["last_accessed"] = record.last_accessed.isoformat()
|
||||
fields["source"] = record.source or ""
|
||||
fields["private"] = "true" if record.private else "false"
|
||||
|
||||
# Add score (Valkey Search returns cosine distance, not similarity)
|
||||
# Convert similarity to distance: distance = 2 * (1 - similarity)
|
||||
distance = 2.0 * (1.0 - score)
|
||||
fields["score"] = str(distance)
|
||||
|
||||
# Add embedding if present
|
||||
if record.embedding:
|
||||
fields["embedding"] = json.dumps(record.embedding)
|
||||
|
||||
docs[doc_key] = fields
|
||||
|
||||
return [len(records), docs]
|
||||
|
||||
|
||||
class TestValkeyStorageVectorSearch:
|
||||
"""Tests for ValkeyStorage vector search operation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_with_no_filters_returns_all_records(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search with no filters returns all records."""
|
||||
# Create test records
|
||||
record1 = MemoryRecord(
|
||||
id="record-1",
|
||||
content="First test record",
|
||||
scope="/test",
|
||||
categories=["cat1"],
|
||||
metadata={"key": "value1"},
|
||||
importance=0.8,
|
||||
created_at=datetime(2024, 1, 1, 10, 0, 0),
|
||||
last_accessed=datetime(2024, 1, 1, 11, 0, 0),
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
record2 = MemoryRecord(
|
||||
id="record-2",
|
||||
content="Second test record",
|
||||
scope="/test",
|
||||
categories=["cat2"],
|
||||
metadata={"key": "value2"},
|
||||
importance=0.6,
|
||||
created_at=datetime(2024, 1, 2, 10, 0, 0),
|
||||
last_accessed=datetime(2024, 1, 2, 11, 0, 0),
|
||||
embedding=[0.2, 0.3, 0.4, 0.5],
|
||||
)
|
||||
|
||||
# Mock FT.INFO to simulate index exists
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
# Mock FT.SEARCH to return both records
|
||||
mock_ft_search.return_value = create_mock_ft_search_response([
|
||||
(record1, 0.95),
|
||||
(record2, 0.85),
|
||||
])
|
||||
|
||||
# Perform search with no filters
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(query_embedding, limit=10)
|
||||
|
||||
# Verify ft.search was called
|
||||
mock_ft_search.assert_called_once()
|
||||
|
||||
# Verify query contains only KNN part (no filters)
|
||||
call_args = mock_ft_search.call_args
|
||||
query = call_args[0][2] # 3rd positional arg: query string
|
||||
assert "*=>[KNN 10 @embedding $BLOB AS score]" in query
|
||||
assert "@scope" not in query
|
||||
assert "@categories" not in query
|
||||
|
||||
# Verify results
|
||||
assert len(results) == 2
|
||||
assert results[0][0].id == "record-1"
|
||||
assert results[0][1] == 0.95
|
||||
assert results[1][0].id == "record-2"
|
||||
assert results[1][1] == 0.85
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_with_scope_filter_only(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search with scope filter only."""
|
||||
record1 = MemoryRecord(
|
||||
id="record-1",
|
||||
content="Record in scope",
|
||||
scope="/agent/task",
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)])
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(
|
||||
query_embedding,
|
||||
scope_prefix="/agent",
|
||||
limit=10
|
||||
)
|
||||
|
||||
# Verify query contains scope filter
|
||||
call_args = mock_ft_search.call_args
|
||||
query = call_args[0][2]
|
||||
assert "(@scope:{/agent*})=>[KNN 10 @embedding $BLOB AS score]" in query
|
||||
|
||||
# Verify results
|
||||
assert len(results) == 1
|
||||
assert results[0][0].id == "record-1"
|
||||
assert results[0][0].scope == "/agent/task"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_with_category_filter_only(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search with category filter only."""
|
||||
record1 = MemoryRecord(
|
||||
id="record-1",
|
||||
content="Record with planning category",
|
||||
scope="/test",
|
||||
categories=["planning"],
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.88)])
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(
|
||||
query_embedding,
|
||||
categories=["planning", "execution"],
|
||||
limit=10
|
||||
)
|
||||
|
||||
# Verify query contains category filter with OR logic
|
||||
call_args = mock_ft_search.call_args
|
||||
query = call_args[0][2]
|
||||
assert "(@categories:{planning|execution})=>[KNN 10 @embedding $BLOB AS score]" in query
|
||||
|
||||
# Verify results
|
||||
assert len(results) == 1
|
||||
assert results[0][0].id == "record-1"
|
||||
assert "planning" in results[0][0].categories
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_with_metadata_filter_only(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search with metadata filter only."""
|
||||
record1 = MemoryRecord(
|
||||
id="record-1",
|
||||
content="Record with metadata",
|
||||
scope="/test",
|
||||
metadata={"agent_id": "agent-1", "priority": "high"},
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.92)])
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(
|
||||
query_embedding,
|
||||
metadata_filter={"agent_id": "agent-1", "priority": "high"},
|
||||
limit=10
|
||||
)
|
||||
|
||||
# Verify query contains metadata filters (AND logic)
|
||||
call_args = mock_ft_search.call_args
|
||||
query = call_args[0][2]
|
||||
assert "@agent_id:{agent\\-1}" in query or "@agent_id:{agent-1}" in query
|
||||
assert "@priority:{high}" in query
|
||||
assert "=>[KNN 10 @embedding $BLOB AS score]" in query
|
||||
|
||||
# Verify results
|
||||
assert len(results) == 1
|
||||
assert results[0][0].id == "record-1"
|
||||
assert results[0][0].metadata["agent_id"] == "agent-1"
|
||||
assert results[0][0].metadata["priority"] == "high"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_with_combined_filters(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search with combined filters (scope + categories + metadata)."""
|
||||
record1 = MemoryRecord(
|
||||
id="record-1",
|
||||
content="Record matching all filters",
|
||||
scope="/agent/task",
|
||||
categories=["planning"],
|
||||
metadata={"agent_id": "agent-1"},
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.93)])
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(
|
||||
query_embedding,
|
||||
scope_prefix="/agent",
|
||||
categories=["planning"],
|
||||
metadata_filter={"agent_id": "agent-1"},
|
||||
limit=10
|
||||
)
|
||||
|
||||
# Verify query contains all filters
|
||||
call_args = mock_ft_search.call_args
|
||||
query = call_args[0][2]
|
||||
assert "@scope:{/agent*}" in query
|
||||
assert "@categories:{planning}" in query
|
||||
assert "@agent_id:{agent\\-1}" in query or "@agent_id:{agent-1}" in query
|
||||
assert "=>[KNN 10 @embedding $BLOB AS score]" in query
|
||||
|
||||
# Verify results
|
||||
assert len(results) == 1
|
||||
assert results[0][0].id == "record-1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_respects_limit_parameter(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search respects limit parameter."""
|
||||
records = [
|
||||
(
|
||||
MemoryRecord(
|
||||
id=f"record-{i}",
|
||||
content=f"Record {i}",
|
||||
scope="/test",
|
||||
embedding=[0.1 * i, 0.2 * i, 0.3 * i, 0.4 * i],
|
||||
),
|
||||
0.9 - (i * 0.1)
|
||||
)
|
||||
for i in range(1, 6)
|
||||
]
|
||||
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = create_mock_ft_search_response(records[:3])
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(query_embedding, limit=3)
|
||||
|
||||
# Verify KNN limit in query
|
||||
call_args = mock_ft_search.call_args
|
||||
query = call_args[0][2]
|
||||
assert "=>[KNN 3 @embedding $BLOB AS score]" in query
|
||||
|
||||
# Verify results respect limit
|
||||
assert len(results) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_respects_min_score_parameter(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search respects min_score parameter."""
|
||||
record1 = MemoryRecord(
|
||||
id="record-1",
|
||||
content="High score record",
|
||||
scope="/test",
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
record2 = MemoryRecord(
|
||||
id="record-2",
|
||||
content="Medium score record",
|
||||
scope="/test",
|
||||
embedding=[0.2, 0.3, 0.4, 0.5],
|
||||
)
|
||||
record3 = MemoryRecord(
|
||||
id="record-3",
|
||||
content="Low score record",
|
||||
scope="/test",
|
||||
embedding=[0.3, 0.4, 0.5, 0.6],
|
||||
)
|
||||
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = create_mock_ft_search_response([
|
||||
(record1, 0.95),
|
||||
(record2, 0.75),
|
||||
(record3, 0.55),
|
||||
])
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(
|
||||
query_embedding,
|
||||
limit=10,
|
||||
min_score=0.7
|
||||
)
|
||||
|
||||
# Verify only records with score >= 0.7 are returned
|
||||
assert len(results) == 2
|
||||
assert results[0][0].id == "record-1"
|
||||
assert results[0][1] == 0.95
|
||||
assert results[1][0].id == "record-2"
|
||||
assert results[1][1] == 0.75
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_returns_results_ordered_by_descending_score(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search returns results ordered by descending score."""
|
||||
record1 = MemoryRecord(
|
||||
id="record-1",
|
||||
content="Medium score",
|
||||
scope="/test",
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
record2 = MemoryRecord(
|
||||
id="record-2",
|
||||
content="Highest score",
|
||||
scope="/test",
|
||||
embedding=[0.2, 0.3, 0.4, 0.5],
|
||||
)
|
||||
record3 = MemoryRecord(
|
||||
id="record-3",
|
||||
content="Lowest score",
|
||||
scope="/test",
|
||||
embedding=[0.3, 0.4, 0.5, 0.6],
|
||||
)
|
||||
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = create_mock_ft_search_response([
|
||||
(record1, 0.75),
|
||||
(record2, 0.95),
|
||||
(record3, 0.55),
|
||||
])
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(query_embedding, limit=10)
|
||||
|
||||
# Verify results are ordered by descending score
|
||||
assert len(results) == 3
|
||||
assert results[0][0].id == "record-2"
|
||||
assert results[0][1] == 0.95
|
||||
assert results[1][0].id == "record-1"
|
||||
assert results[1][1] == 0.75
|
||||
assert results[2][0].id == "record-3"
|
||||
assert results[2][1] == 0.55
|
||||
|
||||
# Verify scores are in descending order
|
||||
for i in range(len(results) - 1):
|
||||
assert results[i][1] >= results[i + 1][1]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_with_empty_results(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search with no matching results."""
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = [0] # Total count = 0
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(query_embedding, limit=10)
|
||||
|
||||
# Verify empty results
|
||||
assert len(results) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_with_special_characters_in_scope(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search with special characters in scope prefix."""
|
||||
record1 = MemoryRecord(
|
||||
id="record-1",
|
||||
content="Record with special scope",
|
||||
scope="/agent:task-1",
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)])
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(
|
||||
query_embedding,
|
||||
scope_prefix="/agent:task",
|
||||
limit=10
|
||||
)
|
||||
|
||||
# Verify query contains escaped scope
|
||||
call_args = mock_ft_search.call_args
|
||||
query = call_args[0][2]
|
||||
assert "@scope:{/agent\\:task*}" in query or "@scope:{/agent:task*}" in query
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_with_special_characters_in_categories(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search with special characters in categories."""
|
||||
record1 = MemoryRecord(
|
||||
id="record-1",
|
||||
content="Record with special category",
|
||||
scope="/test",
|
||||
categories=["plan:execute"],
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)])
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(
|
||||
query_embedding,
|
||||
categories=["plan:execute"],
|
||||
limit=10
|
||||
)
|
||||
|
||||
# Verify query contains escaped category
|
||||
call_args = mock_ft_search.call_args
|
||||
query = call_args[0][2]
|
||||
assert "@categories:{plan\\:execute}" in query or "@categories:{plan:execute}" in query
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_with_numeric_metadata_values(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search with numeric metadata values."""
|
||||
record1 = MemoryRecord(
|
||||
id="record-1",
|
||||
content="Record with numeric metadata",
|
||||
scope="/test",
|
||||
metadata={"count": 42, "score": 3.14},
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)])
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(
|
||||
query_embedding,
|
||||
metadata_filter={"count": 42, "score": 3.14},
|
||||
limit=10
|
||||
)
|
||||
|
||||
# Verify query contains string-converted metadata values
|
||||
call_args = mock_ft_search.call_args
|
||||
query = call_args[0][2]
|
||||
assert "@count:{42}" in query
|
||||
assert "@score:{3" in query and "14}" in query
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_with_embedding_blob_parameter(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search passes embedding as BLOB parameter."""
|
||||
record1 = MemoryRecord(
|
||||
id="record-1",
|
||||
content="Test record",
|
||||
scope="/test",
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)])
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(query_embedding, limit=10)
|
||||
|
||||
# Verify ft.search was called with search options containing BLOB param
|
||||
call_args = mock_ft_search.call_args
|
||||
# The 4th positional arg is the FtSearchOptions
|
||||
search_options = call_args[0][3]
|
||||
# The options object should have params with BLOB
|
||||
assert search_options is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_results_sorted_by_score(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search results are sorted by score (descending) automatically."""
|
||||
record1 = MemoryRecord(
|
||||
id="record-1",
|
||||
content="Test record",
|
||||
scope="/test",
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)])
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(query_embedding, limit=10)
|
||||
|
||||
# Verify ft.search was called (results are auto-sorted by vector search)
|
||||
mock_ft_search.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_with_return_fields(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search includes RETURN clause with all record fields."""
|
||||
record1 = MemoryRecord(
|
||||
id="record-1",
|
||||
content="Test record",
|
||||
scope="/test",
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)])
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(query_embedding, limit=10)
|
||||
|
||||
# Verify ft.search was called with search options containing return fields
|
||||
call_args = mock_ft_search.call_args
|
||||
search_options = call_args[0][3]
|
||||
assert search_options is not None
|
||||
# The FtSearchOptions should have return_fields set
|
||||
assert search_options.return_fields is not None
|
||||
assert len(search_options.return_fields) == 11 # All fields including score
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.VectorFieldAttributesHnsw")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.create")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_handles_valkey_search_not_available(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_create: AsyncMock,
|
||||
mock_vector_attrs: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search raises error when Valkey Search module is not available."""
|
||||
# Mock FT.INFO to fail (index doesn't exist)
|
||||
mock_ft_list.return_value = []
|
||||
# Mock FT.CREATE to fail (Search module not available)
|
||||
mock_ft_create.side_effect = Exception("ERR unknown command 'ft.create'")
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
|
||||
with pytest.raises(RuntimeError, match="Valkey Search module is not available"):
|
||||
await valkey_storage.asearch(query_embedding, limit=10)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_handles_ft_search_error(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search handles FT.SEARCH errors gracefully."""
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.side_effect = Exception("ERR unknown command 'FT.SEARCH'")
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
|
||||
with pytest.raises(RuntimeError, match="Valkey Search module is not available"):
|
||||
await valkey_storage.asearch(query_embedding, limit=10)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_handles_malformed_ft_search_response(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search handles malformed FT.SEARCH response gracefully."""
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = None # Malformed response
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(query_embedding, limit=10)
|
||||
|
||||
# Verify empty results are returned (graceful handling)
|
||||
assert len(results) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_handles_missing_score_field(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search handles missing score field in results."""
|
||||
record1 = MemoryRecord(
|
||||
id="record-1",
|
||||
content="Test record",
|
||||
scope="/test",
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
|
||||
# Create mock response without score field (dict format)
|
||||
docs = {
|
||||
f"record:{record1.id}": {
|
||||
"id": record1.id,
|
||||
"content": record1.content,
|
||||
"scope": record1.scope,
|
||||
"categories": str(record1.categories),
|
||||
"metadata": str(record1.metadata),
|
||||
"importance": str(record1.importance),
|
||||
"created_at": record1.created_at.isoformat(),
|
||||
"last_accessed": record1.last_accessed.isoformat(),
|
||||
"source": record1.source or "",
|
||||
"private": "false",
|
||||
# No score field
|
||||
}
|
||||
}
|
||||
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = [1, docs]
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(query_embedding, limit=10)
|
||||
|
||||
# Verify record is returned with default score of 0.0
|
||||
assert len(results) == 1
|
||||
assert results[0][0].id == "record-1"
|
||||
assert results[0][1] == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_filters_out_records_with_deserialization_errors(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search filters out records that fail deserialization."""
|
||||
valid_record = MemoryRecord(
|
||||
id="valid-record",
|
||||
content="Valid record",
|
||||
scope="/test",
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
|
||||
# Create mock response with one valid and one invalid record (dict format)
|
||||
docs = {
|
||||
f"record:{valid_record.id}": {
|
||||
"id": valid_record.id,
|
||||
"content": valid_record.content,
|
||||
"scope": valid_record.scope,
|
||||
"categories": str(valid_record.categories),
|
||||
"metadata": str(valid_record.metadata),
|
||||
"importance": str(valid_record.importance),
|
||||
"created_at": valid_record.created_at.isoformat(),
|
||||
"last_accessed": valid_record.last_accessed.isoformat(),
|
||||
"source": valid_record.source or "",
|
||||
"private": "false",
|
||||
"score": "0.1",
|
||||
},
|
||||
"record:invalid-record": {
|
||||
"id": "invalid-record",
|
||||
# Missing content, scope, and other required fields
|
||||
"score": "0.2",
|
||||
},
|
||||
}
|
||||
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = [2, docs]
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(query_embedding, limit=10)
|
||||
|
||||
# Verify only valid record is returned
|
||||
assert len(results) == 1
|
||||
assert results[0][0].id == "valid-record"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_converts_cosine_distance_to_similarity(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search converts Valkey Search cosine distance to similarity score."""
|
||||
record1 = MemoryRecord(
|
||||
id="record-1",
|
||||
content="Test record",
|
||||
scope="/test",
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
|
||||
# Create mock response with distance score (dict format)
|
||||
docs = {
|
||||
f"record:{record1.id}": {
|
||||
"id": record1.id,
|
||||
"content": record1.content,
|
||||
"scope": record1.scope,
|
||||
"categories": str(record1.categories),
|
||||
"metadata": str(record1.metadata),
|
||||
"importance": str(record1.importance),
|
||||
"created_at": record1.created_at.isoformat(),
|
||||
"last_accessed": record1.last_accessed.isoformat(),
|
||||
"source": record1.source or "",
|
||||
"private": "false",
|
||||
"score": "0.1", # Distance = 0.1
|
||||
}
|
||||
}
|
||||
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = [1, docs]
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(query_embedding, limit=10)
|
||||
|
||||
# Verify similarity score is correctly converted
|
||||
assert len(results) == 1
|
||||
assert results[0][0].id == "record-1"
|
||||
# Distance 0.1 -> Similarity = 1 - (0.1 / 2) = 0.95
|
||||
assert abs(results[0][1] - 0.95) < 0.01
|
||||
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
def test_search_sync_wrapper(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test that sync search wrapper calls async implementation."""
|
||||
record1 = MemoryRecord(
|
||||
id="record-1",
|
||||
content="Test record",
|
||||
scope="/test",
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)])
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = valkey_storage.search(query_embedding, limit=10)
|
||||
|
||||
# Verify ft.search was called
|
||||
assert mock_ft_search.call_count >= 1
|
||||
|
||||
# Verify results
|
||||
assert len(results) == 1
|
||||
assert results[0][0].id == "record-1"
|
||||
assert results[0][1] == 0.9
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_with_multiple_categories_uses_or_logic(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search with multiple categories uses OR logic."""
|
||||
record1 = MemoryRecord(
|
||||
id="record-1",
|
||||
content="Record with one matching category",
|
||||
scope="/test",
|
||||
categories=["planning"],
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)])
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(
|
||||
query_embedding,
|
||||
categories=["planning", "execution", "review"],
|
||||
limit=10
|
||||
)
|
||||
|
||||
# Verify query contains OR logic for categories
|
||||
call_args = mock_ft_search.call_args
|
||||
query = call_args[0][2]
|
||||
assert "@categories:{planning|execution|review}" in query
|
||||
|
||||
# Verify record with only one matching category is returned
|
||||
assert len(results) == 1
|
||||
assert results[0][0].id == "record-1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_with_multiple_metadata_filters_uses_and_logic(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search with multiple metadata filters uses AND logic."""
|
||||
record1 = MemoryRecord(
|
||||
id="record-1",
|
||||
content="Record matching all metadata",
|
||||
scope="/test",
|
||||
metadata={"agent_id": "agent-1", "priority": "high", "status": "active"},
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)])
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(
|
||||
query_embedding,
|
||||
metadata_filter={"agent_id": "agent-1", "priority": "high", "status": "active"},
|
||||
limit=10
|
||||
)
|
||||
|
||||
# Verify query contains AND logic for metadata
|
||||
call_args = mock_ft_search.call_args
|
||||
query = call_args[0][2]
|
||||
assert "@agent_id:" in query
|
||||
assert "@priority:" in query
|
||||
assert "@status:" in query
|
||||
|
||||
# Verify record matching all metadata is returned
|
||||
assert len(results) == 1
|
||||
assert results[0][0].id == "record-1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_with_zero_limit_returns_empty(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search with limit=0 returns empty results."""
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = [0]
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(query_embedding, limit=0)
|
||||
|
||||
# Verify empty results
|
||||
assert len(results) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_with_min_score_one_filters_all(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search with min_score=1.0 filters out all non-perfect matches."""
|
||||
record1 = MemoryRecord(
|
||||
id="record-1",
|
||||
content="High score but not perfect",
|
||||
scope="/test",
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.99)])
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(
|
||||
query_embedding,
|
||||
limit=10,
|
||||
min_score=1.0
|
||||
)
|
||||
|
||||
# Verify all results are filtered out
|
||||
assert len(results) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_with_min_score_zero_returns_all(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search with min_score=0.0 returns all results."""
|
||||
record1 = MemoryRecord(
|
||||
id="record-1",
|
||||
content="High score",
|
||||
scope="/test",
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
record2 = MemoryRecord(
|
||||
id="record-2",
|
||||
content="Low score",
|
||||
scope="/test",
|
||||
embedding=[0.2, 0.3, 0.4, 0.5],
|
||||
)
|
||||
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = create_mock_ft_search_response([
|
||||
(record1, 0.95),
|
||||
(record2, 0.05),
|
||||
])
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(
|
||||
query_embedding,
|
||||
limit=10,
|
||||
min_score=0.0
|
||||
)
|
||||
|
||||
# Verify all results are returned
|
||||
assert len(results) == 2
|
||||
assert results[0][0].id == "record-1"
|
||||
assert results[1][0].id == "record-2"
|
||||
115
lib/crewai/tests/memory/test_embedding_safety.py
Normal file
115
lib/crewai/tests/memory/test_embedding_safety.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""Tests for embedding safety: bytes→float validators and async-safe embed_texts."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from crewai.memory.types import MemoryRecord, embed_text, embed_texts
|
||||
|
||||
|
||||
class TestMemoryRecordEmbeddingValidator:
|
||||
"""Tests for MemoryRecord.validate_embedding (bytes→list[float])."""
|
||||
|
||||
def test_none_embedding_stays_none(self) -> None:
|
||||
r = MemoryRecord(content="test", embedding=None)
|
||||
assert r.embedding is None
|
||||
|
||||
def test_list_of_floats_passes_through(self) -> None:
|
||||
r = MemoryRecord(content="test", embedding=[0.1, 0.2, 0.3])
|
||||
assert r.embedding == [0.1, 0.2, 0.3]
|
||||
|
||||
def test_bytes_converted_to_list_float(self) -> None:
|
||||
arr = np.array([0.1, 0.2, 0.3], dtype=np.float32)
|
||||
raw_bytes = arr.tobytes()
|
||||
r = MemoryRecord(content="test", embedding=raw_bytes)
|
||||
assert r.embedding is not None
|
||||
assert len(r.embedding) == 3
|
||||
assert all(isinstance(x, float) for x in r.embedding)
|
||||
np.testing.assert_allclose(r.embedding, [0.1, 0.2, 0.3], atol=1e-6)
|
||||
|
||||
def test_empty_bytes_becomes_none(self) -> None:
|
||||
r = MemoryRecord(content="test", embedding=b"")
|
||||
assert r.embedding is None
|
||||
|
||||
def test_list_of_ints_converted_to_floats(self) -> None:
|
||||
r = MemoryRecord(content="test", embedding=[1, 2, 3])
|
||||
assert r.embedding == [1.0, 2.0, 3.0]
|
||||
assert all(isinstance(x, float) for x in r.embedding)
|
||||
|
||||
def test_numpy_array_converted_to_list(self) -> None:
|
||||
arr = np.array([0.5, 0.6], dtype=np.float32)
|
||||
r = MemoryRecord(content="test", embedding=arr)
|
||||
assert r.embedding is not None
|
||||
assert isinstance(r.embedding, list)
|
||||
assert len(r.embedding) == 2
|
||||
|
||||
|
||||
class TestEmbedTextsAsyncSafety:
|
||||
"""Tests for embed_texts running safely in async context."""
|
||||
|
||||
def test_embed_texts_sync_context(self) -> None:
|
||||
"""embed_texts works in a normal sync context."""
|
||||
embedder = MagicMock(return_value=[[0.1, 0.2], [0.3, 0.4]])
|
||||
result = embed_texts(embedder, ["hello", "world"])
|
||||
assert len(result) == 2
|
||||
assert result[0] == [0.1, 0.2]
|
||||
embedder.assert_called_once()
|
||||
|
||||
def test_embed_texts_empty_input(self) -> None:
|
||||
embedder = MagicMock()
|
||||
assert embed_texts(embedder, []) == []
|
||||
embedder.assert_not_called()
|
||||
|
||||
def test_embed_texts_all_empty_strings(self) -> None:
|
||||
embedder = MagicMock()
|
||||
result = embed_texts(embedder, ["", " ", ""])
|
||||
assert result == [[], [], []]
|
||||
embedder.assert_not_called()
|
||||
|
||||
def test_embed_texts_skips_empty_preserves_positions(self) -> None:
|
||||
embedder = MagicMock(return_value=[[0.1, 0.2]])
|
||||
result = embed_texts(embedder, ["", "hello", ""])
|
||||
assert result == [[], [0.1, 0.2], []]
|
||||
embedder.assert_called_once_with(["hello"])
|
||||
|
||||
def test_embed_texts_in_async_context(self) -> None:
|
||||
"""embed_texts uses thread pool when called from async context."""
|
||||
embedder = MagicMock(return_value=[[0.1, 0.2]])
|
||||
|
||||
async def run() -> list[list[float]]:
|
||||
return embed_texts(embedder, ["hello"])
|
||||
|
||||
result = asyncio.run(run())
|
||||
assert result == [[0.1, 0.2]]
|
||||
embedder.assert_called_once()
|
||||
|
||||
|
||||
class TestEmbedText:
|
||||
"""Tests for embed_text (single text)."""
|
||||
|
||||
def test_empty_string_returns_empty(self) -> None:
|
||||
embedder = MagicMock()
|
||||
assert embed_text(embedder, "") == []
|
||||
embedder.assert_not_called()
|
||||
|
||||
def test_whitespace_only_returns_empty(self) -> None:
|
||||
embedder = MagicMock()
|
||||
assert embed_text(embedder, " ") == []
|
||||
embedder.assert_not_called()
|
||||
|
||||
def test_normal_text_returns_embedding(self) -> None:
|
||||
embedder = MagicMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
result = embed_text(embedder, "hello")
|
||||
assert result == [0.1, 0.2, 0.3]
|
||||
|
||||
def test_numpy_array_result_converted(self) -> None:
|
||||
arr = np.array([0.1, 0.2], dtype=np.float32)
|
||||
embedder = MagicMock(return_value=[arr])
|
||||
result = embed_text(embedder, "hello")
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 2
|
||||
125
lib/crewai/tests/utilities/test_cache_config.py
Normal file
125
lib/crewai/tests/utilities/test_cache_config.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""Tests for shared cache configuration helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.utilities.cache_config import (
|
||||
get_aiocache_config,
|
||||
parse_cache_url,
|
||||
use_valkey_cache,
|
||||
)
|
||||
|
||||
|
||||
class TestParseCacheUrl:
|
||||
"""Tests for parse_cache_url()."""
|
||||
|
||||
def test_returns_none_when_no_env_vars(self) -> None:
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
assert parse_cache_url() is None
|
||||
|
||||
def test_parses_valkey_url(self) -> None:
|
||||
with patch.dict(
|
||||
os.environ, {"VALKEY_URL": "redis://myhost:6380/2"}, clear=True
|
||||
):
|
||||
result = parse_cache_url()
|
||||
assert result is not None
|
||||
assert result["host"] == "myhost"
|
||||
assert result["port"] == 6380
|
||||
assert result["db"] == 2
|
||||
assert result["password"] is None
|
||||
|
||||
def test_parses_redis_url(self) -> None:
|
||||
with patch.dict(
|
||||
os.environ, {"REDIS_URL": "redis://localhost:6379/0"}, clear=True
|
||||
):
|
||||
result = parse_cache_url()
|
||||
assert result is not None
|
||||
assert result["host"] == "localhost"
|
||||
assert result["port"] == 6379
|
||||
assert result["db"] == 0
|
||||
|
||||
def test_valkey_url_takes_priority_over_redis_url(self) -> None:
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"VALKEY_URL": "redis://valkey-host:6380/1",
|
||||
"REDIS_URL": "redis://redis-host:6379/0",
|
||||
},
|
||||
clear=True,
|
||||
):
|
||||
result = parse_cache_url()
|
||||
assert result is not None
|
||||
assert result["host"] == "valkey-host"
|
||||
assert result["port"] == 6380
|
||||
|
||||
def test_parses_password(self) -> None:
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{"VALKEY_URL": "redis://:s3cret@myhost:6379/0"},
|
||||
clear=True,
|
||||
):
|
||||
result = parse_cache_url()
|
||||
assert result is not None
|
||||
assert result["password"] == "s3cret"
|
||||
|
||||
def test_defaults_for_minimal_url(self) -> None:
|
||||
with patch.dict(
|
||||
os.environ, {"VALKEY_URL": "redis://myhost"}, clear=True
|
||||
):
|
||||
result = parse_cache_url()
|
||||
assert result is not None
|
||||
assert result["host"] == "myhost"
|
||||
assert result["port"] == 6379
|
||||
assert result["db"] == 0
|
||||
assert result["password"] is None
|
||||
|
||||
def test_non_numeric_db_path_defaults_to_zero(self) -> None:
|
||||
with patch.dict(
|
||||
os.environ, {"VALKEY_URL": "redis://myhost:6379/mydb"}, clear=True
|
||||
):
|
||||
result = parse_cache_url()
|
||||
assert result is not None
|
||||
assert result["db"] == 0
|
||||
|
||||
|
||||
class TestGetAiocacheConfig:
|
||||
"""Tests for get_aiocache_config()."""
|
||||
|
||||
def test_returns_memory_cache_when_no_url(self) -> None:
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
config = get_aiocache_config()
|
||||
assert config["default"]["cache"] == "aiocache.SimpleMemoryCache"
|
||||
|
||||
def test_returns_redis_cache_when_url_set(self) -> None:
|
||||
with patch.dict(
|
||||
os.environ, {"VALKEY_URL": "redis://myhost:6380/2"}, clear=True
|
||||
):
|
||||
config = get_aiocache_config()
|
||||
assert config["default"]["cache"] == "aiocache.RedisCache"
|
||||
assert config["default"]["endpoint"] == "myhost"
|
||||
assert config["default"]["port"] == 6380
|
||||
assert config["default"]["db"] == 2
|
||||
|
||||
|
||||
class TestUseValkeyCache:
|
||||
"""Tests for use_valkey_cache()."""
|
||||
|
||||
def test_returns_false_when_not_set(self) -> None:
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
assert use_valkey_cache() is False
|
||||
|
||||
def test_returns_true_when_set(self) -> None:
|
||||
with patch.dict(
|
||||
os.environ, {"VALKEY_URL": "redis://localhost:6379"}, clear=True
|
||||
):
|
||||
assert use_valkey_cache() is True
|
||||
|
||||
def test_returns_false_when_only_redis_url_set(self) -> None:
|
||||
with patch.dict(
|
||||
os.environ, {"REDIS_URL": "redis://localhost:6379"}, clear=True
|
||||
):
|
||||
assert use_valkey_cache() is False
|
||||
Reference in New Issue
Block a user