Compare commits

...

1 Commits

Author SHA1 Message Date
Matthias Howell
c5a9a8da50 feat(valkey): shared cache config + ValkeyCache for A2A and file uploads
Extract duplicated Redis URL parsing into a shared cache_config utility.
Introduce ValkeyCache as a lightweight async key/value cache using
valkey-glide. Wire it into A2A task handling, agent card caching, and
file upload caching.

Part 1/4 of Valkey storage implementation.

fix: async-safe embeddings and resilient drain_writes

Add bytes→float validators on MemoryRecord and ItemState to handle
Valkey returning embeddings as raw bytes. Make embed_texts() safe when
called from an async context by using a thread pool. Improve
drain_writes() with per-save timeouts and error logging instead of
raising on failure.

Part 3/4 of Valkey storage implementation.

feat(valkey): ValkeyStorage vector memory backend

Add ValkeyStorage, a distributed StorageBackend implementation using
Valkey-GLIDE with Valkey Search for vector similarity. Wire it into
Memory as the 'valkey' storage option. Pin scrapegraph-py<2 to fix
unrelated upstream breakage.

Part 4/4 of Valkey storage implementation.

fix: use datetime.utcnow() for last_accessed consistency

MemoryRecord defaults use utcnow() for created_at and last_accessed.
Match that in ValkeyStorage.update_record() to avoid timezone
inconsistency in recency scoring.

feat(valkey): shared cache config + ValkeyCache for A2A and file uploads

Extract duplicated Redis URL parsing into a shared cache_config utility.
Introduce ValkeyCache as a lightweight async key/value cache using
valkey-glide. Wire it into A2A task handling, agent card caching, and
file upload caching.

Part 1/4 of Valkey storage implementation.

fix: handle non-numeric database path in cache URL parsing

Extract _parse_db_from_path() helper that catches ValueError for
paths like /mydb and defaults to 0 with a warning, instead of
crashing.

fix: async-safe embeddings and resilient drain_writes

Add bytes→float validators on MemoryRecord and ItemState to handle
Valkey returning embeddings as raw bytes. Make embed_texts() safe when
called from an async context by using a thread pool. Improve
drain_writes() with per-save timeouts and error logging instead of
raising on failure.

Part 3/4 of Valkey storage implementation.

fix: catch concurrent.futures.TimeoutError for Python 3.10 compat

In Python <3.11, concurrent.futures.TimeoutError is distinct from the
builtin TimeoutError. Catch both so the timeout warning path works
on all supported Python versions.
2026-05-13 11:00:36 -04:00
19 changed files with 8848 additions and 164 deletions

View File

@@ -1,4 +1,4 @@
"""Cache for tracking uploaded files using aiocache."""
"""Cache for tracking uploaded files using aiocache or ValkeyCache."""
from __future__ import annotations
@@ -10,10 +10,11 @@ from dataclasses import dataclass
from datetime import datetime, timezone
import hashlib
import logging
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Protocol
from aiocache import Cache # type: ignore[import-untyped]
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
from crewai.utilities.cache_config import parse_cache_url
from crewai_files.core.constants import DEFAULT_MAX_CACHE_ENTRIES, DEFAULT_TTL_SECONDS
from crewai_files.uploaders.factory import ProviderType
@@ -51,6 +52,33 @@ class CachedUpload:
return False
return datetime.now(timezone.utc) >= self.expires_at
def to_dict(self) -> dict[str, Any]:
"""Serialize to a JSON-compatible dict."""
return {
"file_id": self.file_id,
"provider": self.provider,
"file_uri": self.file_uri,
"content_type": self.content_type,
"uploaded_at": self.uploaded_at.isoformat(),
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> CachedUpload:
"""Deserialize from a dict."""
return cls(
file_id=data["file_id"],
provider=data["provider"],
file_uri=data.get("file_uri"),
content_type=data["content_type"],
uploaded_at=datetime.fromisoformat(data["uploaded_at"]),
expires_at=(
datetime.fromisoformat(data["expires_at"])
if data.get("expires_at")
else None
),
)
def _make_key(file_hash: str, provider: str) -> str:
"""Create a cache key from file hash and provider."""
@@ -58,14 +86,7 @@ def _make_key(file_hash: str, provider: str) -> str:
def _compute_file_hash_streaming(chunks: Iterator[bytes]) -> str:
"""Compute SHA-256 hash from streaming chunks.
Args:
chunks: Iterator of byte chunks.
Returns:
Hexadecimal hash string.
"""
"""Compute SHA-256 hash from streaming chunks."""
hasher = hashlib.sha256()
for chunk in chunks:
hasher.update(chunk)
@@ -73,10 +94,7 @@ def _compute_file_hash_streaming(chunks: Iterator[bytes]) -> str:
def _compute_file_hash(file: FileInput) -> str:
"""Compute SHA-256 hash of file content.
Uses streaming for FilePath sources to avoid loading large files into memory.
"""
"""Compute SHA-256 hash of file content."""
from crewai_files.core.sources import FilePath
source = file._file_source
@@ -86,10 +104,73 @@ def _compute_file_hash(file: FileInput) -> str:
return hashlib.sha256(content).hexdigest()
class UploadCache:
"""Async cache for tracking uploaded files using aiocache.
class CacheBackend(Protocol):
"""Protocol for cache backends used by UploadCache."""
Supports in-memory caching by default, with optional Redis backend
async def get(self, key: str) -> CachedUpload | None: ...
async def set(self, key: str, value: CachedUpload, ttl: int) -> None: ...
async def delete(self, key: str) -> bool: ...
class AiocacheBackend:
"""Cache backend backed by aiocache (memory or Redis)."""
def __init__(self, cache: Cache) -> None: # type: ignore[no-any-unimported]
self._cache = cache
async def get(self, key: str) -> CachedUpload | None:
result = await self._cache.get(key)
if isinstance(result, CachedUpload):
return result
return None
async def set(self, key: str, value: CachedUpload, ttl: int) -> None:
await self._cache.set(key, value, ttl=ttl)
async def delete(self, key: str) -> bool:
result = await self._cache.delete(key)
return bool(result > 0 if isinstance(result, int) else result)
class ValkeyCacheBackend:
"""Cache backend backed by ValkeyCache (JSON serialization)."""
def __init__(
self,
host: str = "localhost",
port: int = 6379,
db: int = 0,
password: str | None = None,
default_ttl: int | None = None,
) -> None:
from crewai.memory.storage.valkey_cache import ValkeyCache
self._cache = ValkeyCache(
host=host, port=port, db=db, password=password, default_ttl=default_ttl
)
async def get(self, key: str) -> CachedUpload | None:
data = await self._cache.get(key)
if data is None:
return None
try:
return CachedUpload.from_dict(data)
except (KeyError, ValueError) as e:
logger.warning(f"Failed to deserialize cached upload: {e}")
return None
async def set(self, key: str, value: CachedUpload, ttl: int) -> None:
await self._cache.set(key, value.to_dict(), ttl=ttl)
async def delete(self, key: str) -> bool:
await self._cache.delete(key)
return True # ValkeyCache.delete is void
class UploadCache:
"""Async cache for tracking uploaded files.
Supports in-memory caching by default, with optional Redis or Valkey backend
for distributed setups.
Attributes:
@@ -110,7 +191,7 @@ class UploadCache:
Args:
ttl: Default TTL in seconds.
namespace: Cache namespace.
cache_type: Backend type ("memory" or "redis").
cache_type: Backend type ("memory", "redis", or "valkey").
max_entries: Maximum cache entries (None for unlimited).
**cache_kwargs: Additional args for cache backend.
"""
@@ -120,18 +201,39 @@ class UploadCache:
self._provider_keys: dict[ProviderType, set[str]] = {}
self._key_access_order: list[str] = []
self._backend: CacheBackend = self._create_backend(
cache_type, namespace, ttl, **cache_kwargs
)
@staticmethod
def _create_backend(
cache_type: str,
namespace: str,
ttl: int,
**cache_kwargs: Any,
) -> CacheBackend:
"""Create the appropriate cache backend."""
if cache_type == "valkey":
conn = parse_cache_url() or {}
return ValkeyCacheBackend(
host=cache_kwargs.get("host", conn.get("host", "localhost")),
port=cache_kwargs.get("port", conn.get("port", 6379)),
db=cache_kwargs.get("db", conn.get("db", 0)),
password=cache_kwargs.get("password", conn.get("password")),
default_ttl=ttl,
)
if cache_type == "redis":
self._cache = Cache(
Cache.REDIS,
serializer=PickleSerializer(),
namespace=namespace,
**cache_kwargs,
)
else:
self._cache = Cache(
serializer=PickleSerializer(),
namespace=namespace,
return AiocacheBackend(
Cache(
Cache.REDIS,
serializer=PickleSerializer(),
namespace=namespace,
**cache_kwargs,
)
)
return AiocacheBackend(
Cache(serializer=PickleSerializer(), namespace=namespace)
)
def _track_key(self, provider: ProviderType, key: str) -> None:
"""Track a key for a provider (for cleanup) and access order."""
@@ -157,11 +259,9 @@ class UploadCache:
"""
if self.max_entries is None:
return 0
current_count = len(self)
if current_count < self.max_entries:
return 0
to_evict = max(1, self.max_entries // 10)
return await self._evict_oldest(to_evict)
@@ -176,31 +276,24 @@ class UploadCache:
"""
evicted = 0
keys_to_evict = self._key_access_order[:count]
for key in keys_to_evict:
await self._cache.delete(key)
await self._backend.delete(key)
self._key_access_order.remove(key)
for provider_keys in self._provider_keys.values():
provider_keys.discard(key)
evicted += 1
if evicted > 0:
logger.debug(f"Evicted {evicted} oldest cache entries")
return evicted
# ------------------------------------------------------------------
# Async public API
# ------------------------------------------------------------------
async def aget(
self, file: FileInput, provider: ProviderType
) -> CachedUpload | None:
"""Get a cached upload for a file.
Args:
file: The file to look up.
provider: The provider name.
Returns:
Cached upload if found and not expired, None otherwise.
"""
"""Get a cached upload for a file."""
file_hash = _compute_file_hash(file)
return await self.aget_by_hash(file_hash, provider)
@@ -217,17 +310,14 @@ class UploadCache:
Cached upload if found and not expired, None otherwise.
"""
key = _make_key(file_hash, provider)
result = await self._cache.get(key)
result = await self._backend.get(key)
if result is None:
return None
if isinstance(result, CachedUpload):
if result.is_expired():
await self._cache.delete(key)
self._untrack_key(provider, key)
return None
return result
return None
if result.is_expired():
await self._backend.delete(key)
self._untrack_key(provider, key)
return None
return result
async def aset(
self,
@@ -237,18 +327,7 @@ class UploadCache:
file_uri: str | None = None,
expires_at: datetime | None = None,
) -> CachedUpload:
"""Cache an uploaded file.
Args:
file: The file that was uploaded.
provider: The provider name.
file_id: Provider-specific file identifier.
file_uri: Optional URI for accessing the file.
expires_at: When the upload expires.
Returns:
The created cache entry.
"""
"""Cache an uploaded file."""
file_hash = _compute_file_hash(file)
return await self.aset_by_hash(
file_hash=file_hash,
@@ -282,7 +361,6 @@ class UploadCache:
The created cache entry.
"""
await self._evict_if_needed()
key = _make_key(file_hash, provider)
now = datetime.now(timezone.utc)
@@ -299,7 +377,7 @@ class UploadCache:
if expires_at is not None:
ttl = max(0, int((expires_at - now).total_seconds()))
await self._cache.set(key, cached, ttl=ttl)
await self._backend.set(key, cached, ttl=ttl)
self._track_key(provider, key)
logger.debug(f"Cached upload: {file_id} for provider {provider}")
return cached
@@ -316,9 +394,7 @@ class UploadCache:
"""
file_hash = _compute_file_hash(file)
key = _make_key(file_hash, provider)
result = await self._cache.delete(key)
removed = bool(result > 0 if isinstance(result, int) else result)
removed = await self._backend.delete(key)
if removed:
self._untrack_key(provider, key)
return removed
@@ -335,11 +411,10 @@ class UploadCache:
"""
if provider not in self._provider_keys:
return False
for key in list(self._provider_keys[provider]):
cached = await self._cache.get(key)
if isinstance(cached, CachedUpload) and cached.file_id == file_id:
await self._cache.delete(key)
cached = await self._backend.get(key)
if cached is not None and cached.file_id == file_id:
await self._backend.delete(key)
self._untrack_key(provider, key)
return True
return False
@@ -351,17 +426,13 @@ class UploadCache:
Number of entries removed.
"""
removed = 0
for provider, keys in list(self._provider_keys.items()):
for key in list(keys):
cached = await self._cache.get(key)
if cached is None or (
isinstance(cached, CachedUpload) and cached.is_expired()
):
await self._cache.delete(key)
cached = await self._backend.get(key)
if cached is None or cached.is_expired():
await self._backend.delete(key)
self._untrack_key(provider, key)
removed += 1
if removed > 0:
logger.debug(f"Cleared {removed} expired cache entries")
return removed
@@ -373,9 +444,12 @@ class UploadCache:
Number of entries cleared.
"""
count = sum(len(keys) for keys in self._provider_keys.values())
await self._cache.clear(namespace=self.namespace)
# Delete all tracked keys individually (works for all backends)
for keys in self._provider_keys.values():
for key in keys:
await self._backend.delete(key)
self._provider_keys.clear()
self._key_access_order.clear()
if count > 0:
logger.debug(f"Cleared {count} cache entries")
return count
@@ -391,14 +465,17 @@ class UploadCache:
"""
if provider not in self._provider_keys:
return []
results: list[CachedUpload] = []
for key in list(self._provider_keys[provider]):
cached = await self._cache.get(key)
if isinstance(cached, CachedUpload) and not cached.is_expired():
cached = await self._backend.get(key)
if cached is not None and not cached.is_expired():
results.append(cached)
return results
# ------------------------------------------------------------------
# Sync wrappers
# ------------------------------------------------------------------
@staticmethod
def _run_sync(coro: Any) -> Any:
"""Run an async coroutine from sync context without blocking event loop."""
@@ -489,11 +566,7 @@ class UploadCache:
return sum(len(keys) for keys in self._provider_keys.values())
def get_providers(self) -> builtins.set[ProviderType]:
"""Get all provider names that have cached entries.
Returns:
Set of provider names.
"""
"""Get all provider names that have cached entries."""
return builtins.set(self._provider_keys.keys())
@@ -506,17 +579,7 @@ def get_upload_cache(
cache_type: str = "memory",
**cache_kwargs: Any,
) -> UploadCache:
"""Get or create the default upload cache.
Args:
ttl: Default TTL in seconds.
namespace: Cache namespace.
cache_type: Backend type ("memory" or "redis").
**cache_kwargs: Additional args for cache backend.
Returns:
The upload cache instance.
"""
"""Get or create the default upload cache."""
global _default_cache
if _default_cache is None:
_default_cache = UploadCache(

View File

@@ -110,6 +110,9 @@ file-processing = [
qdrant-edge = [
"qdrant-edge-py>=0.6.0",
]
valkey = [
"valkey-glide>=1.3.0",
]
[tool.uv]

View File

@@ -13,8 +13,12 @@ from types import MethodType
from typing import TYPE_CHECKING
from a2a.client.errors import A2AClientHTTPError
from a2a.types import AgentCapabilities, AgentCard, AgentSkill
from aiocache import cached # type: ignore[import-untyped]
from a2a.types import (
AgentCapabilities,
AgentCard,
AgentSkill,
)
from aiocache import cached, caches # type: ignore[import-untyped]
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
import httpx
@@ -32,6 +36,7 @@ from crewai.events.types.a2a_events import (
A2AAuthenticationFailedEvent,
A2AConnectionErrorEvent,
)
from crewai.utilities.cache_config import get_aiocache_config
if TYPE_CHECKING:
@@ -40,6 +45,18 @@ if TYPE_CHECKING:
from crewai.task import Task
_cache_configured = False
def _ensure_cache_configured() -> None:
"""Configure aiocache on first use (lazy initialization)."""
global _cache_configured
if _cache_configured:
return
caches.set_config(get_aiocache_config())
_cache_configured = True
def _get_tls_verify(auth: ClientAuthScheme | None) -> ssl.SSLContext | bool | str:
"""Get TLS verify parameter from auth scheme.
@@ -191,6 +208,7 @@ async def afetch_agent_card(
else:
auth_hash = _auth_store.compute_key("none", "")
_auth_store.set(auth_hash, auth)
_ensure_cache_configured()
agent_card: AgentCard = await _afetch_agent_card_cached(
endpoint, auth_hash, timeout
)

View File

@@ -9,9 +9,8 @@ from datetime import datetime
from functools import wraps
import json
import logging
import os
import threading
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast
from urllib.parse import urlparse
from a2a.server.agent_execution import RequestContext
from a2a.server.events import EventQueue
@@ -38,7 +37,6 @@ from a2a.utils import (
from a2a.utils.errors import ServerError
from aiocache import SimpleMemoryCache, caches # type: ignore[import-untyped]
from pydantic import BaseModel
from typing_extensions import TypedDict
from crewai.a2a.utils.agent_card import _get_server_config
from crewai.a2a.utils.content_type import validate_message_parts
@@ -50,12 +48,18 @@ from crewai.events.types.a2a_events import (
A2AServerTaskStartedEvent,
)
from crewai.task import Task
from crewai.utilities.cache_config import (
get_aiocache_config,
parse_cache_url,
use_valkey_cache,
)
from crewai.utilities.pydantic_schema_utils import create_model_from_schema
if TYPE_CHECKING:
from crewai.a2a.extensions.server import ExtensionContext, ServerExtensionRegistry
from crewai.agent import Agent
from crewai.memory.storage.valkey_cache import ValkeyCache
logger = logging.getLogger(__name__)
@@ -64,52 +68,61 @@ P = ParamSpec("P")
T = TypeVar("T")
class RedisCacheConfig(TypedDict, total=False):
"""Configuration for aiocache Redis backend."""
# ---------------------------------------------------------------------------
# Lazy cache initialisation
# ---------------------------------------------------------------------------
cache: str
endpoint: str
port: int
db: int
password: str
_task_cache: ValkeyCache | None = None
_lazy_init_complete = False
_cache_init_lock = threading.Lock()
# Cancellation polling interval in seconds.
_CANCEL_POLL_INTERVAL = 0.1
# Configure aiocache at import time (matches upstream behaviour).
# This is safe — it only touches aiocache, no optional dependencies.
# The Valkey path is deferred to _ensure_task_cache() to avoid importing
# valkey-glide at module level (it may not be installed).
if not use_valkey_cache():
caches.set_config(get_aiocache_config())
def _parse_redis_url(url: str) -> RedisCacheConfig:
"""Parse a Redis URL into aiocache configuration.
def _ensure_task_cache() -> None:
"""Initialise the Valkey task cache on first use (thread-safe).
Args:
url: Redis connection URL (e.g., redis://localhost:6379/0).
Returns:
Configuration dict for aiocache.RedisCache.
For the aiocache path, configuration happens at module level above.
This function only needs to run for the Valkey path.
"""
parsed = urlparse(url)
config: RedisCacheConfig = {
"cache": "aiocache.RedisCache",
"endpoint": parsed.hostname or "localhost",
"port": parsed.port or 6379,
}
if parsed.path and parsed.path != "/":
try:
config["db"] = int(parsed.path.lstrip("/"))
except ValueError:
pass
if parsed.password:
config["password"] = parsed.password
return config
global _task_cache, _lazy_init_complete
if _lazy_init_complete:
return
with _cache_init_lock:
if _lazy_init_complete:
return
_redis_url = os.environ.get("REDIS_URL")
if use_valkey_cache():
from crewai.memory.storage.valkey_cache import ValkeyCache
caches.set_config(
{
"default": _parse_redis_url(_redis_url)
if _redis_url
else {
"cache": "aiocache.SimpleMemoryCache",
}
}
)
conn = parse_cache_url() or {}
try:
_task_cache = ValkeyCache(
host=conn.get("host", "localhost"),
port=conn.get("port", 6379),
db=conn.get("db", 0),
password=conn.get("password"),
default_ttl=3600,
)
except Exception as e:
logger.error(
"Failed to initialize ValkeyCache for task cancellation, "
"falling back to aiocache",
extra={"error": str(e)},
)
caches.set_config(get_aiocache_config())
_task_cache = None
_lazy_init_complete = True
def cancellable(
@@ -130,6 +143,8 @@ def cancellable(
@wraps(fn)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
"""Wrap function with cancellation monitoring."""
_ensure_task_cache()
context: RequestContext | None = None
for arg in args:
if isinstance(arg, RequestContext):
@@ -142,19 +157,34 @@ def cancellable(
return await fn(*args, **kwargs)
task_id = context.task_id
cache = caches.get("default")
async def poll_for_cancel() -> bool:
"""Poll cache for cancellation flag."""
async def poll_for_cancel_valkey() -> bool:
"""Poll ValkeyCache for cancellation flag."""
while True:
if _task_cache is not None and await _task_cache.get(
f"cancel:{task_id}"
):
return True
await asyncio.sleep(_CANCEL_POLL_INTERVAL)
async def poll_for_cancel_aiocache() -> bool:
"""Poll aiocache for cancellation flag."""
cache = caches.get("default")
while True:
if await cache.get(f"cancel:{task_id}"):
return True
await asyncio.sleep(0.1)
await asyncio.sleep(_CANCEL_POLL_INTERVAL)
async def watch_for_cancel() -> bool:
"""Watch for cancellation events via pub/sub or polling."""
if _task_cache is not None:
# ValkeyCache: use polling (pub/sub not implemented yet)
return await poll_for_cancel_valkey()
# aiocache: use pub/sub if Redis, otherwise poll
cache = caches.get("default")
if isinstance(cache, SimpleMemoryCache):
return await poll_for_cancel()
return await poll_for_cancel_aiocache()
try:
client = cache.client
@@ -168,7 +198,7 @@ def cancellable(
"Cancel watcher Redis error, falling back to polling",
extra={"task_id": task_id, "error": str(e)},
)
return await poll_for_cancel()
return await poll_for_cancel_aiocache()
return False
execute_task = asyncio.create_task(fn(*args, **kwargs))
@@ -190,7 +220,12 @@ def cancellable(
cancel_watch.cancel()
return execute_task.result()
finally:
await cache.delete(f"cancel:{task_id}")
# Clean up cancellation flag
if _task_cache is not None:
await _task_cache.delete(f"cancel:{task_id}")
else:
cache = caches.get("default")
await cache.delete(f"cancel:{task_id}")
return wrapper
@@ -475,6 +510,8 @@ async def cancel(
if task_id is None or context_id is None:
raise ServerError(InvalidParamsError(message="task_id and context_id required"))
_ensure_task_cache()
if context.current_task and context.current_task.status.state in (
TaskState.completed,
TaskState.failed,
@@ -482,11 +519,16 @@ async def cancel(
):
return context.current_task
cache = caches.get("default")
await cache.set(f"cancel:{task_id}", True, ttl=3600)
if not isinstance(cache, SimpleMemoryCache):
await cache.client.publish(f"cancel:{task_id}", "cancel")
if _task_cache is not None:
# Use ValkeyCache
await _task_cache.set(f"cancel:{task_id}", True, ttl=3600)
# Note: pub/sub not implemented for ValkeyCache yet, relies on polling
else:
# Use aiocache
cache = caches.get("default")
await cache.set(f"cancel:{task_id}", True, ttl=3600)
if not isinstance(cache, SimpleMemoryCache):
await cache.client.publish(f"cancel:{task_id}", "cancel")
await event_queue.enqueue_event(
TaskStatusUpdateEvent(

View File

@@ -18,7 +18,7 @@ import math
from typing import Any
from uuid import uuid4
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from crewai.flow.flow import Flow, listen, start
from crewai.memory.analyze import (
@@ -68,6 +68,27 @@ class ItemState(BaseModel):
plan: ConsolidationPlan | None = None
result_record: MemoryRecord | None = None
@field_validator("similar_records", "result_record", mode="before")
@classmethod
def ensure_embedding_is_list(cls, v: Any) -> Any:
"""Ensure MemoryRecord embeddings are list[float], not bytes.
Delegates to MemoryRecord.validate_embedding for consistent behavior
(e.g. empty bytes → None).
"""
if v is None:
return None
if isinstance(v, list):
for record in v:
if isinstance(record, MemoryRecord) and isinstance(
record.embedding, bytes
):
record.embedding = MemoryRecord.validate_embedding(record.embedding)
return v
if isinstance(v, MemoryRecord) and isinstance(v.embedding, bytes):
v.embedding = MemoryRecord.validate_embedding(v.embedding)
return v
class EncodingState(BaseModel):
"""Batch-level state for the encoding flow."""

View File

@@ -0,0 +1,198 @@
"""Valkey-based cache implementation for CrewAI.
This module provides a simple cache interface using Valkey-GLIDE client
for caching operations with optional TTL support. It replaces Redis usage
in A2A communication, file uploads, and agent card caching.
"""
from __future__ import annotations
import json
import logging
from typing import Any
from glide import GlideClient, GlideClientConfiguration, NodeAddress
_logger = logging.getLogger(__name__)
class ValkeyCache:
"""Simple cache interface using Valkey-GLIDE client.
Provides get/set/delete/exists operations for caching with optional TTL.
Uses JSON serialization for complex values and lazy client initialization.
Example:
>>> cache = ValkeyCache(host="localhost", port=6379)
>>> await cache.set("key", {"data": "value"}, ttl=3600)
>>> value = await cache.get("key")
>>> await cache.delete("key")
"""
def __init__(
self,
host: str = "localhost",
port: int = 6379,
db: int = 0,
password: str | None = None,
default_ttl: int | None = None,
) -> None:
"""Initialize Valkey cache.
Args:
host: Valkey server hostname.
port: Valkey server port.
db: Database number to use.
password: Optional password for authentication.
default_ttl: Default TTL in seconds (None = no expiration).
"""
self._host = host
self._port = port
self._db = db
self._password = password
self._default_ttl = default_ttl
self._client: GlideClient | None = None
async def _get_client(self) -> GlideClient:
"""Get or create Valkey client (lazy initialization).
Returns:
Initialized GlideClient instance.
Raises:
RuntimeError: If connection to Valkey fails.
TimeoutError: If connection attempt times out (10 seconds).
"""
import asyncio
if self._client is None:
host = self._host
port = self._port
db = self._db
try:
from glide import ServerCredentials
config = GlideClientConfiguration(
addresses=[NodeAddress(host, port)],
database_id=db,
credentials=(
ServerCredentials(password=self._password)
if self._password
else None
),
)
# Add connection timeout (10 seconds)
try:
self._client = await asyncio.wait_for(
GlideClient.create(config), timeout=10.0
)
except asyncio.TimeoutError as e:
_logger.error("Connection timeout connecting to Valkey")
raise TimeoutError(
"Connection timeout to Valkey. "
"Ensure Valkey is running and accessible."
) from e
_logger.info("Valkey cache client initialized")
except (TimeoutError, RuntimeError):
raise
except Exception as e:
_logger.error(
"Failed to create Valkey cache client: %s", type(e).__name__
)
raise RuntimeError(
"Cannot connect to Valkey. Check connection settings."
) from e
return self._client
async def get(self, key: str) -> Any | None:
"""Get value from cache.
Args:
key: Cache key.
Returns:
Cached value (deserialized from JSON) or None if not found.
"""
client = await self._get_client()
value = await client.get(key)
if value is None:
return None
try:
return json.loads(value)
except json.JSONDecodeError:
_logger.warning(f"Failed to deserialize cached value for key: {key}")
return None
async def set(
self,
key: str,
value: Any,
ttl: int | None = None,
) -> None:
"""Set value in cache.
Args:
key: Cache key.
value: Value to cache (will be serialized to JSON).
ttl: TTL in seconds (None uses default_ttl, 0 = no expiration).
Raises:
TypeError: If value is not JSON-serializable.
"""
from glide import ExpirySet, ExpiryType
client = await self._get_client()
try:
serialized = json.dumps(value)
except (TypeError, ValueError) as e:
_logger.error("Cannot serialize value for key %r: %s", key, e)
raise TypeError(
f"Value for cache key {key!r} is not JSON-serializable: {e}"
) from e
ttl_to_use = ttl if ttl is not None else self._default_ttl
if ttl_to_use and ttl_to_use > 0:
# Set with expiration using SET command with EX option
await client.set(
key,
serialized,
expiry=ExpirySet(ExpiryType.SEC, ttl_to_use),
)
else:
await client.set(key, serialized)
async def delete(self, key: str) -> None:
"""Delete value from cache.
Args:
key: Cache key to delete.
"""
client = await self._get_client()
await client.delete([key])
async def exists(self, key: str) -> bool:
"""Check if key exists in cache.
Args:
key: Cache key to check.
Returns:
True if key exists, False otherwise.
"""
client = await self._get_client()
result = await client.exists([key])
return result > 0
async def close(self) -> None:
"""Close Valkey client connection."""
if self._client:
await self._client.close()
self._client = None
_logger.debug("Valkey cache client closed")

File diff suppressed because it is too large Load Diff

View File

@@ -2,13 +2,17 @@
from __future__ import annotations
import concurrent.futures
from datetime import datetime
import logging
from typing import Any
from uuid import uuid4
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
_logger = logging.getLogger(__name__)
# When searching the vector store, we ask for more results than the caller
# requested so that post-search steps (composite scoring, deduplication,
# category filtering) have enough candidates to fill the final result set.
@@ -57,6 +61,23 @@ class MemoryRecord(BaseModel):
repr=False,
description="Vector embedding for semantic search. Excluded from serialization to save tokens.",
)
@field_validator("embedding", mode="before")
@classmethod
def validate_embedding(cls, v: Any) -> list[float] | None:
"""Ensure embedding is always list[float] or None, never bytes."""
if v is None:
return None
if isinstance(v, bytes):
# Convert bytes to list[float] if needed
import numpy as np
if len(v) == 0:
return None
arr = np.frombuffer(v, dtype=np.float32)
return [float(x) for x in arr]
return [float(x) for x in v]
source: str | None = Field(
default=None,
description=(
@@ -304,7 +325,11 @@ def embed_text(embedder: Any, text: str) -> list[float]:
"""
if not text or not text.strip():
return []
# Just call the embedder directly - the blocking issue needs to be fixed
# at a higher level (making Memory.recall() async)
result = embedder([text])
if not result:
return []
first = result[0]
@@ -315,12 +340,27 @@ def embed_text(embedder: Any, text: str) -> list[float]:
return list(first)
# Reusable thread pool for running embedder calls from sync context
# when an async event loop is already running. Uses max_workers=2 so
# a single slow/timed-out call doesn't block subsequent embeds.
_EMBED_POOL = concurrent.futures.ThreadPoolExecutor(max_workers=2)
def embed_texts(embedder: Any, texts: list[str]) -> list[list[float]]:
"""Embed multiple texts in a single API call.
The embedder already accepts ``list[str]``, so this just calls it once
with the full batch and normalises the output format.
When called from an async context, offloads the embedder to a thread pool
so the embedding work doesn't run on the event loop thread. The calling
thread still blocks on the result (unavoidable for a sync function), but
this prevents the embedder from starving the event loop's I/O callbacks.
The pool uses ``max_workers=2`` so a single timed-out call doesn't block
subsequent embeds.
Note: the proper long-term fix is making ``Memory.recall()`` async.
Args:
embedder: Callable that accepts a list of strings and returns embeddings.
texts: List of texts to embed.
@@ -328,6 +368,8 @@ def embed_texts(embedder: Any, texts: list[str]) -> list[list[float]]:
Returns:
List of embeddings, one per input text. Empty texts produce empty lists.
"""
import asyncio
if not texts:
return []
# Filter out empty texts, remembering their positions
@@ -337,7 +379,28 @@ def embed_texts(embedder: Any, texts: list[str]) -> list[list[float]]:
if not valid:
return [[] for _ in texts]
result = embedder([t for _, t in valid])
texts_to_embed = [t for _, t in valid]
# Check if we're in an async context
result: Any
try:
asyncio.get_running_loop()
# We're in an async context but this is a sync function.
# Offload to thread pool so the embedder doesn't run on the
# event loop thread. The .result() call blocks this thread
# (acceptable — callers like Memory.recall() are sync).
try:
result = _EMBED_POOL.submit(embedder, texts_to_embed).result(timeout=30)
except concurrent.futures.TimeoutError:
_logger.warning(
"Embedder timed out after 30s, returning empty embeddings. "
"The worker thread may still be running."
)
return [[] for _ in texts]
except RuntimeError:
# Not in async context, run directly
result = embedder(texts_to_embed)
embeddings: list[list[float]] = [[] for _ in texts]
for (orig_idx, _), emb in zip(valid, result, strict=False):
if hasattr(emb, "tolist"):

View File

@@ -2,9 +2,11 @@
from __future__ import annotations
import concurrent.futures
from concurrent.futures import Future, ThreadPoolExecutor
import contextvars
from datetime import datetime
import logging
import threading
import time
from typing import TYPE_CHECKING, Annotated, Any, Literal
@@ -36,6 +38,9 @@ from crewai.rag.embeddings.factory import build_embedder
from crewai.rag.embeddings.providers.openai.types import OpenAIProviderSpec
_logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
@@ -211,6 +216,18 @@ class Memory(BaseModel):
from crewai.memory.storage.lancedb_storage import LanceDBStorage
self._storage = LanceDBStorage()
elif self.storage == "valkey":
from crewai.memory.storage.valkey_storage import ValkeyStorage
from crewai.utilities.cache_config import parse_cache_url
conn = parse_cache_url() or {}
self._storage = ValkeyStorage(
host=conn.get("host", "localhost"),
port=conn.get("port", 6379),
db=conn.get("db", 0),
password=conn.get("password"),
use_tls=conn.get("use_tls", False),
)
else:
from crewai.memory.storage.lancedb_storage import LanceDBStorage
@@ -316,16 +333,60 @@ class Memory(BaseModel):
except Exception: # noqa: S110
pass # swallow everything during shutdown
def drain_writes(self) -> None:
def drain_writes(self, timeout_per_save: float = 60.0) -> None:
"""Block until all pending background saves have completed.
Called automatically by ``recall()`` and should be called by the
crew at shutdown to ensure no saves are lost.
Args:
timeout_per_save: Maximum seconds to wait per save operation.
Default 60s. If a save times out, logs warning
but continues to avoid blocking crew completion.
"""
with self._pending_lock:
pending = list(self._pending_saves)
for future in pending:
future.result() # blocks until done; re-raises exceptions
if pending:
_logger.debug(
"[DRAIN_WRITES] Waiting for %d pending saves...", len(pending)
)
failed_saves = 0
for i, future in enumerate(pending):
try:
_logger.debug(
"[DRAIN_WRITES] Waiting for save %d/%d...", i + 1, len(pending)
)
future.result(timeout=timeout_per_save)
_logger.debug(
"[DRAIN_WRITES] Save %d/%d completed", i + 1, len(pending)
)
except (TimeoutError, concurrent.futures.TimeoutError): # noqa: PERF203
failed_saves += 1
_logger.warning(
"[DRAIN_WRITES] Save %d/%d timed out after %ss. "
"This save will be abandoned. Consider increasing timeout or checking "
"LLM/embedder performance.",
i + 1,
len(pending),
timeout_per_save,
)
# Don't raise - just log and continue to avoid blocking crew completion
except Exception as e:
failed_saves += 1
_logger.error(
"[DRAIN_WRITES] Save %d/%d failed: %s", i + 1, len(pending), e
)
# Don't raise - just log and continue
if failed_saves > 0:
_logger.warning(
"[DRAIN_WRITES] %d/%d saves failed or timed out. "
"Some memories may not have been persisted.",
failed_saves,
len(pending),
)
def close(self) -> None:
"""Drain pending saves, flush storage, and shut down the background thread pool."""

View File

@@ -0,0 +1,78 @@
"""Shared cache configuration helpers for Valkey/Redis URL parsing."""
from __future__ import annotations
import logging
import os
from typing import Any
from urllib.parse import urlparse
_logger = logging.getLogger(__name__)
def parse_cache_url() -> dict[str, Any] | None:
"""Parse VALKEY_URL or REDIS_URL from environment.
Priority: VALKEY_URL > REDIS_URL.
Returns:
Dict with host, port, db, password keys, or None if no URL is set.
"""
url = os.environ.get("VALKEY_URL") or os.environ.get("REDIS_URL")
if not url:
return None
parsed = urlparse(url)
return {
"host": parsed.hostname or "localhost",
"port": parsed.port or 6379,
"db": _parse_db_from_path(parsed.path),
"password": parsed.password,
"use_tls": parsed.scheme in ("rediss", "valkeys"),
}
def _parse_db_from_path(path: str | None) -> int:
"""Parse database number from URL path, defaulting to 0."""
if not path or path == "/":
return 0
try:
return int(path.lstrip("/"))
except ValueError:
_logger.warning(
"Invalid database number in URL path: %s, using default 0", path
)
return 0
def get_aiocache_config() -> dict[str, Any]:
"""Build an aiocache configuration dict from environment.
Uses VALKEY_URL or REDIS_URL (both are Redis-wire-compatible) to
configure ``aiocache.RedisCache``. Falls back to
``aiocache.SimpleMemoryCache`` when neither variable is set.
Returns:
Configuration dict suitable for ``aiocache.caches.set_config()``.
"""
conn = parse_cache_url()
if conn is not None:
return {
"default": {
"cache": "aiocache.RedisCache",
"endpoint": conn["host"],
"port": conn["port"],
"db": conn.get("db", 0),
"password": conn.get("password"),
}
}
return {
"default": {
"cache": "aiocache.SimpleMemoryCache",
}
}
def use_valkey_cache() -> bool:
"""Return True if VALKEY_URL is set in the environment."""
return bool(os.environ.get("VALKEY_URL"))

View File

@@ -0,0 +1,511 @@
"""Tests for ValkeyCache implementation."""
from __future__ import annotations
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from crewai.memory.storage.valkey_cache import ValkeyCache
@pytest.fixture
def mock_glide_client() -> AsyncMock:
"""Create a mock GlideClient for testing."""
client = AsyncMock()
client.get = AsyncMock()
client.set = AsyncMock()
client.delete = AsyncMock()
client.exists = AsyncMock()
client.close = AsyncMock()
return client
@pytest.fixture
def valkey_cache(mock_glide_client: AsyncMock) -> ValkeyCache:
"""Create a ValkeyCache instance with mocked client."""
cache = ValkeyCache(host="localhost", port=6379, db=0)
# Mock the client creation to return our mock
async def mock_create_client() -> AsyncMock:
cache._client = mock_glide_client
return mock_glide_client
cache._get_client = mock_create_client # type: ignore[method-assign]
return cache
class TestValkeyCacheBasicOperations:
"""Tests for basic ValkeyCache operations (get/set/delete/exists)."""
@pytest.mark.asyncio
async def test_set_and_get_string_value(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test setting and getting a string value."""
# Mock get to return serialized string
mock_glide_client.get.return_value = json.dumps("test_value")
# Set value
await valkey_cache.set("test_key", "test_value")
# Verify set was called
mock_glide_client.set.assert_called_once()
call_args = mock_glide_client.set.call_args
assert call_args[0][0] == "test_key"
assert call_args[0][1] == json.dumps("test_value")
# Get value
result = await valkey_cache.get("test_key")
# Verify get was called and result is correct
mock_glide_client.get.assert_called_once_with("test_key")
assert result == "test_value"
@pytest.mark.asyncio
async def test_set_and_get_dict_value(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test setting and getting a dictionary value."""
test_dict = {"key1": "value1", "key2": 42, "key3": [1, 2, 3]}
mock_glide_client.get.return_value = json.dumps(test_dict)
# Set value
await valkey_cache.set("dict_key", test_dict)
# Verify set was called with serialized dict
mock_glide_client.set.assert_called_once()
call_args = mock_glide_client.set.call_args
assert call_args[0][0] == "dict_key"
assert call_args[0][1] == json.dumps(test_dict)
# Get value
result = await valkey_cache.get("dict_key")
# Verify result matches original dict
assert result == test_dict
@pytest.mark.asyncio
async def test_set_and_get_list_value(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test setting and getting a list value."""
test_list = [1, "two", 3.0, {"nested": "dict"}]
mock_glide_client.get.return_value = json.dumps(test_list)
await valkey_cache.set("list_key", test_list)
result = await valkey_cache.get("list_key")
assert result == test_list
@pytest.mark.asyncio
async def test_get_nonexistent_key_returns_none(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test getting a non-existent key returns None."""
mock_glide_client.get.return_value = None
result = await valkey_cache.get("nonexistent_key")
assert result is None
mock_glide_client.get.assert_called_once_with("nonexistent_key")
@pytest.mark.asyncio
async def test_delete_key(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test deleting a key."""
await valkey_cache.delete("test_key")
mock_glide_client.delete.assert_called_once_with(["test_key"])
@pytest.mark.asyncio
async def test_exists_returns_true_for_existing_key(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test exists returns True for existing key."""
mock_glide_client.exists.return_value = 1
result = await valkey_cache.exists("existing_key")
assert result is True
mock_glide_client.exists.assert_called_once_with(["existing_key"])
@pytest.mark.asyncio
async def test_exists_returns_false_for_nonexistent_key(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test exists returns False for non-existent key."""
mock_glide_client.exists.return_value = 0
result = await valkey_cache.exists("nonexistent_key")
assert result is False
mock_glide_client.exists.assert_called_once_with(["nonexistent_key"])
class TestValkeyCacheTTL:
"""Tests for ValkeyCache TTL functionality."""
@pytest.mark.asyncio
async def test_set_with_explicit_ttl(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test setting a value with explicit TTL."""
await valkey_cache.set("ttl_key", "value", ttl=3600)
# Verify set was called with expiry
mock_glide_client.set.assert_called_once()
call_args = mock_glide_client.set.call_args
assert call_args[0][0] == "ttl_key"
assert call_args[0][1] == json.dumps("value")
assert "expiry" in call_args[1]
@pytest.mark.asyncio
async def test_set_with_default_ttl(
self, mock_glide_client: AsyncMock
) -> None:
"""Test setting a value with default TTL from constructor."""
cache = ValkeyCache(host="localhost", port=6379, default_ttl=1800)
async def mock_create_client() -> AsyncMock:
cache._client = mock_glide_client
return mock_glide_client
cache._get_client = mock_create_client # type: ignore[method-assign]
await cache.set("default_ttl_key", "value")
# Verify set was called with default TTL
mock_glide_client.set.assert_called_once()
call_args = mock_glide_client.set.call_args
assert "expiry" in call_args[1]
@pytest.mark.asyncio
async def test_set_without_ttl(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test setting a value without TTL (no expiration)."""
await valkey_cache.set("no_ttl_key", "value")
# Verify set was called without expiry
mock_glide_client.set.assert_called_once()
call_args = mock_glide_client.set.call_args
assert call_args[0][0] == "no_ttl_key"
assert call_args[0][1] == json.dumps("value")
# Should not have expiry parameter
assert "expiry" not in call_args[1] or call_args[1].get("expiry") is None
@pytest.mark.asyncio
async def test_set_with_zero_ttl_no_expiration(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test setting a value with TTL=0 means no expiration."""
await valkey_cache.set("zero_ttl_key", "value", ttl=0)
# Verify set was called without expiry
mock_glide_client.set.assert_called_once()
call_args = mock_glide_client.set.call_args
assert "expiry" not in call_args[1] or call_args[1].get("expiry") is None
@pytest.mark.asyncio
async def test_explicit_ttl_overrides_default(
self, mock_glide_client: AsyncMock
) -> None:
"""Test explicit TTL overrides default TTL."""
cache = ValkeyCache(host="localhost", port=6379, default_ttl=1800)
async def mock_create_client() -> AsyncMock:
cache._client = mock_glide_client
return mock_glide_client
cache._get_client = mock_create_client # type: ignore[method-assign]
await cache.set("override_key", "value", ttl=7200)
# Verify set was called with explicit TTL (7200), not default (1800)
mock_glide_client.set.assert_called_once()
call_args = mock_glide_client.set.call_args
assert "expiry" in call_args[1]
class TestValkeyCacheJSONSerialization:
"""Tests for ValkeyCache JSON serialization edge cases."""
@pytest.mark.asyncio
async def test_serialize_none_value(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test serializing None value."""
mock_glide_client.get.return_value = json.dumps(None)
await valkey_cache.set("none_key", None)
result = await valkey_cache.get("none_key")
assert result is None
@pytest.mark.asyncio
async def test_serialize_boolean_values(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test serializing boolean values."""
mock_glide_client.get.side_effect = [
json.dumps(True),
json.dumps(False),
]
await valkey_cache.set("true_key", True)
await valkey_cache.set("false_key", False)
result_true = await valkey_cache.get("true_key")
result_false = await valkey_cache.get("false_key")
assert result_true is True
assert result_false is False
@pytest.mark.asyncio
async def test_serialize_numeric_values(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test serializing numeric values (int, float)."""
mock_glide_client.get.side_effect = [
json.dumps(42),
json.dumps(3.14159),
json.dumps(0),
json.dumps(-100),
]
await valkey_cache.set("int_key", 42)
await valkey_cache.set("float_key", 3.14159)
await valkey_cache.set("zero_key", 0)
await valkey_cache.set("negative_key", -100)
assert await valkey_cache.get("int_key") == 42
assert await valkey_cache.get("float_key") == 3.14159
assert await valkey_cache.get("zero_key") == 0
assert await valkey_cache.get("negative_key") == -100
@pytest.mark.asyncio
async def test_serialize_empty_collections(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test serializing empty collections."""
mock_glide_client.get.side_effect = [
json.dumps([]),
json.dumps({}),
json.dumps(""),
]
await valkey_cache.set("empty_list", [])
await valkey_cache.set("empty_dict", {})
await valkey_cache.set("empty_string", "")
assert await valkey_cache.get("empty_list") == []
assert await valkey_cache.get("empty_dict") == {}
assert await valkey_cache.get("empty_string") == ""
@pytest.mark.asyncio
async def test_serialize_nested_structures(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test serializing deeply nested structures."""
nested_data = {
"level1": {
"level2": {
"level3": [1, 2, {"level4": "deep"}]
}
},
"list": [{"a": 1}, {"b": 2}]
}
mock_glide_client.get.return_value = json.dumps(nested_data)
await valkey_cache.set("nested_key", nested_data)
result = await valkey_cache.get("nested_key")
assert result == nested_data
@pytest.mark.asyncio
async def test_deserialize_invalid_json_returns_none(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test deserializing invalid JSON returns None and logs warning."""
mock_glide_client.get.return_value = "invalid json {{"
with patch("crewai.memory.storage.valkey_cache._logger") as mock_logger:
result = await valkey_cache.get("invalid_key")
assert result is None
mock_logger.warning.assert_called_once()
@pytest.mark.asyncio
async def test_serialize_unicode_strings(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test serializing unicode strings."""
unicode_data = "Hello 世界 🌍 Привет"
mock_glide_client.get.return_value = json.dumps(unicode_data)
await valkey_cache.set("unicode_key", unicode_data)
result = await valkey_cache.get("unicode_key")
assert result == unicode_data
class TestValkeyCacheConnectionManagement:
"""Tests for ValkeyCache connection management."""
@pytest.mark.asyncio
async def test_lazy_client_initialization(self) -> None:
"""Test client is initialized lazily on first use."""
cache = ValkeyCache(host="localhost", port=6379)
# Client should be None initially
assert cache._client is None
# Mock GlideClient.create
with patch("crewai.memory.storage.valkey_cache.GlideClient") as mock_glide:
mock_client = AsyncMock()
mock_glide.create = AsyncMock(return_value=mock_client)
mock_client.get = AsyncMock(return_value=None)
# First operation should initialize client
await cache.get("test_key")
# Client should now be initialized
assert cache._client is not None
mock_glide.create.assert_called_once()
@pytest.mark.asyncio
async def test_client_reuse_across_operations(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test client is reused across multiple operations."""
mock_glide_client.get.return_value = json.dumps("value")
mock_glide_client.exists.return_value = 1
# Perform multiple operations
await valkey_cache.get("key1")
await valkey_cache.set("key2", "value2")
await valkey_cache.exists("key3")
await valkey_cache.delete("key4")
# _get_client should return the same client instance
client1 = await valkey_cache._get_client()
client2 = await valkey_cache._get_client()
assert client1 is client2
@pytest.mark.asyncio
async def test_close_connection(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test closing the client connection."""
# Initialize client
await valkey_cache._get_client()
assert valkey_cache._client is not None
# Close connection
await valkey_cache.close()
# Verify close was called and client is None
mock_glide_client.close.assert_called_once()
assert valkey_cache._client is None
@pytest.mark.asyncio
async def test_connection_error_raises_runtime_error(self) -> None:
"""Test connection error raises RuntimeError with descriptive message."""
cache = ValkeyCache(host="invalid-host", port=9999)
with patch("crewai.memory.storage.valkey_cache.GlideClient") as mock_glide:
mock_glide.create = AsyncMock(side_effect=Exception("Connection refused"))
with pytest.raises(RuntimeError) as exc_info:
await cache._get_client()
assert "Cannot connect to Valkey" in str(exc_info.value)
@pytest.mark.asyncio
async def test_authentication_with_password(self) -> None:
"""Test client initialization with password authentication."""
cache = ValkeyCache(
host="localhost",
port=6379,
password="secret_password"
)
with patch("crewai.memory.storage.valkey_cache.GlideClient") as mock_glide:
mock_client = AsyncMock()
mock_glide.create = AsyncMock(return_value=mock_client)
await cache._get_client()
# Verify GlideClient.create was called with credentials
mock_glide.create.assert_called_once()
config = mock_glide.create.call_args[0][0]
assert hasattr(config, "credentials")
class TestValkeyCacheEdgeCases:
"""Tests for ValkeyCache edge cases and error conditions."""
@pytest.mark.asyncio
async def test_set_with_special_characters_in_key(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test setting values with special characters in key."""
special_keys = [
"key:with:colons",
"key/with/slashes",
"key-with-dashes",
"key_with_underscores",
"key.with.dots",
]
for key in special_keys:
await valkey_cache.set(key, "value")
mock_glide_client.set.assert_called()
@pytest.mark.asyncio
async def test_large_value_serialization(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test serializing large values."""
large_list = list(range(10000))
mock_glide_client.get.return_value = json.dumps(large_list)
await valkey_cache.set("large_key", large_list)
result = await valkey_cache.get("large_key")
assert result == large_list
@pytest.mark.asyncio
async def test_concurrent_operations(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test concurrent cache operations."""
import asyncio
mock_glide_client.get.return_value = json.dumps("value")
# Perform concurrent operations
tasks = [
valkey_cache.set(f"key{i}", f"value{i}")
for i in range(10)
]
await asyncio.gather(*tasks)
# Verify all operations completed
assert mock_glide_client.set.call_count == 10
@pytest.mark.asyncio
async def test_set_non_serializable_value_raises_type_error(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test that non-JSON-serializable values raise TypeError."""
from datetime import datetime
with pytest.raises(TypeError, match="not JSON-serializable"):
await valkey_cache.set("bad_key", datetime.now())
# Verify set was never called on the client
mock_glide_client.set.assert_not_called()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,267 @@
"""Tests for ValkeyStorage error handling."""
from __future__ import annotations
import asyncio
import json
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import uuid4
import pytest
from crewai.memory.storage.valkey_storage import ValkeyStorage
from crewai.memory.types import MemoryRecord
@pytest.fixture
def mock_glide_client() -> AsyncMock:
"""Create a mock GlideClient for testing."""
client = AsyncMock()
client.hset = AsyncMock(return_value=1)
client.zrange = AsyncMock(return_value=[])
client.zadd = AsyncMock()
client.sadd = AsyncMock()
client.hgetall = AsyncMock(return_value={})
client.close = AsyncMock()
return client
@pytest.fixture
def valkey_storage(mock_glide_client: AsyncMock) -> ValkeyStorage:
"""Create a ValkeyStorage instance with mocked client."""
storage = ValkeyStorage(host="localhost", port=6379, db=0)
# Mock the client creation to return our mock
async def mock_create_client() -> AsyncMock:
storage._client = mock_glide_client
return mock_glide_client
storage._get_client = mock_create_client # type: ignore[method-assign]
return storage
class TestSerializationErrors:
"""Tests for serialization error handling."""
def test_serialization_error_raises_descriptive_exception(
self, valkey_storage: ValkeyStorage
) -> None:
"""Test that serialization errors raise descriptive ValueError."""
# Create a record with non-serializable metadata
record = MemoryRecord(
id="test-id",
content="test content",
scope="/test",
categories=["test"],
metadata={"bad_key": object()}, # Non-serializable object
importance=0.5,
created_at=datetime.now(),
last_accessed=datetime.now(),
embedding=[0.1, 0.2, 0.3],
)
# Should raise ValueError with descriptive message
with pytest.raises(ValueError, match="Failed to serialize record test-id"):
valkey_storage._record_to_dict(record)
def test_serialization_error_includes_cause(
self, valkey_storage: ValkeyStorage
) -> None:
"""Test that serialization error includes the original exception as cause."""
# Create a mock record that will fail during JSON serialization
# We need to bypass Pydantic validation, so we'll patch json.dumps
record = MemoryRecord(
id="test-id-2",
content="test content",
scope="/test",
categories=["valid"],
metadata={"key": "value"},
importance=0.5,
created_at=datetime.now(),
last_accessed=datetime.now(),
embedding=[0.1, 0.2, 0.3],
)
# Patch json.dumps to raise an error
with patch("json.dumps", side_effect=TypeError("Cannot serialize")):
with pytest.raises(ValueError) as exc_info:
valkey_storage._record_to_dict(record)
# Verify the exception has a cause
assert exc_info.value.__cause__ is not None
assert isinstance(exc_info.value.__cause__, TypeError)
class TestDeserializationErrors:
"""Tests for deserialization error handling."""
def test_deserialization_error_logs_and_returns_none(
self, valkey_storage: ValkeyStorage, caplog: pytest.LogCaptureFixture
) -> None:
"""Test that deserialization errors log error and return None."""
# Create malformed data (missing required fields)
malformed_data = {
"id": "test-id",
"content": "test content",
# Missing scope, categories, metadata, etc.
}
# Should return None and log error
result = valkey_storage._dict_to_record(malformed_data)
assert result is None
assert "Failed to deserialize record test-id" in caplog.text
def test_deserialization_with_invalid_json_categories_uses_tag_fallback(
self, valkey_storage: ValkeyStorage
) -> None:
"""Test that non-JSON categories fall back to TAG (comma-separated) parsing."""
# Create data with non-JSON categories string
data = {
"id": "test-id-json",
"content": "test content",
"scope": "/test",
"categories": "not valid json [", # Not JSON, treated as TAG format
"metadata": "{}",
"importance": "0.5",
"created_at": "2024-01-01T12:00:00",
"last_accessed": "2024-01-01T12:00:00",
"source": "",
"private": "false",
}
result = valkey_storage._dict_to_record(data)
# TAG fallback: comma-split produces the raw string as a single category
assert result is not None
assert result.id == "test-id-json"
assert result.categories == ["not valid json ["]
def test_deserialization_with_invalid_datetime_returns_none(
self, valkey_storage: ValkeyStorage, caplog: pytest.LogCaptureFixture
) -> None:
"""Test that invalid datetime format returns None."""
# Create data with invalid datetime
invalid_data = {
"id": "test-id-datetime",
"content": "test content",
"scope": "/test",
"categories": '["test"]',
"metadata": "{}",
"importance": "0.5",
"created_at": "not a datetime", # Invalid datetime
"last_accessed": "2024-01-01T12:00:00",
"source": "",
"private": "false",
}
result = valkey_storage._dict_to_record(invalid_data)
assert result is None
assert "Failed to deserialize record test-id-datetime" in caplog.text
def test_deserialization_with_invalid_float_returns_none(
self, valkey_storage: ValkeyStorage, caplog: pytest.LogCaptureFixture
) -> None:
"""Test that invalid float importance returns None."""
# Create data with invalid float
invalid_data = {
"id": "test-id-float",
"content": "test content",
"scope": "/test",
"categories": '["test"]',
"metadata": "{}",
"importance": "not a float", # Invalid float
"created_at": "2024-01-01T12:00:00",
"last_accessed": "2024-01-01T12:00:00",
"source": "",
"private": "false",
}
result = valkey_storage._dict_to_record(invalid_data)
assert result is None
assert "Failed to deserialize record test-id-float" in caplog.text
def test_deserialization_with_bytes_keys_uses_tag_fallback(
self, valkey_storage: ValkeyStorage
) -> None:
"""Test that deserialization handles bytes keys with non-JSON categories via TAG fallback."""
# Create data with bytes keys (as returned by Valkey)
bytes_data = {
b"id": b"test-id-bytes",
b"content": b"test content",
b"scope": b"/test",
b"categories": b"invalid json [", # Not JSON, treated as TAG format
b"metadata": b"{}",
b"importance": b"0.5",
b"created_at": b"2024-01-01T12:00:00",
b"last_accessed": b"2024-01-01T12:00:00",
}
result = valkey_storage._dict_to_record(bytes_data)
# TAG fallback: comma-split produces the raw string as a single category
assert result is not None
assert result.id == "test-id-bytes"
assert result.categories == ["invalid json ["]
class TestRetryBehaviorIntegration:
"""Integration tests demonstrating retry behavior patterns."""
@pytest.mark.asyncio
async def test_mock_client_operation_with_retry_pattern(
self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test demonstrating how retry would work with client operations."""
from glide import ClosingError
# Mock a client operation that fails once
mock_glide_client.hgetall.side_effect = [
ClosingError("Connection lost"),
{
b"id": b"test-id",
b"content": b"test content",
b"scope": b"/test",
b"categories": b'["test"]',
b"metadata": b"{}",
b"importance": b"0.5",
b"created_at": b"2024-01-01T12:00:00",
b"last_accessed": b"2024-01-01T12:00:00",
b"source": b"",
b"private": b"false",
b"embedding": b"",
},
]
# First call fails, second succeeds
with pytest.raises(ClosingError):
await mock_glide_client.hgetall("record:test-id")
# Second call succeeds
result = await mock_glide_client.hgetall("record:test-id")
assert result is not None
@pytest.mark.asyncio
async def test_serialization_error_not_retried(
self, valkey_storage: ValkeyStorage
) -> None:
"""Test that serialization errors are not retried (they're not connection errors)."""
# Create a record with non-serializable data
record = MemoryRecord(
id="test-id",
content="test content",
scope="/test",
categories=["test"],
metadata={"bad": object()},
importance=0.5,
created_at=datetime.now(),
last_accessed=datetime.now(),
embedding=[0.1, 0.2, 0.3],
)
# Serialization error should not be retried
with pytest.raises(ValueError, match="Failed to serialize"):
valkey_storage._record_to_dict(record)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,998 @@
"""Tests for ValkeyStorage vector search operation."""
from __future__ import annotations
import json
from datetime import datetime
from unittest.mock import AsyncMock, patch
from uuid import uuid4
import pytest
from crewai.memory.storage.valkey_storage import ValkeyStorage
from crewai.memory.types import MemoryRecord
@pytest.fixture
def mock_glide_client() -> AsyncMock:
"""Create a mock GlideClient for testing."""
client = AsyncMock()
client.hset = AsyncMock(return_value=1)
client.zrange = AsyncMock(return_value=[])
client.zadd = AsyncMock()
client.sadd = AsyncMock()
client.hgetall = AsyncMock(return_value={})
client.close = AsyncMock()
return client
@pytest.fixture
def valkey_storage(mock_glide_client: AsyncMock) -> ValkeyStorage:
"""Create a ValkeyStorage instance with mocked client."""
storage = ValkeyStorage(host="localhost", port=6379, db=0)
# Mock the client creation to return our mock
async def mock_create_client() -> AsyncMock:
storage._client = mock_glide_client
return mock_glide_client
storage._get_client = mock_create_client # type: ignore[method-assign]
return storage
def create_mock_ft_search_response(
records: list[tuple[MemoryRecord, float]]
) -> list[int | dict[str, dict[str, str]]]:
"""Create a mock FT.SEARCH response in native dict format.
Args:
records: List of (MemoryRecord, score) tuples to include in response.
Returns:
Mock FT.SEARCH response in the native format:
[total_count, {doc_key: {field: value, ...}, ...}]
"""
if not records:
return [0]
docs: dict[str, dict[str, str]] = {}
for record, score in records:
doc_key = f"record:{record.id}"
# Build field dict
fields: dict[str, str] = {}
fields["id"] = record.id
fields["content"] = record.content
fields["scope"] = record.scope
fields["categories"] = json.dumps(record.categories)
fields["metadata"] = json.dumps(record.metadata)
fields["importance"] = str(record.importance)
fields["created_at"] = record.created_at.isoformat()
fields["last_accessed"] = record.last_accessed.isoformat()
fields["source"] = record.source or ""
fields["private"] = "true" if record.private else "false"
# Add score (Valkey Search returns cosine distance, not similarity)
# Convert similarity to distance: distance = 2 * (1 - similarity)
distance = 2.0 * (1.0 - score)
fields["score"] = str(distance)
# Add embedding if present
if record.embedding:
fields["embedding"] = json.dumps(record.embedding)
docs[doc_key] = fields
return [len(records), docs]
class TestValkeyStorageVectorSearch:
"""Tests for ValkeyStorage vector search operation."""
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_with_no_filters_returns_all_records(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search with no filters returns all records."""
# Create test records
record1 = MemoryRecord(
id="record-1",
content="First test record",
scope="/test",
categories=["cat1"],
metadata={"key": "value1"},
importance=0.8,
created_at=datetime(2024, 1, 1, 10, 0, 0),
last_accessed=datetime(2024, 1, 1, 11, 0, 0),
embedding=[0.1, 0.2, 0.3, 0.4],
)
record2 = MemoryRecord(
id="record-2",
content="Second test record",
scope="/test",
categories=["cat2"],
metadata={"key": "value2"},
importance=0.6,
created_at=datetime(2024, 1, 2, 10, 0, 0),
last_accessed=datetime(2024, 1, 2, 11, 0, 0),
embedding=[0.2, 0.3, 0.4, 0.5],
)
# Mock FT.INFO to simulate index exists
mock_ft_list.return_value = [b"memory_index"]
# Mock FT.SEARCH to return both records
mock_ft_search.return_value = create_mock_ft_search_response([
(record1, 0.95),
(record2, 0.85),
])
# Perform search with no filters
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(query_embedding, limit=10)
# Verify ft.search was called
mock_ft_search.assert_called_once()
# Verify query contains only KNN part (no filters)
call_args = mock_ft_search.call_args
query = call_args[0][2] # 3rd positional arg: query string
assert "*=>[KNN 10 @embedding $BLOB AS score]" in query
assert "@scope" not in query
assert "@categories" not in query
# Verify results
assert len(results) == 2
assert results[0][0].id == "record-1"
assert results[0][1] == 0.95
assert results[1][0].id == "record-2"
assert results[1][1] == 0.85
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_with_scope_filter_only(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search with scope filter only."""
record1 = MemoryRecord(
id="record-1",
content="Record in scope",
scope="/agent/task",
embedding=[0.1, 0.2, 0.3, 0.4],
)
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)])
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(
query_embedding,
scope_prefix="/agent",
limit=10
)
# Verify query contains scope filter
call_args = mock_ft_search.call_args
query = call_args[0][2]
assert "(@scope:{/agent*})=>[KNN 10 @embedding $BLOB AS score]" in query
# Verify results
assert len(results) == 1
assert results[0][0].id == "record-1"
assert results[0][0].scope == "/agent/task"
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_with_category_filter_only(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search with category filter only."""
record1 = MemoryRecord(
id="record-1",
content="Record with planning category",
scope="/test",
categories=["planning"],
embedding=[0.1, 0.2, 0.3, 0.4],
)
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.88)])
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(
query_embedding,
categories=["planning", "execution"],
limit=10
)
# Verify query contains category filter with OR logic
call_args = mock_ft_search.call_args
query = call_args[0][2]
assert "(@categories:{planning|execution})=>[KNN 10 @embedding $BLOB AS score]" in query
# Verify results
assert len(results) == 1
assert results[0][0].id == "record-1"
assert "planning" in results[0][0].categories
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_with_metadata_filter_only(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search with metadata filter only."""
record1 = MemoryRecord(
id="record-1",
content="Record with metadata",
scope="/test",
metadata={"agent_id": "agent-1", "priority": "high"},
embedding=[0.1, 0.2, 0.3, 0.4],
)
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.92)])
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(
query_embedding,
metadata_filter={"agent_id": "agent-1", "priority": "high"},
limit=10
)
# Verify query contains metadata filters (AND logic)
call_args = mock_ft_search.call_args
query = call_args[0][2]
assert "@agent_id:{agent\\-1}" in query or "@agent_id:{agent-1}" in query
assert "@priority:{high}" in query
assert "=>[KNN 10 @embedding $BLOB AS score]" in query
# Verify results
assert len(results) == 1
assert results[0][0].id == "record-1"
assert results[0][0].metadata["agent_id"] == "agent-1"
assert results[0][0].metadata["priority"] == "high"
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_with_combined_filters(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search with combined filters (scope + categories + metadata)."""
record1 = MemoryRecord(
id="record-1",
content="Record matching all filters",
scope="/agent/task",
categories=["planning"],
metadata={"agent_id": "agent-1"},
embedding=[0.1, 0.2, 0.3, 0.4],
)
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.93)])
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(
query_embedding,
scope_prefix="/agent",
categories=["planning"],
metadata_filter={"agent_id": "agent-1"},
limit=10
)
# Verify query contains all filters
call_args = mock_ft_search.call_args
query = call_args[0][2]
assert "@scope:{/agent*}" in query
assert "@categories:{planning}" in query
assert "@agent_id:{agent\\-1}" in query or "@agent_id:{agent-1}" in query
assert "=>[KNN 10 @embedding $BLOB AS score]" in query
# Verify results
assert len(results) == 1
assert results[0][0].id == "record-1"
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_respects_limit_parameter(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search respects limit parameter."""
records = [
(
MemoryRecord(
id=f"record-{i}",
content=f"Record {i}",
scope="/test",
embedding=[0.1 * i, 0.2 * i, 0.3 * i, 0.4 * i],
),
0.9 - (i * 0.1)
)
for i in range(1, 6)
]
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = create_mock_ft_search_response(records[:3])
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(query_embedding, limit=3)
# Verify KNN limit in query
call_args = mock_ft_search.call_args
query = call_args[0][2]
assert "=>[KNN 3 @embedding $BLOB AS score]" in query
# Verify results respect limit
assert len(results) == 3
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_respects_min_score_parameter(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search respects min_score parameter."""
record1 = MemoryRecord(
id="record-1",
content="High score record",
scope="/test",
embedding=[0.1, 0.2, 0.3, 0.4],
)
record2 = MemoryRecord(
id="record-2",
content="Medium score record",
scope="/test",
embedding=[0.2, 0.3, 0.4, 0.5],
)
record3 = MemoryRecord(
id="record-3",
content="Low score record",
scope="/test",
embedding=[0.3, 0.4, 0.5, 0.6],
)
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = create_mock_ft_search_response([
(record1, 0.95),
(record2, 0.75),
(record3, 0.55),
])
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(
query_embedding,
limit=10,
min_score=0.7
)
# Verify only records with score >= 0.7 are returned
assert len(results) == 2
assert results[0][0].id == "record-1"
assert results[0][1] == 0.95
assert results[1][0].id == "record-2"
assert results[1][1] == 0.75
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_returns_results_ordered_by_descending_score(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search returns results ordered by descending score."""
record1 = MemoryRecord(
id="record-1",
content="Medium score",
scope="/test",
embedding=[0.1, 0.2, 0.3, 0.4],
)
record2 = MemoryRecord(
id="record-2",
content="Highest score",
scope="/test",
embedding=[0.2, 0.3, 0.4, 0.5],
)
record3 = MemoryRecord(
id="record-3",
content="Lowest score",
scope="/test",
embedding=[0.3, 0.4, 0.5, 0.6],
)
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = create_mock_ft_search_response([
(record1, 0.75),
(record2, 0.95),
(record3, 0.55),
])
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(query_embedding, limit=10)
# Verify results are ordered by descending score
assert len(results) == 3
assert results[0][0].id == "record-2"
assert results[0][1] == 0.95
assert results[1][0].id == "record-1"
assert results[1][1] == 0.75
assert results[2][0].id == "record-3"
assert results[2][1] == 0.55
# Verify scores are in descending order
for i in range(len(results) - 1):
assert results[i][1] >= results[i + 1][1]
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_with_empty_results(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search with no matching results."""
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = [0] # Total count = 0
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(query_embedding, limit=10)
# Verify empty results
assert len(results) == 0
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_with_special_characters_in_scope(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search with special characters in scope prefix."""
record1 = MemoryRecord(
id="record-1",
content="Record with special scope",
scope="/agent:task-1",
embedding=[0.1, 0.2, 0.3, 0.4],
)
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)])
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(
query_embedding,
scope_prefix="/agent:task",
limit=10
)
# Verify query contains escaped scope
call_args = mock_ft_search.call_args
query = call_args[0][2]
assert "@scope:{/agent\\:task*}" in query or "@scope:{/agent:task*}" in query
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_with_special_characters_in_categories(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search with special characters in categories."""
record1 = MemoryRecord(
id="record-1",
content="Record with special category",
scope="/test",
categories=["plan:execute"],
embedding=[0.1, 0.2, 0.3, 0.4],
)
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)])
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(
query_embedding,
categories=["plan:execute"],
limit=10
)
# Verify query contains escaped category
call_args = mock_ft_search.call_args
query = call_args[0][2]
assert "@categories:{plan\\:execute}" in query or "@categories:{plan:execute}" in query
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_with_numeric_metadata_values(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search with numeric metadata values."""
record1 = MemoryRecord(
id="record-1",
content="Record with numeric metadata",
scope="/test",
metadata={"count": 42, "score": 3.14},
embedding=[0.1, 0.2, 0.3, 0.4],
)
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)])
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(
query_embedding,
metadata_filter={"count": 42, "score": 3.14},
limit=10
)
# Verify query contains string-converted metadata values
call_args = mock_ft_search.call_args
query = call_args[0][2]
assert "@count:{42}" in query
assert "@score:{3" in query and "14}" in query
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_with_embedding_blob_parameter(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search passes embedding as BLOB parameter."""
record1 = MemoryRecord(
id="record-1",
content="Test record",
scope="/test",
embedding=[0.1, 0.2, 0.3, 0.4],
)
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)])
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(query_embedding, limit=10)
# Verify ft.search was called with search options containing BLOB param
call_args = mock_ft_search.call_args
# The 4th positional arg is the FtSearchOptions
search_options = call_args[0][3]
# The options object should have params with BLOB
assert search_options is not None
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_results_sorted_by_score(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search results are sorted by score (descending) automatically."""
record1 = MemoryRecord(
id="record-1",
content="Test record",
scope="/test",
embedding=[0.1, 0.2, 0.3, 0.4],
)
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)])
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(query_embedding, limit=10)
# Verify ft.search was called (results are auto-sorted by vector search)
mock_ft_search.assert_called_once()
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_with_return_fields(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search includes RETURN clause with all record fields."""
record1 = MemoryRecord(
id="record-1",
content="Test record",
scope="/test",
embedding=[0.1, 0.2, 0.3, 0.4],
)
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)])
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(query_embedding, limit=10)
# Verify ft.search was called with search options containing return fields
call_args = mock_ft_search.call_args
search_options = call_args[0][3]
assert search_options is not None
# The FtSearchOptions should have return_fields set
assert search_options.return_fields is not None
assert len(search_options.return_fields) == 11 # All fields including score
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.VectorFieldAttributesHnsw")
@patch("crewai.memory.storage.valkey_storage.ft.create")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_handles_valkey_search_not_available(
self, mock_ft_list: AsyncMock, mock_ft_create: AsyncMock,
mock_vector_attrs: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search raises error when Valkey Search module is not available."""
# Mock FT.INFO to fail (index doesn't exist)
mock_ft_list.return_value = []
# Mock FT.CREATE to fail (Search module not available)
mock_ft_create.side_effect = Exception("ERR unknown command 'ft.create'")
query_embedding = [0.1, 0.2, 0.3, 0.4]
with pytest.raises(RuntimeError, match="Valkey Search module is not available"):
await valkey_storage.asearch(query_embedding, limit=10)
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_handles_ft_search_error(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search handles FT.SEARCH errors gracefully."""
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.side_effect = Exception("ERR unknown command 'FT.SEARCH'")
query_embedding = [0.1, 0.2, 0.3, 0.4]
with pytest.raises(RuntimeError, match="Valkey Search module is not available"):
await valkey_storage.asearch(query_embedding, limit=10)
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_handles_malformed_ft_search_response(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search handles malformed FT.SEARCH response gracefully."""
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = None # Malformed response
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(query_embedding, limit=10)
# Verify empty results are returned (graceful handling)
assert len(results) == 0
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_handles_missing_score_field(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search handles missing score field in results."""
record1 = MemoryRecord(
id="record-1",
content="Test record",
scope="/test",
embedding=[0.1, 0.2, 0.3, 0.4],
)
# Create mock response without score field (dict format)
docs = {
f"record:{record1.id}": {
"id": record1.id,
"content": record1.content,
"scope": record1.scope,
"categories": str(record1.categories),
"metadata": str(record1.metadata),
"importance": str(record1.importance),
"created_at": record1.created_at.isoformat(),
"last_accessed": record1.last_accessed.isoformat(),
"source": record1.source or "",
"private": "false",
# No score field
}
}
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = [1, docs]
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(query_embedding, limit=10)
# Verify record is returned with default score of 0.0
assert len(results) == 1
assert results[0][0].id == "record-1"
assert results[0][1] == 0.0
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_filters_out_records_with_deserialization_errors(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search filters out records that fail deserialization."""
valid_record = MemoryRecord(
id="valid-record",
content="Valid record",
scope="/test",
embedding=[0.1, 0.2, 0.3, 0.4],
)
# Create mock response with one valid and one invalid record (dict format)
docs = {
f"record:{valid_record.id}": {
"id": valid_record.id,
"content": valid_record.content,
"scope": valid_record.scope,
"categories": str(valid_record.categories),
"metadata": str(valid_record.metadata),
"importance": str(valid_record.importance),
"created_at": valid_record.created_at.isoformat(),
"last_accessed": valid_record.last_accessed.isoformat(),
"source": valid_record.source or "",
"private": "false",
"score": "0.1",
},
"record:invalid-record": {
"id": "invalid-record",
# Missing content, scope, and other required fields
"score": "0.2",
},
}
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = [2, docs]
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(query_embedding, limit=10)
# Verify only valid record is returned
assert len(results) == 1
assert results[0][0].id == "valid-record"
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_converts_cosine_distance_to_similarity(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search converts Valkey Search cosine distance to similarity score."""
record1 = MemoryRecord(
id="record-1",
content="Test record",
scope="/test",
embedding=[0.1, 0.2, 0.3, 0.4],
)
# Create mock response with distance score (dict format)
docs = {
f"record:{record1.id}": {
"id": record1.id,
"content": record1.content,
"scope": record1.scope,
"categories": str(record1.categories),
"metadata": str(record1.metadata),
"importance": str(record1.importance),
"created_at": record1.created_at.isoformat(),
"last_accessed": record1.last_accessed.isoformat(),
"source": record1.source or "",
"private": "false",
"score": "0.1", # Distance = 0.1
}
}
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = [1, docs]
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(query_embedding, limit=10)
# Verify similarity score is correctly converted
assert len(results) == 1
assert results[0][0].id == "record-1"
# Distance 0.1 -> Similarity = 1 - (0.1 / 2) = 0.95
assert abs(results[0][1] - 0.95) < 0.01
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
def test_search_sync_wrapper(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test that sync search wrapper calls async implementation."""
record1 = MemoryRecord(
id="record-1",
content="Test record",
scope="/test",
embedding=[0.1, 0.2, 0.3, 0.4],
)
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)])
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = valkey_storage.search(query_embedding, limit=10)
# Verify ft.search was called
assert mock_ft_search.call_count >= 1
# Verify results
assert len(results) == 1
assert results[0][0].id == "record-1"
assert results[0][1] == 0.9
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_with_multiple_categories_uses_or_logic(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search with multiple categories uses OR logic."""
record1 = MemoryRecord(
id="record-1",
content="Record with one matching category",
scope="/test",
categories=["planning"],
embedding=[0.1, 0.2, 0.3, 0.4],
)
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)])
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(
query_embedding,
categories=["planning", "execution", "review"],
limit=10
)
# Verify query contains OR logic for categories
call_args = mock_ft_search.call_args
query = call_args[0][2]
assert "@categories:{planning|execution|review}" in query
# Verify record with only one matching category is returned
assert len(results) == 1
assert results[0][0].id == "record-1"
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_with_multiple_metadata_filters_uses_and_logic(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search with multiple metadata filters uses AND logic."""
record1 = MemoryRecord(
id="record-1",
content="Record matching all metadata",
scope="/test",
metadata={"agent_id": "agent-1", "priority": "high", "status": "active"},
embedding=[0.1, 0.2, 0.3, 0.4],
)
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)])
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(
query_embedding,
metadata_filter={"agent_id": "agent-1", "priority": "high", "status": "active"},
limit=10
)
# Verify query contains AND logic for metadata
call_args = mock_ft_search.call_args
query = call_args[0][2]
assert "@agent_id:" in query
assert "@priority:" in query
assert "@status:" in query
# Verify record matching all metadata is returned
assert len(results) == 1
assert results[0][0].id == "record-1"
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_with_zero_limit_returns_empty(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search with limit=0 returns empty results."""
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = [0]
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(query_embedding, limit=0)
# Verify empty results
assert len(results) == 0
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_with_min_score_one_filters_all(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search with min_score=1.0 filters out all non-perfect matches."""
record1 = MemoryRecord(
id="record-1",
content="High score but not perfect",
scope="/test",
embedding=[0.1, 0.2, 0.3, 0.4],
)
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.99)])
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(
query_embedding,
limit=10,
min_score=1.0
)
# Verify all results are filtered out
assert len(results) == 0
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_with_min_score_zero_returns_all(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search with min_score=0.0 returns all results."""
record1 = MemoryRecord(
id="record-1",
content="High score",
scope="/test",
embedding=[0.1, 0.2, 0.3, 0.4],
)
record2 = MemoryRecord(
id="record-2",
content="Low score",
scope="/test",
embedding=[0.2, 0.3, 0.4, 0.5],
)
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = create_mock_ft_search_response([
(record1, 0.95),
(record2, 0.05),
])
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(
query_embedding,
limit=10,
min_score=0.0
)
# Verify all results are returned
assert len(results) == 2
assert results[0][0].id == "record-1"
assert results[1][0].id == "record-2"

View File

@@ -0,0 +1,115 @@
"""Tests for embedding safety: bytes→float validators and async-safe embed_texts."""
from __future__ import annotations
import asyncio
import concurrent.futures
from unittest.mock import MagicMock
import numpy as np
import pytest
from crewai.memory.types import MemoryRecord, embed_text, embed_texts
class TestMemoryRecordEmbeddingValidator:
"""Tests for MemoryRecord.validate_embedding (bytes→list[float])."""
def test_none_embedding_stays_none(self) -> None:
r = MemoryRecord(content="test", embedding=None)
assert r.embedding is None
def test_list_of_floats_passes_through(self) -> None:
r = MemoryRecord(content="test", embedding=[0.1, 0.2, 0.3])
assert r.embedding == [0.1, 0.2, 0.3]
def test_bytes_converted_to_list_float(self) -> None:
arr = np.array([0.1, 0.2, 0.3], dtype=np.float32)
raw_bytes = arr.tobytes()
r = MemoryRecord(content="test", embedding=raw_bytes)
assert r.embedding is not None
assert len(r.embedding) == 3
assert all(isinstance(x, float) for x in r.embedding)
np.testing.assert_allclose(r.embedding, [0.1, 0.2, 0.3], atol=1e-6)
def test_empty_bytes_becomes_none(self) -> None:
r = MemoryRecord(content="test", embedding=b"")
assert r.embedding is None
def test_list_of_ints_converted_to_floats(self) -> None:
r = MemoryRecord(content="test", embedding=[1, 2, 3])
assert r.embedding == [1.0, 2.0, 3.0]
assert all(isinstance(x, float) for x in r.embedding)
def test_numpy_array_converted_to_list(self) -> None:
arr = np.array([0.5, 0.6], dtype=np.float32)
r = MemoryRecord(content="test", embedding=arr)
assert r.embedding is not None
assert isinstance(r.embedding, list)
assert len(r.embedding) == 2
class TestEmbedTextsAsyncSafety:
"""Tests for embed_texts running safely in async context."""
def test_embed_texts_sync_context(self) -> None:
"""embed_texts works in a normal sync context."""
embedder = MagicMock(return_value=[[0.1, 0.2], [0.3, 0.4]])
result = embed_texts(embedder, ["hello", "world"])
assert len(result) == 2
assert result[0] == [0.1, 0.2]
embedder.assert_called_once()
def test_embed_texts_empty_input(self) -> None:
embedder = MagicMock()
assert embed_texts(embedder, []) == []
embedder.assert_not_called()
def test_embed_texts_all_empty_strings(self) -> None:
embedder = MagicMock()
result = embed_texts(embedder, ["", " ", ""])
assert result == [[], [], []]
embedder.assert_not_called()
def test_embed_texts_skips_empty_preserves_positions(self) -> None:
embedder = MagicMock(return_value=[[0.1, 0.2]])
result = embed_texts(embedder, ["", "hello", ""])
assert result == [[], [0.1, 0.2], []]
embedder.assert_called_once_with(["hello"])
def test_embed_texts_in_async_context(self) -> None:
"""embed_texts uses thread pool when called from async context."""
embedder = MagicMock(return_value=[[0.1, 0.2]])
async def run() -> list[list[float]]:
return embed_texts(embedder, ["hello"])
result = asyncio.run(run())
assert result == [[0.1, 0.2]]
embedder.assert_called_once()
class TestEmbedText:
"""Tests for embed_text (single text)."""
def test_empty_string_returns_empty(self) -> None:
embedder = MagicMock()
assert embed_text(embedder, "") == []
embedder.assert_not_called()
def test_whitespace_only_returns_empty(self) -> None:
embedder = MagicMock()
assert embed_text(embedder, " ") == []
embedder.assert_not_called()
def test_normal_text_returns_embedding(self) -> None:
embedder = MagicMock(return_value=[[0.1, 0.2, 0.3]])
result = embed_text(embedder, "hello")
assert result == [0.1, 0.2, 0.3]
def test_numpy_array_result_converted(self) -> None:
arr = np.array([0.1, 0.2], dtype=np.float32)
embedder = MagicMock(return_value=[arr])
result = embed_text(embedder, "hello")
assert isinstance(result, list)
assert len(result) == 2

View File

@@ -0,0 +1,125 @@
"""Tests for shared cache configuration helpers."""
from __future__ import annotations
import os
from unittest.mock import patch
import pytest
from crewai.utilities.cache_config import (
get_aiocache_config,
parse_cache_url,
use_valkey_cache,
)
class TestParseCacheUrl:
"""Tests for parse_cache_url()."""
def test_returns_none_when_no_env_vars(self) -> None:
with patch.dict(os.environ, {}, clear=True):
assert parse_cache_url() is None
def test_parses_valkey_url(self) -> None:
with patch.dict(
os.environ, {"VALKEY_URL": "redis://myhost:6380/2"}, clear=True
):
result = parse_cache_url()
assert result is not None
assert result["host"] == "myhost"
assert result["port"] == 6380
assert result["db"] == 2
assert result["password"] is None
def test_parses_redis_url(self) -> None:
with patch.dict(
os.environ, {"REDIS_URL": "redis://localhost:6379/0"}, clear=True
):
result = parse_cache_url()
assert result is not None
assert result["host"] == "localhost"
assert result["port"] == 6379
assert result["db"] == 0
def test_valkey_url_takes_priority_over_redis_url(self) -> None:
with patch.dict(
os.environ,
{
"VALKEY_URL": "redis://valkey-host:6380/1",
"REDIS_URL": "redis://redis-host:6379/0",
},
clear=True,
):
result = parse_cache_url()
assert result is not None
assert result["host"] == "valkey-host"
assert result["port"] == 6380
def test_parses_password(self) -> None:
with patch.dict(
os.environ,
{"VALKEY_URL": "redis://:s3cret@myhost:6379/0"},
clear=True,
):
result = parse_cache_url()
assert result is not None
assert result["password"] == "s3cret"
def test_defaults_for_minimal_url(self) -> None:
with patch.dict(
os.environ, {"VALKEY_URL": "redis://myhost"}, clear=True
):
result = parse_cache_url()
assert result is not None
assert result["host"] == "myhost"
assert result["port"] == 6379
assert result["db"] == 0
assert result["password"] is None
def test_non_numeric_db_path_defaults_to_zero(self) -> None:
with patch.dict(
os.environ, {"VALKEY_URL": "redis://myhost:6379/mydb"}, clear=True
):
result = parse_cache_url()
assert result is not None
assert result["db"] == 0
class TestGetAiocacheConfig:
"""Tests for get_aiocache_config()."""
def test_returns_memory_cache_when_no_url(self) -> None:
with patch.dict(os.environ, {}, clear=True):
config = get_aiocache_config()
assert config["default"]["cache"] == "aiocache.SimpleMemoryCache"
def test_returns_redis_cache_when_url_set(self) -> None:
with patch.dict(
os.environ, {"VALKEY_URL": "redis://myhost:6380/2"}, clear=True
):
config = get_aiocache_config()
assert config["default"]["cache"] == "aiocache.RedisCache"
assert config["default"]["endpoint"] == "myhost"
assert config["default"]["port"] == 6380
assert config["default"]["db"] == 2
class TestUseValkeyCache:
"""Tests for use_valkey_cache()."""
def test_returns_false_when_not_set(self) -> None:
with patch.dict(os.environ, {}, clear=True):
assert use_valkey_cache() is False
def test_returns_true_when_set(self) -> None:
with patch.dict(
os.environ, {"VALKEY_URL": "redis://localhost:6379"}, clear=True
):
assert use_valkey_cache() is True
def test_returns_false_when_only_redis_url_set(self) -> None:
with patch.dict(
os.environ, {"REDIS_URL": "redis://localhost:6379"}, clear=True
):
assert use_valkey_cache() is False

View File

@@ -205,6 +205,8 @@ override-dependencies = [
"gitpython>=3.1.50,<4",
"langsmith>=0.7.31,<0.8",
"authlib>=1.6.11",
# scrapegraph-py 2.x removed Client class; pin until upstream fixes type ignores
"scrapegraph-py>=1.46.0,<2",
]
[tool.uv.workspace]

46
uv.lock generated
View File

@@ -13,7 +13,7 @@ resolution-markers = [
]
[options]
exclude-newer = "2026-05-08T16:33:02.834109Z"
exclude-newer = "2026-05-08T20:07:25.621408Z"
exclude-newer-span = "P3D"
[manifest]
@@ -38,6 +38,7 @@ overrides = [
{ name = "pypdf", specifier = ">=6.10.2,<7" },
{ name = "python-multipart", specifier = ">=0.0.27,<1" },
{ name = "rich", specifier = ">=13.7.1" },
{ name = "scrapegraph-py", specifier = ">=1.46.0,<2" },
{ name = "transformers", marker = "python_full_version >= '3.10'", specifier = ">=5.4.0" },
{ name = "urllib3", specifier = ">=2.7.0" },
{ name = "uv", specifier = ">=0.11.6,<1" },
@@ -1365,6 +1366,9 @@ qdrant-edge = [
tools = [
{ name = "crewai-tools" },
]
valkey = [
{ name = "valkey-glide" },
]
voyageai = [
{ name = "voyageai" },
]
@@ -1426,9 +1430,10 @@ requires-dist = [
{ name = "tokenizers", specifier = ">=0.21,<1" },
{ name = "tomli", specifier = "~=2.0.2" },
{ name = "tomli-w", specifier = "~=1.1.0" },
{ name = "valkey-glide", marker = "extra == 'valkey'", specifier = ">=1.3.0" },
{ name = "voyageai", marker = "extra == 'voyageai'", specifier = "~=0.3.5" },
]
provides-extras = ["a2a", "anthropic", "aws", "azure-ai-inference", "bedrock", "docling", "embeddings", "file-processing", "google-genai", "litellm", "mem0", "openpyxl", "pandas", "qdrant", "qdrant-edge", "tools", "voyageai", "watson"]
provides-extras = ["a2a", "anthropic", "aws", "azure-ai-inference", "bedrock", "docling", "embeddings", "file-processing", "google-genai", "litellm", "mem0", "openpyxl", "pandas", "qdrant", "qdrant-edge", "tools", "valkey", "voyageai", "watson"]
[[package]]
name = "crewai-cli"
@@ -9533,6 +9538,43 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/fa/6e/3e955517e22cbdd565f2f8b2e73d52528b14b8bcfdb04f62466b071de847/validators-0.35.0-py3-none-any.whl", hash = "sha256:e8c947097eae7892cb3d26868d637f79f47b4a0554bc6b80065dfe5aac3705dd", size = 44712, upload-time = "2025-05-01T05:42:04.203Z" },
]
[[package]]
name = "valkey-glide"
version = "2.0.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio" },
{ name = "protobuf" },
{ name = "typing-extensions", marker = "python_full_version < '3.11'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/32/35/fb0401c4bc7be748d937e95213786d21d9e56767b3ad816db5bad6f92c01/valkey_glide-2.0.1.tar.gz", hash = "sha256:4f9c62a88aedffd725cced7d28a9488b27e3f675d1a5294b4962624e97d346c4", size = 1026255, upload-time = "2025-06-20T01:08:15.861Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/44/a3/bf5ff3841538d0bb337371e073dc2c0e93f748f7f8b10a44806f36ab5fa1/valkey_glide-2.0.1-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:b3307934b76557b18ac559f327592cc09fc895fc653ba46010dd6d70fb6239dc", size = 5074638, upload-time = "2025-06-20T01:07:30.16Z" },
{ url = "https://files.pythonhosted.org/packages/0f/c4/20b66dced96bdca81aa294b39bc03018ed22628c52076752e8d1d3540a7d/valkey_glide-2.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6b83d34e2e723e97c41682479b0dce5882069066e808316292b363855992b449", size = 4750261, upload-time = "2025-06-20T01:07:32.452Z" },
{ url = "https://files.pythonhosted.org/packages/53/58/6440e66bde8963d86bc3c44d88f993059f2a9d7ebdb3256a695d035cff50/valkey_glide-2.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1baaf14d09d464ae645be5bdb5dc6b8a38b7eacf22f9dcb2907200c74fbdcdd3", size = 4767755, upload-time = "2025-06-20T01:07:33.86Z" },
{ url = "https://files.pythonhosted.org/packages/3b/69/dd5c350ce4d2cadde0d83beb601f05e1e62622895f268135e252e8bfc307/valkey_glide-2.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4427e7b4d54c9de289a35032c19d5956f94376f5d4335206c5ac4524cbd1c64a", size = 5094507, upload-time = "2025-06-20T01:07:35.349Z" },
{ url = "https://files.pythonhosted.org/packages/b5/dd/0dd6614e09123a5bd7273bf1159c958d1ea65e7decc2190b225d212e0cb9/valkey_glide-2.0.1-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:6379582d6fbd817697fb119274e37d397db450103cd15d4bd71e555e6d88fb6b", size = 5072939, upload-time = "2025-06-20T01:07:36.948Z" },
{ url = "https://files.pythonhosted.org/packages/c6/04/986188e407231a5f0bfaf31f31b68e3605ab66f4f4c656adfbb0345669d9/valkey_glide-2.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:0f1c0fe003026d8ae172369e0eb2337cbff16f41d4c085332487d6ca2e5282e6", size = 4750491, upload-time = "2025-06-20T01:07:38.659Z" },
{ url = "https://files.pythonhosted.org/packages/ac/fb/2f5cec71ae51c464502a892b6825426cd74a2c325827981726e557926c94/valkey_glide-2.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82c5f33598e50bcfec6fc924864931f3c6e30cd327a9c9562e1c7ac4e17e79fd", size = 4767597, upload-time = "2025-06-20T01:07:40.091Z" },
{ url = "https://files.pythonhosted.org/packages/3a/31/851a1a734fe5da5d520106fcfd824e4da09c3be8a0a2123bb4b1980db1ea/valkey_glide-2.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:79039a9dc23bb074680f171c12b36b3322357a0af85125534993e81a619dce21", size = 5094383, upload-time = "2025-06-20T01:07:41.329Z" },
{ url = "https://files.pythonhosted.org/packages/fc/6d/1e7b432cbc02fe63e7496b984b7fc830fb7de388c877b237e0579a6300fc/valkey_glide-2.0.1-cp312-cp312-macosx_10_7_x86_64.whl", hash = "sha256:f55ec8968b0fde364a5b3399be34b89dcb9068994b5cd384e20db0773ad12723", size = 5075024, upload-time = "2025-06-20T01:07:42.917Z" },
{ url = "https://files.pythonhosted.org/packages/ca/39/6e9f83970590d17d19f596e1b3a366d39077624888e3dd709309efc67690/valkey_glide-2.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:21598f49313912ad27dc700d7b13a3b4bfed7ed9dffad207235cac7d218f4966", size = 4748418, upload-time = "2025-06-20T01:07:44.64Z" },
{ url = "https://files.pythonhosted.org/packages/98/0e/91335c13dc8e7ceb95063234c16010b46e2dd874a2edef62dea155081647/valkey_glide-2.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f662285146529328e2b5a0a7047f699339b4e0d250eb1f252b15c9befa0dea05", size = 4767264, upload-time = "2025-06-20T01:07:46.185Z" },
{ url = "https://files.pythonhosted.org/packages/5f/94/ee4d9d441f83fec1464d9f4e52f7940bdd2aeb917589e6abd57498880876/valkey_glide-2.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3939aaa8411fcbba00cb1ff7c7ba73f388bb1deca919972f65cba7eda1d5fa95", size = 5093543, upload-time = "2025-06-20T01:07:47.345Z" },
{ url = "https://files.pythonhosted.org/packages/ed/7e/257a2e4b61ac29d5923f89bad5fe62be7b4a19e7bec78d191af3ce77aa39/valkey_glide-2.0.1-cp313-cp313-macosx_10_7_x86_64.whl", hash = "sha256:c49b53011a05b5820d0c660ee5c76574183b413a54faa33cf5c01ce77164d9c8", size = 5073114, upload-time = "2025-06-20T01:07:48.885Z" },
{ url = "https://files.pythonhosted.org/packages/20/14/a8a470679953980af7eac3ccb09638f2a76d4547116d48cbc69ae6f25080/valkey_glide-2.0.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:3a23572b83877537916ba36ad0a6b2fd96581534f0bc67ef8f8498bf4dbb2b40", size = 4747717, upload-time = "2025-06-20T01:07:50.092Z" },
{ url = "https://files.pythonhosted.org/packages/9f/49/f168dd0c778d9f6ff1be70d5d3bad7a86928fee563de7de5f4f575eddfd8/valkey_glide-2.0.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:943a2c4a5c38b8a6b53281201d5a4997ec454a6fdda72d27050eeb6aaef12afb", size = 4767128, upload-time = "2025-06-20T01:07:51.306Z" },
{ url = "https://files.pythonhosted.org/packages/43/be/68961b14ea133d1792ce50f6df1753848b5377c3e06a8dbe4e39188a549a/valkey_glide-2.0.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d770ec581acc59d5597e7ccaac37aee7e3b5e716a77a7fa44e2967db3a715f53", size = 5093522, upload-time = "2025-06-20T01:07:52.546Z" },
{ url = "https://files.pythonhosted.org/packages/51/2e/ad8595ffe84317385d52ceab8de1e9ef06a4da6b81ca8cd61b7961923de4/valkey_glide-2.0.1-pp310-pypy310_pp73-macosx_10_7_x86_64.whl", hash = "sha256:d4a9ccfe2b190c90622849dab62f9468acf76a282719a1245d272b649e7c12d1", size = 5074539, upload-time = "2025-06-20T01:07:59.87Z" },
{ url = "https://files.pythonhosted.org/packages/db/e5/2122541c7a64706f3631655209bb0b13723fb99db3c190d9a792b4e7d494/valkey_glide-2.0.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:9aa004077b82f64b23ea0d38d948b5116c23f7228dae3a5b4fcfa1799f8ff7de", size = 4753222, upload-time = "2025-06-20T01:08:01.376Z" },
{ url = "https://files.pythonhosted.org/packages/6c/13/cd9a20988a820ff61b127d3f850887b28bb734daf2c26d512d8e4c2e8e9e/valkey_glide-2.0.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:631a7a0e2045f7e5e3706e1903beeddf381a6529e318c27230798f4382579e4f", size = 4771530, upload-time = "2025-06-20T01:08:02.6Z" },
{ url = "https://files.pythonhosted.org/packages/c7/fc/047e89cc01b4cc71db1b6b8160d3b5d050097b408028022c002351238641/valkey_glide-2.0.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c5ed905fb62368c9bc6aef9df8d66269ef51f968dc527da4d7c956927382c1d", size = 5091242, upload-time = "2025-06-20T01:08:04.111Z" },
{ url = "https://files.pythonhosted.org/packages/1c/9e/68790c1a263f3a0094d67d0109be34631f6f79c2fbce5ced7e33a65ad363/valkey_glide-2.0.1-pp311-pypy311_pp73-macosx_10_7_x86_64.whl", hash = "sha256:53da3cc47c8d946ac76ecc4b468a469d3486778833a59162ea69aa7ce70cbb27", size = 5072793, upload-time = "2025-06-20T01:08:05.562Z" },
{ url = "https://files.pythonhosted.org/packages/1f/ae/a935af65ae4069d76c69f28f6bfb4533da8b89f7fc418beb7a1482cdd9ee/valkey_glide-2.0.1-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:e526a7d718cdd299d6b03091c12dcc15cd02ff22fe420f253341a4891c50824d", size = 4753435, upload-time = "2025-06-20T01:08:07.149Z" },
{ url = "https://files.pythonhosted.org/packages/3b/c2/c91d753a89dd87dce2fc8932cfbe174c7a1226c657b3cd64c063f21d4fe6/valkey_glide-2.0.1-pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8d3345ea2adf6f745733fa5157d8709bcf5ffbb2674391aeebd8f166a37cbc96", size = 4771401, upload-time = "2025-06-20T01:08:08.359Z" },
{ url = "https://files.pythonhosted.org/packages/00/fe/ad83cfc2ac87bf6bad2b75fa64fca5a6dd54568c1de551d36d369e07f948/valkey_glide-2.0.1-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d1c5fff0f12d2aa4277ddc335035b2c8e12bb11243c1a0f3c35071f4a8b11064", size = 5091360, upload-time = "2025-06-20T01:08:09.622Z" },
]
[[package]]
name = "vcrpy"
version = "7.0.0"