Compare commits

..

1 Commits

Author SHA1 Message Date
Iris Clawd
f3843432cd docs: add Platform Tools CLI guide for crewai tool create/publish/install
Add documentation for the CLI commands referenced in the Create Tool
modal on the platform (crewai tool create, crewai tool publish,
crewai tool install). These commands manage tools on the CrewAI
platform registry — distinct from PyPI publishing.

Changes:
- New guide: docs/en/guides/tools/platform-tools-cli.mdx
  Full lifecycle: create → implement → publish → install
  Covers visibility flags (--public/--private/--force)
  Includes platform vs PyPI comparison
- Updated create-custom-tools.mdx tip to cross-reference both guides
- Added new page to docs.json navigation (all versions)

Resolves EPD-76

Co-authored-by: Diego Nogues <diego@crewai.com>
2026-05-13 14:32:45 +00:00
22 changed files with 1534 additions and 10064 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,139 @@
---
title: Platform Tools CLI
description: Create, publish, and install custom tools on the CrewAI platform using the CLI.
icon: terminal
mode: "wide"
---
## Overview
The CrewAI CLI provides commands to manage custom tools on the **CrewAI platform** — a hosted tool registry that lets you share tools within your organization and across the community without publishing to PyPI.
| Command | Purpose |
|---------|---------|
| `crewai tool create <handle>` | Scaffold a new tool project |
| `crewai tool publish` | Publish the tool to the CrewAI platform |
| `crewai tool install <handle>` | Install a platform tool into your crew project |
<Note type="info" title="Platform vs PyPI">
These commands manage tools on the **CrewAI platform registry**. If you want to publish a standalone Python package to PyPI instead, see the [Publish Custom Tools to PyPI](/en/guides/tools/publish-custom-tools) guide.
</Note>
## Prerequisites
- **CrewAI CLI** installed (`pip install crewai`)
- **Authenticated** with the platform — run `crewai login` first
---
## Step 1: Create a Tool Project
Scaffold a new tool project:
```bash
crewai tool create my_custom_tool
```
This generates a project structure with the boilerplate you need to start building your tool.
<Tip>
The `handle` is the unique identifier for your tool on the platform. Choose something descriptive and specific to what the tool does.
</Tip>
### Implement Your Tool
Edit the generated tool file to add your logic. The tool follows the standard CrewAI tools contract — you can subclass `BaseTool` or use the `@tool` decorator:
```python
from crewai.tools import BaseTool
class MyCustomTool(BaseTool):
name: str = "My Custom Tool"
description: str = "Description of what this tool does — be specific so agents know when to use it."
def _run(self, argument: str) -> str:
# Your tool logic here
return "result"
```
For the full tools API reference (input schemas, caching, async support, error handling), see the [Create Custom Tools](/en/learn/create-custom-tools) guide.
---
## Step 2: Publish to the Platform
From your tool project directory, publish it to the CrewAI platform:
```bash
crewai tool publish
```
### Visibility Options
| Flag | Description |
|------|-------------|
| `--public` | Make the tool available to all platform users |
| `--private` | Restrict visibility to your organization |
| `--force` | Bypass Git remote validations |
```bash
# Publish as a public tool
crewai tool publish --public
# Publish privately (organization only)
crewai tool publish --private
```
---
## Step 3: Install a Platform Tool
To install a tool that's been published to the platform:
```bash
crewai tool install my_custom_tool
```
Once installed, you can use the tool in your crew like any other tool — assign it to an agent via the `tools` parameter.
---
## Full Lifecycle Example
```bash
# 1. Authenticate with the platform
crewai login
# 2. Scaffold a new tool
crewai tool create weather_lookup
# 3. Implement your logic in the generated project
cd weather_lookup
# ... edit the tool file ...
# 4. Publish to the platform
crewai tool publish --public
# 5. In another project, install and use it
crewai tool install weather_lookup
```
---
## Platform Tools vs PyPI Packages
| | Platform Tools | PyPI Packages |
|---|---|---|
| **Publish** | `crewai tool publish` | `uv build` + `uv publish` |
| **Registry** | CrewAI platform | PyPI |
| **Install** | `crewai tool install <handle>` | `pip install <package>` |
| **Auth** | `crewai login` | PyPI account + token |
| **Visibility** | `--public` / `--private` flags | Always public |
| **Guide** | This page | [Publish Custom Tools](/en/guides/tools/publish-custom-tools) |
---
## Related
- [Create Custom Tools](/en/learn/create-custom-tools) — Python API reference for building tools (BaseTool, @tool decorator)
- [Publish Custom Tools to PyPI](/en/guides/tools/publish-custom-tools) — package and distribute tools as standalone Python libraries

View File

@@ -12,7 +12,9 @@ incorporating the latest functionalities such as tool delegation, error handling
enabling agents to perform a wide range of actions.
<Tip>
**Want to publish your tool for the community?** If you're building a tool that others could benefit from, check out the [Publish Custom Tools](/en/guides/tools/publish-custom-tools) guide to learn how to package and distribute your tool on PyPI.
**Want to publish your tool to the CrewAI platform?** Use the CLI to scaffold, publish, and share tools directly on the platform — see the [Platform Tools CLI](/en/guides/tools/platform-tools-cli) guide.
**Prefer publishing to PyPI?** Check out the [Publish Custom Tools](/en/guides/tools/publish-custom-tools) guide to package and distribute your tool as a standalone Python library.
</Tip>
### Subclassing `BaseTool`

View File

@@ -1,4 +1,4 @@
"""Cache for tracking uploaded files using aiocache or ValkeyCache."""
"""Cache for tracking uploaded files using aiocache."""
from __future__ import annotations
@@ -10,11 +10,10 @@ from dataclasses import dataclass
from datetime import datetime, timezone
import hashlib
import logging
from typing import TYPE_CHECKING, Any, Protocol
from typing import TYPE_CHECKING, Any
from aiocache import Cache # type: ignore[import-untyped]
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
from crewai.utilities.cache_config import parse_cache_url
from crewai_files.core.constants import DEFAULT_MAX_CACHE_ENTRIES, DEFAULT_TTL_SECONDS
from crewai_files.uploaders.factory import ProviderType
@@ -52,33 +51,6 @@ class CachedUpload:
return False
return datetime.now(timezone.utc) >= self.expires_at
def to_dict(self) -> dict[str, Any]:
"""Serialize to a JSON-compatible dict."""
return {
"file_id": self.file_id,
"provider": self.provider,
"file_uri": self.file_uri,
"content_type": self.content_type,
"uploaded_at": self.uploaded_at.isoformat(),
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> CachedUpload:
"""Deserialize from a dict."""
return cls(
file_id=data["file_id"],
provider=data["provider"],
file_uri=data.get("file_uri"),
content_type=data["content_type"],
uploaded_at=datetime.fromisoformat(data["uploaded_at"]),
expires_at=(
datetime.fromisoformat(data["expires_at"])
if data.get("expires_at")
else None
),
)
def _make_key(file_hash: str, provider: str) -> str:
"""Create a cache key from file hash and provider."""
@@ -86,7 +58,14 @@ def _make_key(file_hash: str, provider: str) -> str:
def _compute_file_hash_streaming(chunks: Iterator[bytes]) -> str:
"""Compute SHA-256 hash from streaming chunks."""
"""Compute SHA-256 hash from streaming chunks.
Args:
chunks: Iterator of byte chunks.
Returns:
Hexadecimal hash string.
"""
hasher = hashlib.sha256()
for chunk in chunks:
hasher.update(chunk)
@@ -94,7 +73,10 @@ def _compute_file_hash_streaming(chunks: Iterator[bytes]) -> str:
def _compute_file_hash(file: FileInput) -> str:
"""Compute SHA-256 hash of file content."""
"""Compute SHA-256 hash of file content.
Uses streaming for FilePath sources to avoid loading large files into memory.
"""
from crewai_files.core.sources import FilePath
source = file._file_source
@@ -104,73 +86,10 @@ def _compute_file_hash(file: FileInput) -> str:
return hashlib.sha256(content).hexdigest()
class CacheBackend(Protocol):
"""Protocol for cache backends used by UploadCache."""
async def get(self, key: str) -> CachedUpload | None: ...
async def set(self, key: str, value: CachedUpload, ttl: int) -> None: ...
async def delete(self, key: str) -> bool: ...
class AiocacheBackend:
"""Cache backend backed by aiocache (memory or Redis)."""
def __init__(self, cache: Cache) -> None: # type: ignore[no-any-unimported]
self._cache = cache
async def get(self, key: str) -> CachedUpload | None:
result = await self._cache.get(key)
if isinstance(result, CachedUpload):
return result
return None
async def set(self, key: str, value: CachedUpload, ttl: int) -> None:
await self._cache.set(key, value, ttl=ttl)
async def delete(self, key: str) -> bool:
result = await self._cache.delete(key)
return bool(result > 0 if isinstance(result, int) else result)
class ValkeyCacheBackend:
"""Cache backend backed by ValkeyCache (JSON serialization)."""
def __init__(
self,
host: str = "localhost",
port: int = 6379,
db: int = 0,
password: str | None = None,
default_ttl: int | None = None,
) -> None:
from crewai.memory.storage.valkey_cache import ValkeyCache
self._cache = ValkeyCache(
host=host, port=port, db=db, password=password, default_ttl=default_ttl
)
async def get(self, key: str) -> CachedUpload | None:
data = await self._cache.get(key)
if data is None:
return None
try:
return CachedUpload.from_dict(data)
except (KeyError, ValueError) as e:
logger.warning(f"Failed to deserialize cached upload: {e}")
return None
async def set(self, key: str, value: CachedUpload, ttl: int) -> None:
await self._cache.set(key, value.to_dict(), ttl=ttl)
async def delete(self, key: str) -> bool:
await self._cache.delete(key)
return True # ValkeyCache.delete is void
class UploadCache:
"""Async cache for tracking uploaded files.
"""Async cache for tracking uploaded files using aiocache.
Supports in-memory caching by default, with optional Redis or Valkey backend
Supports in-memory caching by default, with optional Redis backend
for distributed setups.
Attributes:
@@ -191,7 +110,7 @@ class UploadCache:
Args:
ttl: Default TTL in seconds.
namespace: Cache namespace.
cache_type: Backend type ("memory", "redis", or "valkey").
cache_type: Backend type ("memory" or "redis").
max_entries: Maximum cache entries (None for unlimited).
**cache_kwargs: Additional args for cache backend.
"""
@@ -201,39 +120,18 @@ class UploadCache:
self._provider_keys: dict[ProviderType, set[str]] = {}
self._key_access_order: list[str] = []
self._backend: CacheBackend = self._create_backend(
cache_type, namespace, ttl, **cache_kwargs
)
@staticmethod
def _create_backend(
cache_type: str,
namespace: str,
ttl: int,
**cache_kwargs: Any,
) -> CacheBackend:
"""Create the appropriate cache backend."""
if cache_type == "valkey":
conn = parse_cache_url() or {}
return ValkeyCacheBackend(
host=cache_kwargs.get("host", conn.get("host", "localhost")),
port=cache_kwargs.get("port", conn.get("port", 6379)),
db=cache_kwargs.get("db", conn.get("db", 0)),
password=cache_kwargs.get("password", conn.get("password")),
default_ttl=ttl,
)
if cache_type == "redis":
return AiocacheBackend(
Cache(
Cache.REDIS,
serializer=PickleSerializer(),
namespace=namespace,
**cache_kwargs,
)
self._cache = Cache(
Cache.REDIS,
serializer=PickleSerializer(),
namespace=namespace,
**cache_kwargs,
)
else:
self._cache = Cache(
serializer=PickleSerializer(),
namespace=namespace,
)
return AiocacheBackend(
Cache(serializer=PickleSerializer(), namespace=namespace)
)
def _track_key(self, provider: ProviderType, key: str) -> None:
"""Track a key for a provider (for cleanup) and access order."""
@@ -259,9 +157,11 @@ class UploadCache:
"""
if self.max_entries is None:
return 0
current_count = len(self)
if current_count < self.max_entries:
return 0
to_evict = max(1, self.max_entries // 10)
return await self._evict_oldest(to_evict)
@@ -276,24 +176,31 @@ class UploadCache:
"""
evicted = 0
keys_to_evict = self._key_access_order[:count]
for key in keys_to_evict:
await self._backend.delete(key)
await self._cache.delete(key)
self._key_access_order.remove(key)
for provider_keys in self._provider_keys.values():
provider_keys.discard(key)
evicted += 1
if evicted > 0:
logger.debug(f"Evicted {evicted} oldest cache entries")
return evicted
# ------------------------------------------------------------------
# Async public API
# ------------------------------------------------------------------
return evicted
async def aget(
self, file: FileInput, provider: ProviderType
) -> CachedUpload | None:
"""Get a cached upload for a file."""
"""Get a cached upload for a file.
Args:
file: The file to look up.
provider: The provider name.
Returns:
Cached upload if found and not expired, None otherwise.
"""
file_hash = _compute_file_hash(file)
return await self.aget_by_hash(file_hash, provider)
@@ -310,14 +217,17 @@ class UploadCache:
Cached upload if found and not expired, None otherwise.
"""
key = _make_key(file_hash, provider)
result = await self._backend.get(key)
result = await self._cache.get(key)
if result is None:
return None
if result.is_expired():
await self._backend.delete(key)
self._untrack_key(provider, key)
return None
return result
if isinstance(result, CachedUpload):
if result.is_expired():
await self._cache.delete(key)
self._untrack_key(provider, key)
return None
return result
return None
async def aset(
self,
@@ -327,7 +237,18 @@ class UploadCache:
file_uri: str | None = None,
expires_at: datetime | None = None,
) -> CachedUpload:
"""Cache an uploaded file."""
"""Cache an uploaded file.
Args:
file: The file that was uploaded.
provider: The provider name.
file_id: Provider-specific file identifier.
file_uri: Optional URI for accessing the file.
expires_at: When the upload expires.
Returns:
The created cache entry.
"""
file_hash = _compute_file_hash(file)
return await self.aset_by_hash(
file_hash=file_hash,
@@ -361,6 +282,7 @@ class UploadCache:
The created cache entry.
"""
await self._evict_if_needed()
key = _make_key(file_hash, provider)
now = datetime.now(timezone.utc)
@@ -377,7 +299,7 @@ class UploadCache:
if expires_at is not None:
ttl = max(0, int((expires_at - now).total_seconds()))
await self._backend.set(key, cached, ttl=ttl)
await self._cache.set(key, cached, ttl=ttl)
self._track_key(provider, key)
logger.debug(f"Cached upload: {file_id} for provider {provider}")
return cached
@@ -394,7 +316,9 @@ class UploadCache:
"""
file_hash = _compute_file_hash(file)
key = _make_key(file_hash, provider)
removed = await self._backend.delete(key)
result = await self._cache.delete(key)
removed = bool(result > 0 if isinstance(result, int) else result)
if removed:
self._untrack_key(provider, key)
return removed
@@ -411,10 +335,11 @@ class UploadCache:
"""
if provider not in self._provider_keys:
return False
for key in list(self._provider_keys[provider]):
cached = await self._backend.get(key)
if cached is not None and cached.file_id == file_id:
await self._backend.delete(key)
cached = await self._cache.get(key)
if isinstance(cached, CachedUpload) and cached.file_id == file_id:
await self._cache.delete(key)
self._untrack_key(provider, key)
return True
return False
@@ -426,13 +351,17 @@ class UploadCache:
Number of entries removed.
"""
removed = 0
for provider, keys in list(self._provider_keys.items()):
for key in list(keys):
cached = await self._backend.get(key)
if cached is None or cached.is_expired():
await self._backend.delete(key)
cached = await self._cache.get(key)
if cached is None or (
isinstance(cached, CachedUpload) and cached.is_expired()
):
await self._cache.delete(key)
self._untrack_key(provider, key)
removed += 1
if removed > 0:
logger.debug(f"Cleared {removed} expired cache entries")
return removed
@@ -444,12 +373,9 @@ class UploadCache:
Number of entries cleared.
"""
count = sum(len(keys) for keys in self._provider_keys.values())
# Delete all tracked keys individually (works for all backends)
for keys in self._provider_keys.values():
for key in keys:
await self._backend.delete(key)
await self._cache.clear(namespace=self.namespace)
self._provider_keys.clear()
self._key_access_order.clear()
if count > 0:
logger.debug(f"Cleared {count} cache entries")
return count
@@ -465,17 +391,14 @@ class UploadCache:
"""
if provider not in self._provider_keys:
return []
results: list[CachedUpload] = []
for key in list(self._provider_keys[provider]):
cached = await self._backend.get(key)
if cached is not None and not cached.is_expired():
cached = await self._cache.get(key)
if isinstance(cached, CachedUpload) and not cached.is_expired():
results.append(cached)
return results
# ------------------------------------------------------------------
# Sync wrappers
# ------------------------------------------------------------------
@staticmethod
def _run_sync(coro: Any) -> Any:
"""Run an async coroutine from sync context without blocking event loop."""
@@ -566,7 +489,11 @@ class UploadCache:
return sum(len(keys) for keys in self._provider_keys.values())
def get_providers(self) -> builtins.set[ProviderType]:
"""Get all provider names that have cached entries."""
"""Get all provider names that have cached entries.
Returns:
Set of provider names.
"""
return builtins.set(self._provider_keys.keys())
@@ -579,7 +506,17 @@ def get_upload_cache(
cache_type: str = "memory",
**cache_kwargs: Any,
) -> UploadCache:
"""Get or create the default upload cache."""
"""Get or create the default upload cache.
Args:
ttl: Default TTL in seconds.
namespace: Cache namespace.
cache_type: Backend type ("memory" or "redis").
**cache_kwargs: Additional args for cache backend.
Returns:
The upload cache instance.
"""
global _default_cache
if _default_cache is None:
_default_cache = UploadCache(

View File

@@ -110,9 +110,6 @@ file-processing = [
qdrant-edge = [
"qdrant-edge-py>=0.6.0",
]
valkey = [
"valkey-glide>=1.3.0",
]
[tool.uv]

View File

@@ -13,12 +13,8 @@ from types import MethodType
from typing import TYPE_CHECKING
from a2a.client.errors import A2AClientHTTPError
from a2a.types import (
AgentCapabilities,
AgentCard,
AgentSkill,
)
from aiocache import cached, caches # type: ignore[import-untyped]
from a2a.types import AgentCapabilities, AgentCard, AgentSkill
from aiocache import cached # type: ignore[import-untyped]
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
import httpx
@@ -36,7 +32,6 @@ from crewai.events.types.a2a_events import (
A2AAuthenticationFailedEvent,
A2AConnectionErrorEvent,
)
from crewai.utilities.cache_config import get_aiocache_config
if TYPE_CHECKING:
@@ -45,18 +40,6 @@ if TYPE_CHECKING:
from crewai.task import Task
_cache_configured = False
def _ensure_cache_configured() -> None:
"""Configure aiocache on first use (lazy initialization)."""
global _cache_configured
if _cache_configured:
return
caches.set_config(get_aiocache_config())
_cache_configured = True
def _get_tls_verify(auth: ClientAuthScheme | None) -> ssl.SSLContext | bool | str:
"""Get TLS verify parameter from auth scheme.
@@ -208,7 +191,6 @@ async def afetch_agent_card(
else:
auth_hash = _auth_store.compute_key("none", "")
_auth_store.set(auth_hash, auth)
_ensure_cache_configured()
agent_card: AgentCard = await _afetch_agent_card_cached(
endpoint, auth_hash, timeout
)

View File

@@ -9,8 +9,9 @@ from datetime import datetime
from functools import wraps
import json
import logging
import threading
import os
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast
from urllib.parse import urlparse
from a2a.server.agent_execution import RequestContext
from a2a.server.events import EventQueue
@@ -37,6 +38,7 @@ from a2a.utils import (
from a2a.utils.errors import ServerError
from aiocache import SimpleMemoryCache, caches # type: ignore[import-untyped]
from pydantic import BaseModel
from typing_extensions import TypedDict
from crewai.a2a.utils.agent_card import _get_server_config
from crewai.a2a.utils.content_type import validate_message_parts
@@ -48,18 +50,12 @@ from crewai.events.types.a2a_events import (
A2AServerTaskStartedEvent,
)
from crewai.task import Task
from crewai.utilities.cache_config import (
get_aiocache_config,
parse_cache_url,
use_valkey_cache,
)
from crewai.utilities.pydantic_schema_utils import create_model_from_schema
if TYPE_CHECKING:
from crewai.a2a.extensions.server import ExtensionContext, ServerExtensionRegistry
from crewai.agent import Agent
from crewai.memory.storage.valkey_cache import ValkeyCache
logger = logging.getLogger(__name__)
@@ -68,61 +64,52 @@ P = ParamSpec("P")
T = TypeVar("T")
# ---------------------------------------------------------------------------
# Lazy cache initialisation
# ---------------------------------------------------------------------------
class RedisCacheConfig(TypedDict, total=False):
"""Configuration for aiocache Redis backend."""
_task_cache: ValkeyCache | None = None
_lazy_init_complete = False
_cache_init_lock = threading.Lock()
# Cancellation polling interval in seconds.
_CANCEL_POLL_INTERVAL = 0.1
# Configure aiocache at import time (matches upstream behaviour).
# This is safe — it only touches aiocache, no optional dependencies.
# The Valkey path is deferred to _ensure_task_cache() to avoid importing
# valkey-glide at module level (it may not be installed).
if not use_valkey_cache():
caches.set_config(get_aiocache_config())
cache: str
endpoint: str
port: int
db: int
password: str
def _ensure_task_cache() -> None:
"""Initialise the Valkey task cache on first use (thread-safe).
def _parse_redis_url(url: str) -> RedisCacheConfig:
"""Parse a Redis URL into aiocache configuration.
For the aiocache path, configuration happens at module level above.
This function only needs to run for the Valkey path.
Args:
url: Redis connection URL (e.g., redis://localhost:6379/0).
Returns:
Configuration dict for aiocache.RedisCache.
"""
global _task_cache, _lazy_init_complete
if _lazy_init_complete:
return
parsed = urlparse(url)
config: RedisCacheConfig = {
"cache": "aiocache.RedisCache",
"endpoint": parsed.hostname or "localhost",
"port": parsed.port or 6379,
}
if parsed.path and parsed.path != "/":
try:
config["db"] = int(parsed.path.lstrip("/"))
except ValueError:
pass
if parsed.password:
config["password"] = parsed.password
return config
with _cache_init_lock:
if _lazy_init_complete:
return
if use_valkey_cache():
from crewai.memory.storage.valkey_cache import ValkeyCache
_redis_url = os.environ.get("REDIS_URL")
conn = parse_cache_url() or {}
try:
_task_cache = ValkeyCache(
host=conn.get("host", "localhost"),
port=conn.get("port", 6379),
db=conn.get("db", 0),
password=conn.get("password"),
default_ttl=3600,
)
except Exception as e:
logger.error(
"Failed to initialize ValkeyCache for task cancellation, "
"falling back to aiocache",
extra={"error": str(e)},
)
caches.set_config(get_aiocache_config())
_task_cache = None
_lazy_init_complete = True
caches.set_config(
{
"default": _parse_redis_url(_redis_url)
if _redis_url
else {
"cache": "aiocache.SimpleMemoryCache",
}
}
)
def cancellable(
@@ -143,8 +130,6 @@ def cancellable(
@wraps(fn)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
"""Wrap function with cancellation monitoring."""
_ensure_task_cache()
context: RequestContext | None = None
for arg in args:
if isinstance(arg, RequestContext):
@@ -157,34 +142,19 @@ def cancellable(
return await fn(*args, **kwargs)
task_id = context.task_id
cache = caches.get("default")
async def poll_for_cancel_valkey() -> bool:
"""Poll ValkeyCache for cancellation flag."""
while True:
if _task_cache is not None and await _task_cache.get(
f"cancel:{task_id}"
):
return True
await asyncio.sleep(_CANCEL_POLL_INTERVAL)
async def poll_for_cancel_aiocache() -> bool:
"""Poll aiocache for cancellation flag."""
cache = caches.get("default")
async def poll_for_cancel() -> bool:
"""Poll cache for cancellation flag."""
while True:
if await cache.get(f"cancel:{task_id}"):
return True
await asyncio.sleep(_CANCEL_POLL_INTERVAL)
await asyncio.sleep(0.1)
async def watch_for_cancel() -> bool:
"""Watch for cancellation events via pub/sub or polling."""
if _task_cache is not None:
# ValkeyCache: use polling (pub/sub not implemented yet)
return await poll_for_cancel_valkey()
# aiocache: use pub/sub if Redis, otherwise poll
cache = caches.get("default")
if isinstance(cache, SimpleMemoryCache):
return await poll_for_cancel_aiocache()
return await poll_for_cancel()
try:
client = cache.client
@@ -198,7 +168,7 @@ def cancellable(
"Cancel watcher Redis error, falling back to polling",
extra={"task_id": task_id, "error": str(e)},
)
return await poll_for_cancel_aiocache()
return await poll_for_cancel()
return False
execute_task = asyncio.create_task(fn(*args, **kwargs))
@@ -220,12 +190,7 @@ def cancellable(
cancel_watch.cancel()
return execute_task.result()
finally:
# Clean up cancellation flag
if _task_cache is not None:
await _task_cache.delete(f"cancel:{task_id}")
else:
cache = caches.get("default")
await cache.delete(f"cancel:{task_id}")
await cache.delete(f"cancel:{task_id}")
return wrapper
@@ -510,8 +475,6 @@ async def cancel(
if task_id is None or context_id is None:
raise ServerError(InvalidParamsError(message="task_id and context_id required"))
_ensure_task_cache()
if context.current_task and context.current_task.status.state in (
TaskState.completed,
TaskState.failed,
@@ -519,16 +482,11 @@ async def cancel(
):
return context.current_task
if _task_cache is not None:
# Use ValkeyCache
await _task_cache.set(f"cancel:{task_id}", True, ttl=3600)
# Note: pub/sub not implemented for ValkeyCache yet, relies on polling
else:
# Use aiocache
cache = caches.get("default")
await cache.set(f"cancel:{task_id}", True, ttl=3600)
if not isinstance(cache, SimpleMemoryCache):
await cache.client.publish(f"cancel:{task_id}", "cancel")
cache = caches.get("default")
await cache.set(f"cancel:{task_id}", True, ttl=3600)
if not isinstance(cache, SimpleMemoryCache):
await cache.client.publish(f"cancel:{task_id}", "cancel")
await event_queue.enqueue_event(
TaskStatusUpdateEvent(

View File

@@ -18,7 +18,7 @@ import math
from typing import Any
from uuid import uuid4
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field
from crewai.flow.flow import Flow, listen, start
from crewai.memory.analyze import (
@@ -68,27 +68,6 @@ class ItemState(BaseModel):
plan: ConsolidationPlan | None = None
result_record: MemoryRecord | None = None
@field_validator("similar_records", "result_record", mode="before")
@classmethod
def ensure_embedding_is_list(cls, v: Any) -> Any:
"""Ensure MemoryRecord embeddings are list[float], not bytes.
Delegates to MemoryRecord.validate_embedding for consistent behavior
(e.g. empty bytes → None).
"""
if v is None:
return None
if isinstance(v, list):
for record in v:
if isinstance(record, MemoryRecord) and isinstance(
record.embedding, bytes
):
record.embedding = MemoryRecord.validate_embedding(record.embedding)
return v
if isinstance(v, MemoryRecord) and isinstance(v.embedding, bytes):
v.embedding = MemoryRecord.validate_embedding(v.embedding)
return v
class EncodingState(BaseModel):
"""Batch-level state for the encoding flow."""

View File

@@ -1,198 +0,0 @@
"""Valkey-based cache implementation for CrewAI.
This module provides a simple cache interface using Valkey-GLIDE client
for caching operations with optional TTL support. It replaces Redis usage
in A2A communication, file uploads, and agent card caching.
"""
from __future__ import annotations
import json
import logging
from typing import Any
from glide import GlideClient, GlideClientConfiguration, NodeAddress
_logger = logging.getLogger(__name__)
class ValkeyCache:
"""Simple cache interface using Valkey-GLIDE client.
Provides get/set/delete/exists operations for caching with optional TTL.
Uses JSON serialization for complex values and lazy client initialization.
Example:
>>> cache = ValkeyCache(host="localhost", port=6379)
>>> await cache.set("key", {"data": "value"}, ttl=3600)
>>> value = await cache.get("key")
>>> await cache.delete("key")
"""
def __init__(
self,
host: str = "localhost",
port: int = 6379,
db: int = 0,
password: str | None = None,
default_ttl: int | None = None,
) -> None:
"""Initialize Valkey cache.
Args:
host: Valkey server hostname.
port: Valkey server port.
db: Database number to use.
password: Optional password for authentication.
default_ttl: Default TTL in seconds (None = no expiration).
"""
self._host = host
self._port = port
self._db = db
self._password = password
self._default_ttl = default_ttl
self._client: GlideClient | None = None
async def _get_client(self) -> GlideClient:
"""Get or create Valkey client (lazy initialization).
Returns:
Initialized GlideClient instance.
Raises:
RuntimeError: If connection to Valkey fails.
TimeoutError: If connection attempt times out (10 seconds).
"""
import asyncio
if self._client is None:
host = self._host
port = self._port
db = self._db
try:
from glide import ServerCredentials
config = GlideClientConfiguration(
addresses=[NodeAddress(host, port)],
database_id=db,
credentials=(
ServerCredentials(password=self._password)
if self._password
else None
),
)
# Add connection timeout (10 seconds)
try:
self._client = await asyncio.wait_for(
GlideClient.create(config), timeout=10.0
)
except asyncio.TimeoutError as e:
_logger.error("Connection timeout connecting to Valkey")
raise TimeoutError(
"Connection timeout to Valkey. "
"Ensure Valkey is running and accessible."
) from e
_logger.info("Valkey cache client initialized")
except (TimeoutError, RuntimeError):
raise
except Exception as e:
_logger.error(
"Failed to create Valkey cache client: %s", type(e).__name__
)
raise RuntimeError(
"Cannot connect to Valkey. Check connection settings."
) from e
return self._client
async def get(self, key: str) -> Any | None:
"""Get value from cache.
Args:
key: Cache key.
Returns:
Cached value (deserialized from JSON) or None if not found.
"""
client = await self._get_client()
value = await client.get(key)
if value is None:
return None
try:
return json.loads(value)
except json.JSONDecodeError:
_logger.warning(f"Failed to deserialize cached value for key: {key}")
return None
async def set(
self,
key: str,
value: Any,
ttl: int | None = None,
) -> None:
"""Set value in cache.
Args:
key: Cache key.
value: Value to cache (will be serialized to JSON).
ttl: TTL in seconds (None uses default_ttl, 0 = no expiration).
Raises:
TypeError: If value is not JSON-serializable.
"""
from glide import ExpirySet, ExpiryType
client = await self._get_client()
try:
serialized = json.dumps(value)
except (TypeError, ValueError) as e:
_logger.error("Cannot serialize value for key %r: %s", key, e)
raise TypeError(
f"Value for cache key {key!r} is not JSON-serializable: {e}"
) from e
ttl_to_use = ttl if ttl is not None else self._default_ttl
if ttl_to_use and ttl_to_use > 0:
# Set with expiration using SET command with EX option
await client.set(
key,
serialized,
expiry=ExpirySet(ExpiryType.SEC, ttl_to_use),
)
else:
await client.set(key, serialized)
async def delete(self, key: str) -> None:
"""Delete value from cache.
Args:
key: Cache key to delete.
"""
client = await self._get_client()
await client.delete([key])
async def exists(self, key: str) -> bool:
"""Check if key exists in cache.
Args:
key: Cache key to check.
Returns:
True if key exists, False otherwise.
"""
client = await self._get_client()
result = await client.exists([key])
return result > 0
async def close(self) -> None:
"""Close Valkey client connection."""
if self._client:
await self._client.close()
self._client = None
_logger.debug("Valkey cache client closed")

File diff suppressed because it is too large Load Diff

View File

@@ -2,17 +2,13 @@
from __future__ import annotations
import concurrent.futures
from datetime import datetime
import logging
from typing import Any
from uuid import uuid4
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field
_logger = logging.getLogger(__name__)
# When searching the vector store, we ask for more results than the caller
# requested so that post-search steps (composite scoring, deduplication,
# category filtering) have enough candidates to fill the final result set.
@@ -61,23 +57,6 @@ class MemoryRecord(BaseModel):
repr=False,
description="Vector embedding for semantic search. Excluded from serialization to save tokens.",
)
@field_validator("embedding", mode="before")
@classmethod
def validate_embedding(cls, v: Any) -> list[float] | None:
"""Ensure embedding is always list[float] or None, never bytes."""
if v is None:
return None
if isinstance(v, bytes):
# Convert bytes to list[float] if needed
import numpy as np
if len(v) == 0:
return None
arr = np.frombuffer(v, dtype=np.float32)
return [float(x) for x in arr]
return [float(x) for x in v]
source: str | None = Field(
default=None,
description=(
@@ -325,11 +304,7 @@ def embed_text(embedder: Any, text: str) -> list[float]:
"""
if not text or not text.strip():
return []
# Just call the embedder directly - the blocking issue needs to be fixed
# at a higher level (making Memory.recall() async)
result = embedder([text])
if not result:
return []
first = result[0]
@@ -340,27 +315,12 @@ def embed_text(embedder: Any, text: str) -> list[float]:
return list(first)
# Reusable thread pool for running embedder calls from sync context
# when an async event loop is already running. Uses max_workers=2 so
# a single slow/timed-out call doesn't block subsequent embeds.
_EMBED_POOL = concurrent.futures.ThreadPoolExecutor(max_workers=2)
def embed_texts(embedder: Any, texts: list[str]) -> list[list[float]]:
"""Embed multiple texts in a single API call.
The embedder already accepts ``list[str]``, so this just calls it once
with the full batch and normalises the output format.
When called from an async context, offloads the embedder to a thread pool
so the embedding work doesn't run on the event loop thread. The calling
thread still blocks on the result (unavoidable for a sync function), but
this prevents the embedder from starving the event loop's I/O callbacks.
The pool uses ``max_workers=2`` so a single timed-out call doesn't block
subsequent embeds.
Note: the proper long-term fix is making ``Memory.recall()`` async.
Args:
embedder: Callable that accepts a list of strings and returns embeddings.
texts: List of texts to embed.
@@ -368,8 +328,6 @@ def embed_texts(embedder: Any, texts: list[str]) -> list[list[float]]:
Returns:
List of embeddings, one per input text. Empty texts produce empty lists.
"""
import asyncio
if not texts:
return []
# Filter out empty texts, remembering their positions
@@ -379,28 +337,7 @@ def embed_texts(embedder: Any, texts: list[str]) -> list[list[float]]:
if not valid:
return [[] for _ in texts]
texts_to_embed = [t for _, t in valid]
# Check if we're in an async context
result: Any
try:
asyncio.get_running_loop()
# We're in an async context but this is a sync function.
# Offload to thread pool so the embedder doesn't run on the
# event loop thread. The .result() call blocks this thread
# (acceptable — callers like Memory.recall() are sync).
try:
result = _EMBED_POOL.submit(embedder, texts_to_embed).result(timeout=30)
except concurrent.futures.TimeoutError:
_logger.warning(
"Embedder timed out after 30s, returning empty embeddings. "
"The worker thread may still be running."
)
return [[] for _ in texts]
except RuntimeError:
# Not in async context, run directly
result = embedder(texts_to_embed)
result = embedder([t for _, t in valid])
embeddings: list[list[float]] = [[] for _ in texts]
for (orig_idx, _), emb in zip(valid, result, strict=False):
if hasattr(emb, "tolist"):

View File

@@ -2,11 +2,9 @@
from __future__ import annotations
import concurrent.futures
from concurrent.futures import Future, ThreadPoolExecutor
import contextvars
from datetime import datetime
import logging
import threading
import time
from typing import TYPE_CHECKING, Annotated, Any, Literal
@@ -38,9 +36,6 @@ from crewai.rag.embeddings.factory import build_embedder
from crewai.rag.embeddings.providers.openai.types import OpenAIProviderSpec
_logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
@@ -216,18 +211,6 @@ class Memory(BaseModel):
from crewai.memory.storage.lancedb_storage import LanceDBStorage
self._storage = LanceDBStorage()
elif self.storage == "valkey":
from crewai.memory.storage.valkey_storage import ValkeyStorage
from crewai.utilities.cache_config import parse_cache_url
conn = parse_cache_url() or {}
self._storage = ValkeyStorage(
host=conn.get("host", "localhost"),
port=conn.get("port", 6379),
db=conn.get("db", 0),
password=conn.get("password"),
use_tls=conn.get("use_tls", False),
)
else:
from crewai.memory.storage.lancedb_storage import LanceDBStorage
@@ -333,60 +316,16 @@ class Memory(BaseModel):
except Exception: # noqa: S110
pass # swallow everything during shutdown
def drain_writes(self, timeout_per_save: float = 60.0) -> None:
def drain_writes(self) -> None:
"""Block until all pending background saves have completed.
Called automatically by ``recall()`` and should be called by the
crew at shutdown to ensure no saves are lost.
Args:
timeout_per_save: Maximum seconds to wait per save operation.
Default 60s. If a save times out, logs warning
but continues to avoid blocking crew completion.
"""
with self._pending_lock:
pending = list(self._pending_saves)
if pending:
_logger.debug(
"[DRAIN_WRITES] Waiting for %d pending saves...", len(pending)
)
failed_saves = 0
for i, future in enumerate(pending):
try:
_logger.debug(
"[DRAIN_WRITES] Waiting for save %d/%d...", i + 1, len(pending)
)
future.result(timeout=timeout_per_save)
_logger.debug(
"[DRAIN_WRITES] Save %d/%d completed", i + 1, len(pending)
)
except (TimeoutError, concurrent.futures.TimeoutError): # noqa: PERF203
failed_saves += 1
_logger.warning(
"[DRAIN_WRITES] Save %d/%d timed out after %ss. "
"This save will be abandoned. Consider increasing timeout or checking "
"LLM/embedder performance.",
i + 1,
len(pending),
timeout_per_save,
)
# Don't raise - just log and continue to avoid blocking crew completion
except Exception as e:
failed_saves += 1
_logger.error(
"[DRAIN_WRITES] Save %d/%d failed: %s", i + 1, len(pending), e
)
# Don't raise - just log and continue
if failed_saves > 0:
_logger.warning(
"[DRAIN_WRITES] %d/%d saves failed or timed out. "
"Some memories may not have been persisted.",
failed_saves,
len(pending),
)
for future in pending:
future.result() # blocks until done; re-raises exceptions
def close(self) -> None:
"""Drain pending saves, flush storage, and shut down the background thread pool."""

View File

@@ -1,78 +0,0 @@
"""Shared cache configuration helpers for Valkey/Redis URL parsing."""
from __future__ import annotations
import logging
import os
from typing import Any
from urllib.parse import urlparse
_logger = logging.getLogger(__name__)
def parse_cache_url() -> dict[str, Any] | None:
"""Parse VALKEY_URL or REDIS_URL from environment.
Priority: VALKEY_URL > REDIS_URL.
Returns:
Dict with host, port, db, password keys, or None if no URL is set.
"""
url = os.environ.get("VALKEY_URL") or os.environ.get("REDIS_URL")
if not url:
return None
parsed = urlparse(url)
return {
"host": parsed.hostname or "localhost",
"port": parsed.port or 6379,
"db": _parse_db_from_path(parsed.path),
"password": parsed.password,
"use_tls": parsed.scheme in ("rediss", "valkeys"),
}
def _parse_db_from_path(path: str | None) -> int:
"""Parse database number from URL path, defaulting to 0."""
if not path or path == "/":
return 0
try:
return int(path.lstrip("/"))
except ValueError:
_logger.warning(
"Invalid database number in URL path: %s, using default 0", path
)
return 0
def get_aiocache_config() -> dict[str, Any]:
"""Build an aiocache configuration dict from environment.
Uses VALKEY_URL or REDIS_URL (both are Redis-wire-compatible) to
configure ``aiocache.RedisCache``. Falls back to
``aiocache.SimpleMemoryCache`` when neither variable is set.
Returns:
Configuration dict suitable for ``aiocache.caches.set_config()``.
"""
conn = parse_cache_url()
if conn is not None:
return {
"default": {
"cache": "aiocache.RedisCache",
"endpoint": conn["host"],
"port": conn["port"],
"db": conn.get("db", 0),
"password": conn.get("password"),
}
}
return {
"default": {
"cache": "aiocache.SimpleMemoryCache",
}
}
def use_valkey_cache() -> bool:
"""Return True if VALKEY_URL is set in the environment."""
return bool(os.environ.get("VALKEY_URL"))

View File

@@ -1,511 +0,0 @@
"""Tests for ValkeyCache implementation."""
from __future__ import annotations
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from crewai.memory.storage.valkey_cache import ValkeyCache
@pytest.fixture
def mock_glide_client() -> AsyncMock:
"""Create a mock GlideClient for testing."""
client = AsyncMock()
client.get = AsyncMock()
client.set = AsyncMock()
client.delete = AsyncMock()
client.exists = AsyncMock()
client.close = AsyncMock()
return client
@pytest.fixture
def valkey_cache(mock_glide_client: AsyncMock) -> ValkeyCache:
"""Create a ValkeyCache instance with mocked client."""
cache = ValkeyCache(host="localhost", port=6379, db=0)
# Mock the client creation to return our mock
async def mock_create_client() -> AsyncMock:
cache._client = mock_glide_client
return mock_glide_client
cache._get_client = mock_create_client # type: ignore[method-assign]
return cache
class TestValkeyCacheBasicOperations:
"""Tests for basic ValkeyCache operations (get/set/delete/exists)."""
@pytest.mark.asyncio
async def test_set_and_get_string_value(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test setting and getting a string value."""
# Mock get to return serialized string
mock_glide_client.get.return_value = json.dumps("test_value")
# Set value
await valkey_cache.set("test_key", "test_value")
# Verify set was called
mock_glide_client.set.assert_called_once()
call_args = mock_glide_client.set.call_args
assert call_args[0][0] == "test_key"
assert call_args[0][1] == json.dumps("test_value")
# Get value
result = await valkey_cache.get("test_key")
# Verify get was called and result is correct
mock_glide_client.get.assert_called_once_with("test_key")
assert result == "test_value"
@pytest.mark.asyncio
async def test_set_and_get_dict_value(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test setting and getting a dictionary value."""
test_dict = {"key1": "value1", "key2": 42, "key3": [1, 2, 3]}
mock_glide_client.get.return_value = json.dumps(test_dict)
# Set value
await valkey_cache.set("dict_key", test_dict)
# Verify set was called with serialized dict
mock_glide_client.set.assert_called_once()
call_args = mock_glide_client.set.call_args
assert call_args[0][0] == "dict_key"
assert call_args[0][1] == json.dumps(test_dict)
# Get value
result = await valkey_cache.get("dict_key")
# Verify result matches original dict
assert result == test_dict
@pytest.mark.asyncio
async def test_set_and_get_list_value(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test setting and getting a list value."""
test_list = [1, "two", 3.0, {"nested": "dict"}]
mock_glide_client.get.return_value = json.dumps(test_list)
await valkey_cache.set("list_key", test_list)
result = await valkey_cache.get("list_key")
assert result == test_list
@pytest.mark.asyncio
async def test_get_nonexistent_key_returns_none(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test getting a non-existent key returns None."""
mock_glide_client.get.return_value = None
result = await valkey_cache.get("nonexistent_key")
assert result is None
mock_glide_client.get.assert_called_once_with("nonexistent_key")
@pytest.mark.asyncio
async def test_delete_key(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test deleting a key."""
await valkey_cache.delete("test_key")
mock_glide_client.delete.assert_called_once_with(["test_key"])
@pytest.mark.asyncio
async def test_exists_returns_true_for_existing_key(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test exists returns True for existing key."""
mock_glide_client.exists.return_value = 1
result = await valkey_cache.exists("existing_key")
assert result is True
mock_glide_client.exists.assert_called_once_with(["existing_key"])
@pytest.mark.asyncio
async def test_exists_returns_false_for_nonexistent_key(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test exists returns False for non-existent key."""
mock_glide_client.exists.return_value = 0
result = await valkey_cache.exists("nonexistent_key")
assert result is False
mock_glide_client.exists.assert_called_once_with(["nonexistent_key"])
class TestValkeyCacheTTL:
"""Tests for ValkeyCache TTL functionality."""
@pytest.mark.asyncio
async def test_set_with_explicit_ttl(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test setting a value with explicit TTL."""
await valkey_cache.set("ttl_key", "value", ttl=3600)
# Verify set was called with expiry
mock_glide_client.set.assert_called_once()
call_args = mock_glide_client.set.call_args
assert call_args[0][0] == "ttl_key"
assert call_args[0][1] == json.dumps("value")
assert "expiry" in call_args[1]
@pytest.mark.asyncio
async def test_set_with_default_ttl(
self, mock_glide_client: AsyncMock
) -> None:
"""Test setting a value with default TTL from constructor."""
cache = ValkeyCache(host="localhost", port=6379, default_ttl=1800)
async def mock_create_client() -> AsyncMock:
cache._client = mock_glide_client
return mock_glide_client
cache._get_client = mock_create_client # type: ignore[method-assign]
await cache.set("default_ttl_key", "value")
# Verify set was called with default TTL
mock_glide_client.set.assert_called_once()
call_args = mock_glide_client.set.call_args
assert "expiry" in call_args[1]
@pytest.mark.asyncio
async def test_set_without_ttl(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test setting a value without TTL (no expiration)."""
await valkey_cache.set("no_ttl_key", "value")
# Verify set was called without expiry
mock_glide_client.set.assert_called_once()
call_args = mock_glide_client.set.call_args
assert call_args[0][0] == "no_ttl_key"
assert call_args[0][1] == json.dumps("value")
# Should not have expiry parameter
assert "expiry" not in call_args[1] or call_args[1].get("expiry") is None
@pytest.mark.asyncio
async def test_set_with_zero_ttl_no_expiration(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test setting a value with TTL=0 means no expiration."""
await valkey_cache.set("zero_ttl_key", "value", ttl=0)
# Verify set was called without expiry
mock_glide_client.set.assert_called_once()
call_args = mock_glide_client.set.call_args
assert "expiry" not in call_args[1] or call_args[1].get("expiry") is None
@pytest.mark.asyncio
async def test_explicit_ttl_overrides_default(
self, mock_glide_client: AsyncMock
) -> None:
"""Test explicit TTL overrides default TTL."""
cache = ValkeyCache(host="localhost", port=6379, default_ttl=1800)
async def mock_create_client() -> AsyncMock:
cache._client = mock_glide_client
return mock_glide_client
cache._get_client = mock_create_client # type: ignore[method-assign]
await cache.set("override_key", "value", ttl=7200)
# Verify set was called with explicit TTL (7200), not default (1800)
mock_glide_client.set.assert_called_once()
call_args = mock_glide_client.set.call_args
assert "expiry" in call_args[1]
class TestValkeyCacheJSONSerialization:
"""Tests for ValkeyCache JSON serialization edge cases."""
@pytest.mark.asyncio
async def test_serialize_none_value(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test serializing None value."""
mock_glide_client.get.return_value = json.dumps(None)
await valkey_cache.set("none_key", None)
result = await valkey_cache.get("none_key")
assert result is None
@pytest.mark.asyncio
async def test_serialize_boolean_values(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test serializing boolean values."""
mock_glide_client.get.side_effect = [
json.dumps(True),
json.dumps(False),
]
await valkey_cache.set("true_key", True)
await valkey_cache.set("false_key", False)
result_true = await valkey_cache.get("true_key")
result_false = await valkey_cache.get("false_key")
assert result_true is True
assert result_false is False
@pytest.mark.asyncio
async def test_serialize_numeric_values(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test serializing numeric values (int, float)."""
mock_glide_client.get.side_effect = [
json.dumps(42),
json.dumps(3.14159),
json.dumps(0),
json.dumps(-100),
]
await valkey_cache.set("int_key", 42)
await valkey_cache.set("float_key", 3.14159)
await valkey_cache.set("zero_key", 0)
await valkey_cache.set("negative_key", -100)
assert await valkey_cache.get("int_key") == 42
assert await valkey_cache.get("float_key") == 3.14159
assert await valkey_cache.get("zero_key") == 0
assert await valkey_cache.get("negative_key") == -100
@pytest.mark.asyncio
async def test_serialize_empty_collections(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test serializing empty collections."""
mock_glide_client.get.side_effect = [
json.dumps([]),
json.dumps({}),
json.dumps(""),
]
await valkey_cache.set("empty_list", [])
await valkey_cache.set("empty_dict", {})
await valkey_cache.set("empty_string", "")
assert await valkey_cache.get("empty_list") == []
assert await valkey_cache.get("empty_dict") == {}
assert await valkey_cache.get("empty_string") == ""
@pytest.mark.asyncio
async def test_serialize_nested_structures(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test serializing deeply nested structures."""
nested_data = {
"level1": {
"level2": {
"level3": [1, 2, {"level4": "deep"}]
}
},
"list": [{"a": 1}, {"b": 2}]
}
mock_glide_client.get.return_value = json.dumps(nested_data)
await valkey_cache.set("nested_key", nested_data)
result = await valkey_cache.get("nested_key")
assert result == nested_data
@pytest.mark.asyncio
async def test_deserialize_invalid_json_returns_none(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test deserializing invalid JSON returns None and logs warning."""
mock_glide_client.get.return_value = "invalid json {{"
with patch("crewai.memory.storage.valkey_cache._logger") as mock_logger:
result = await valkey_cache.get("invalid_key")
assert result is None
mock_logger.warning.assert_called_once()
@pytest.mark.asyncio
async def test_serialize_unicode_strings(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test serializing unicode strings."""
unicode_data = "Hello 世界 🌍 Привет"
mock_glide_client.get.return_value = json.dumps(unicode_data)
await valkey_cache.set("unicode_key", unicode_data)
result = await valkey_cache.get("unicode_key")
assert result == unicode_data
class TestValkeyCacheConnectionManagement:
"""Tests for ValkeyCache connection management."""
@pytest.mark.asyncio
async def test_lazy_client_initialization(self) -> None:
"""Test client is initialized lazily on first use."""
cache = ValkeyCache(host="localhost", port=6379)
# Client should be None initially
assert cache._client is None
# Mock GlideClient.create
with patch("crewai.memory.storage.valkey_cache.GlideClient") as mock_glide:
mock_client = AsyncMock()
mock_glide.create = AsyncMock(return_value=mock_client)
mock_client.get = AsyncMock(return_value=None)
# First operation should initialize client
await cache.get("test_key")
# Client should now be initialized
assert cache._client is not None
mock_glide.create.assert_called_once()
@pytest.mark.asyncio
async def test_client_reuse_across_operations(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test client is reused across multiple operations."""
mock_glide_client.get.return_value = json.dumps("value")
mock_glide_client.exists.return_value = 1
# Perform multiple operations
await valkey_cache.get("key1")
await valkey_cache.set("key2", "value2")
await valkey_cache.exists("key3")
await valkey_cache.delete("key4")
# _get_client should return the same client instance
client1 = await valkey_cache._get_client()
client2 = await valkey_cache._get_client()
assert client1 is client2
@pytest.mark.asyncio
async def test_close_connection(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test closing the client connection."""
# Initialize client
await valkey_cache._get_client()
assert valkey_cache._client is not None
# Close connection
await valkey_cache.close()
# Verify close was called and client is None
mock_glide_client.close.assert_called_once()
assert valkey_cache._client is None
@pytest.mark.asyncio
async def test_connection_error_raises_runtime_error(self) -> None:
"""Test connection error raises RuntimeError with descriptive message."""
cache = ValkeyCache(host="invalid-host", port=9999)
with patch("crewai.memory.storage.valkey_cache.GlideClient") as mock_glide:
mock_glide.create = AsyncMock(side_effect=Exception("Connection refused"))
with pytest.raises(RuntimeError) as exc_info:
await cache._get_client()
assert "Cannot connect to Valkey" in str(exc_info.value)
@pytest.mark.asyncio
async def test_authentication_with_password(self) -> None:
"""Test client initialization with password authentication."""
cache = ValkeyCache(
host="localhost",
port=6379,
password="secret_password"
)
with patch("crewai.memory.storage.valkey_cache.GlideClient") as mock_glide:
mock_client = AsyncMock()
mock_glide.create = AsyncMock(return_value=mock_client)
await cache._get_client()
# Verify GlideClient.create was called with credentials
mock_glide.create.assert_called_once()
config = mock_glide.create.call_args[0][0]
assert hasattr(config, "credentials")
class TestValkeyCacheEdgeCases:
"""Tests for ValkeyCache edge cases and error conditions."""
@pytest.mark.asyncio
async def test_set_with_special_characters_in_key(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test setting values with special characters in key."""
special_keys = [
"key:with:colons",
"key/with/slashes",
"key-with-dashes",
"key_with_underscores",
"key.with.dots",
]
for key in special_keys:
await valkey_cache.set(key, "value")
mock_glide_client.set.assert_called()
@pytest.mark.asyncio
async def test_large_value_serialization(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test serializing large values."""
large_list = list(range(10000))
mock_glide_client.get.return_value = json.dumps(large_list)
await valkey_cache.set("large_key", large_list)
result = await valkey_cache.get("large_key")
assert result == large_list
@pytest.mark.asyncio
async def test_concurrent_operations(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test concurrent cache operations."""
import asyncio
mock_glide_client.get.return_value = json.dumps("value")
# Perform concurrent operations
tasks = [
valkey_cache.set(f"key{i}", f"value{i}")
for i in range(10)
]
await asyncio.gather(*tasks)
# Verify all operations completed
assert mock_glide_client.set.call_count == 10
@pytest.mark.asyncio
async def test_set_non_serializable_value_raises_type_error(
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
) -> None:
"""Test that non-JSON-serializable values raise TypeError."""
from datetime import datetime
with pytest.raises(TypeError, match="not JSON-serializable"):
await valkey_cache.set("bad_key", datetime.now())
# Verify set was never called on the client
mock_glide_client.set.assert_not_called()

File diff suppressed because it is too large Load Diff

View File

@@ -1,267 +0,0 @@
"""Tests for ValkeyStorage error handling."""
from __future__ import annotations
import asyncio
import json
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import uuid4
import pytest
from crewai.memory.storage.valkey_storage import ValkeyStorage
from crewai.memory.types import MemoryRecord
@pytest.fixture
def mock_glide_client() -> AsyncMock:
"""Create a mock GlideClient for testing."""
client = AsyncMock()
client.hset = AsyncMock(return_value=1)
client.zrange = AsyncMock(return_value=[])
client.zadd = AsyncMock()
client.sadd = AsyncMock()
client.hgetall = AsyncMock(return_value={})
client.close = AsyncMock()
return client
@pytest.fixture
def valkey_storage(mock_glide_client: AsyncMock) -> ValkeyStorage:
"""Create a ValkeyStorage instance with mocked client."""
storage = ValkeyStorage(host="localhost", port=6379, db=0)
# Mock the client creation to return our mock
async def mock_create_client() -> AsyncMock:
storage._client = mock_glide_client
return mock_glide_client
storage._get_client = mock_create_client # type: ignore[method-assign]
return storage
class TestSerializationErrors:
"""Tests for serialization error handling."""
def test_serialization_error_raises_descriptive_exception(
self, valkey_storage: ValkeyStorage
) -> None:
"""Test that serialization errors raise descriptive ValueError."""
# Create a record with non-serializable metadata
record = MemoryRecord(
id="test-id",
content="test content",
scope="/test",
categories=["test"],
metadata={"bad_key": object()}, # Non-serializable object
importance=0.5,
created_at=datetime.now(),
last_accessed=datetime.now(),
embedding=[0.1, 0.2, 0.3],
)
# Should raise ValueError with descriptive message
with pytest.raises(ValueError, match="Failed to serialize record test-id"):
valkey_storage._record_to_dict(record)
def test_serialization_error_includes_cause(
self, valkey_storage: ValkeyStorage
) -> None:
"""Test that serialization error includes the original exception as cause."""
# Create a mock record that will fail during JSON serialization
# We need to bypass Pydantic validation, so we'll patch json.dumps
record = MemoryRecord(
id="test-id-2",
content="test content",
scope="/test",
categories=["valid"],
metadata={"key": "value"},
importance=0.5,
created_at=datetime.now(),
last_accessed=datetime.now(),
embedding=[0.1, 0.2, 0.3],
)
# Patch json.dumps to raise an error
with patch("json.dumps", side_effect=TypeError("Cannot serialize")):
with pytest.raises(ValueError) as exc_info:
valkey_storage._record_to_dict(record)
# Verify the exception has a cause
assert exc_info.value.__cause__ is not None
assert isinstance(exc_info.value.__cause__, TypeError)
class TestDeserializationErrors:
"""Tests for deserialization error handling."""
def test_deserialization_error_logs_and_returns_none(
self, valkey_storage: ValkeyStorage, caplog: pytest.LogCaptureFixture
) -> None:
"""Test that deserialization errors log error and return None."""
# Create malformed data (missing required fields)
malformed_data = {
"id": "test-id",
"content": "test content",
# Missing scope, categories, metadata, etc.
}
# Should return None and log error
result = valkey_storage._dict_to_record(malformed_data)
assert result is None
assert "Failed to deserialize record test-id" in caplog.text
def test_deserialization_with_invalid_json_categories_uses_tag_fallback(
self, valkey_storage: ValkeyStorage
) -> None:
"""Test that non-JSON categories fall back to TAG (comma-separated) parsing."""
# Create data with non-JSON categories string
data = {
"id": "test-id-json",
"content": "test content",
"scope": "/test",
"categories": "not valid json [", # Not JSON, treated as TAG format
"metadata": "{}",
"importance": "0.5",
"created_at": "2024-01-01T12:00:00",
"last_accessed": "2024-01-01T12:00:00",
"source": "",
"private": "false",
}
result = valkey_storage._dict_to_record(data)
# TAG fallback: comma-split produces the raw string as a single category
assert result is not None
assert result.id == "test-id-json"
assert result.categories == ["not valid json ["]
def test_deserialization_with_invalid_datetime_returns_none(
self, valkey_storage: ValkeyStorage, caplog: pytest.LogCaptureFixture
) -> None:
"""Test that invalid datetime format returns None."""
# Create data with invalid datetime
invalid_data = {
"id": "test-id-datetime",
"content": "test content",
"scope": "/test",
"categories": '["test"]',
"metadata": "{}",
"importance": "0.5",
"created_at": "not a datetime", # Invalid datetime
"last_accessed": "2024-01-01T12:00:00",
"source": "",
"private": "false",
}
result = valkey_storage._dict_to_record(invalid_data)
assert result is None
assert "Failed to deserialize record test-id-datetime" in caplog.text
def test_deserialization_with_invalid_float_returns_none(
self, valkey_storage: ValkeyStorage, caplog: pytest.LogCaptureFixture
) -> None:
"""Test that invalid float importance returns None."""
# Create data with invalid float
invalid_data = {
"id": "test-id-float",
"content": "test content",
"scope": "/test",
"categories": '["test"]',
"metadata": "{}",
"importance": "not a float", # Invalid float
"created_at": "2024-01-01T12:00:00",
"last_accessed": "2024-01-01T12:00:00",
"source": "",
"private": "false",
}
result = valkey_storage._dict_to_record(invalid_data)
assert result is None
assert "Failed to deserialize record test-id-float" in caplog.text
def test_deserialization_with_bytes_keys_uses_tag_fallback(
self, valkey_storage: ValkeyStorage
) -> None:
"""Test that deserialization handles bytes keys with non-JSON categories via TAG fallback."""
# Create data with bytes keys (as returned by Valkey)
bytes_data = {
b"id": b"test-id-bytes",
b"content": b"test content",
b"scope": b"/test",
b"categories": b"invalid json [", # Not JSON, treated as TAG format
b"metadata": b"{}",
b"importance": b"0.5",
b"created_at": b"2024-01-01T12:00:00",
b"last_accessed": b"2024-01-01T12:00:00",
}
result = valkey_storage._dict_to_record(bytes_data)
# TAG fallback: comma-split produces the raw string as a single category
assert result is not None
assert result.id == "test-id-bytes"
assert result.categories == ["invalid json ["]
class TestRetryBehaviorIntegration:
"""Integration tests demonstrating retry behavior patterns."""
@pytest.mark.asyncio
async def test_mock_client_operation_with_retry_pattern(
self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test demonstrating how retry would work with client operations."""
from glide import ClosingError
# Mock a client operation that fails once
mock_glide_client.hgetall.side_effect = [
ClosingError("Connection lost"),
{
b"id": b"test-id",
b"content": b"test content",
b"scope": b"/test",
b"categories": b'["test"]',
b"metadata": b"{}",
b"importance": b"0.5",
b"created_at": b"2024-01-01T12:00:00",
b"last_accessed": b"2024-01-01T12:00:00",
b"source": b"",
b"private": b"false",
b"embedding": b"",
},
]
# First call fails, second succeeds
with pytest.raises(ClosingError):
await mock_glide_client.hgetall("record:test-id")
# Second call succeeds
result = await mock_glide_client.hgetall("record:test-id")
assert result is not None
@pytest.mark.asyncio
async def test_serialization_error_not_retried(
self, valkey_storage: ValkeyStorage
) -> None:
"""Test that serialization errors are not retried (they're not connection errors)."""
# Create a record with non-serializable data
record = MemoryRecord(
id="test-id",
content="test content",
scope="/test",
categories=["test"],
metadata={"bad": object()},
importance=0.5,
created_at=datetime.now(),
last_accessed=datetime.now(),
embedding=[0.1, 0.2, 0.3],
)
# Serialization error should not be retried
with pytest.raises(ValueError, match="Failed to serialize"):
valkey_storage._record_to_dict(record)

File diff suppressed because it is too large Load Diff

View File

@@ -1,998 +0,0 @@
"""Tests for ValkeyStorage vector search operation."""
from __future__ import annotations
import json
from datetime import datetime
from unittest.mock import AsyncMock, patch
from uuid import uuid4
import pytest
from crewai.memory.storage.valkey_storage import ValkeyStorage
from crewai.memory.types import MemoryRecord
@pytest.fixture
def mock_glide_client() -> AsyncMock:
"""Create a mock GlideClient for testing."""
client = AsyncMock()
client.hset = AsyncMock(return_value=1)
client.zrange = AsyncMock(return_value=[])
client.zadd = AsyncMock()
client.sadd = AsyncMock()
client.hgetall = AsyncMock(return_value={})
client.close = AsyncMock()
return client
@pytest.fixture
def valkey_storage(mock_glide_client: AsyncMock) -> ValkeyStorage:
"""Create a ValkeyStorage instance with mocked client."""
storage = ValkeyStorage(host="localhost", port=6379, db=0)
# Mock the client creation to return our mock
async def mock_create_client() -> AsyncMock:
storage._client = mock_glide_client
return mock_glide_client
storage._get_client = mock_create_client # type: ignore[method-assign]
return storage
def create_mock_ft_search_response(
records: list[tuple[MemoryRecord, float]]
) -> list[int | dict[str, dict[str, str]]]:
"""Create a mock FT.SEARCH response in native dict format.
Args:
records: List of (MemoryRecord, score) tuples to include in response.
Returns:
Mock FT.SEARCH response in the native format:
[total_count, {doc_key: {field: value, ...}, ...}]
"""
if not records:
return [0]
docs: dict[str, dict[str, str]] = {}
for record, score in records:
doc_key = f"record:{record.id}"
# Build field dict
fields: dict[str, str] = {}
fields["id"] = record.id
fields["content"] = record.content
fields["scope"] = record.scope
fields["categories"] = json.dumps(record.categories)
fields["metadata"] = json.dumps(record.metadata)
fields["importance"] = str(record.importance)
fields["created_at"] = record.created_at.isoformat()
fields["last_accessed"] = record.last_accessed.isoformat()
fields["source"] = record.source or ""
fields["private"] = "true" if record.private else "false"
# Add score (Valkey Search returns cosine distance, not similarity)
# Convert similarity to distance: distance = 2 * (1 - similarity)
distance = 2.0 * (1.0 - score)
fields["score"] = str(distance)
# Add embedding if present
if record.embedding:
fields["embedding"] = json.dumps(record.embedding)
docs[doc_key] = fields
return [len(records), docs]
class TestValkeyStorageVectorSearch:
"""Tests for ValkeyStorage vector search operation."""
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_with_no_filters_returns_all_records(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search with no filters returns all records."""
# Create test records
record1 = MemoryRecord(
id="record-1",
content="First test record",
scope="/test",
categories=["cat1"],
metadata={"key": "value1"},
importance=0.8,
created_at=datetime(2024, 1, 1, 10, 0, 0),
last_accessed=datetime(2024, 1, 1, 11, 0, 0),
embedding=[0.1, 0.2, 0.3, 0.4],
)
record2 = MemoryRecord(
id="record-2",
content="Second test record",
scope="/test",
categories=["cat2"],
metadata={"key": "value2"},
importance=0.6,
created_at=datetime(2024, 1, 2, 10, 0, 0),
last_accessed=datetime(2024, 1, 2, 11, 0, 0),
embedding=[0.2, 0.3, 0.4, 0.5],
)
# Mock FT.INFO to simulate index exists
mock_ft_list.return_value = [b"memory_index"]
# Mock FT.SEARCH to return both records
mock_ft_search.return_value = create_mock_ft_search_response([
(record1, 0.95),
(record2, 0.85),
])
# Perform search with no filters
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(query_embedding, limit=10)
# Verify ft.search was called
mock_ft_search.assert_called_once()
# Verify query contains only KNN part (no filters)
call_args = mock_ft_search.call_args
query = call_args[0][2] # 3rd positional arg: query string
assert "*=>[KNN 10 @embedding $BLOB AS score]" in query
assert "@scope" not in query
assert "@categories" not in query
# Verify results
assert len(results) == 2
assert results[0][0].id == "record-1"
assert results[0][1] == 0.95
assert results[1][0].id == "record-2"
assert results[1][1] == 0.85
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_with_scope_filter_only(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search with scope filter only."""
record1 = MemoryRecord(
id="record-1",
content="Record in scope",
scope="/agent/task",
embedding=[0.1, 0.2, 0.3, 0.4],
)
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)])
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(
query_embedding,
scope_prefix="/agent",
limit=10
)
# Verify query contains scope filter
call_args = mock_ft_search.call_args
query = call_args[0][2]
assert "(@scope:{/agent*})=>[KNN 10 @embedding $BLOB AS score]" in query
# Verify results
assert len(results) == 1
assert results[0][0].id == "record-1"
assert results[0][0].scope == "/agent/task"
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_with_category_filter_only(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search with category filter only."""
record1 = MemoryRecord(
id="record-1",
content="Record with planning category",
scope="/test",
categories=["planning"],
embedding=[0.1, 0.2, 0.3, 0.4],
)
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.88)])
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(
query_embedding,
categories=["planning", "execution"],
limit=10
)
# Verify query contains category filter with OR logic
call_args = mock_ft_search.call_args
query = call_args[0][2]
assert "(@categories:{planning|execution})=>[KNN 10 @embedding $BLOB AS score]" in query
# Verify results
assert len(results) == 1
assert results[0][0].id == "record-1"
assert "planning" in results[0][0].categories
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_with_metadata_filter_only(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search with metadata filter only."""
record1 = MemoryRecord(
id="record-1",
content="Record with metadata",
scope="/test",
metadata={"agent_id": "agent-1", "priority": "high"},
embedding=[0.1, 0.2, 0.3, 0.4],
)
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.92)])
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(
query_embedding,
metadata_filter={"agent_id": "agent-1", "priority": "high"},
limit=10
)
# Verify query contains metadata filters (AND logic)
call_args = mock_ft_search.call_args
query = call_args[0][2]
assert "@agent_id:{agent\\-1}" in query or "@agent_id:{agent-1}" in query
assert "@priority:{high}" in query
assert "=>[KNN 10 @embedding $BLOB AS score]" in query
# Verify results
assert len(results) == 1
assert results[0][0].id == "record-1"
assert results[0][0].metadata["agent_id"] == "agent-1"
assert results[0][0].metadata["priority"] == "high"
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_with_combined_filters(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search with combined filters (scope + categories + metadata)."""
record1 = MemoryRecord(
id="record-1",
content="Record matching all filters",
scope="/agent/task",
categories=["planning"],
metadata={"agent_id": "agent-1"},
embedding=[0.1, 0.2, 0.3, 0.4],
)
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.93)])
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(
query_embedding,
scope_prefix="/agent",
categories=["planning"],
metadata_filter={"agent_id": "agent-1"},
limit=10
)
# Verify query contains all filters
call_args = mock_ft_search.call_args
query = call_args[0][2]
assert "@scope:{/agent*}" in query
assert "@categories:{planning}" in query
assert "@agent_id:{agent\\-1}" in query or "@agent_id:{agent-1}" in query
assert "=>[KNN 10 @embedding $BLOB AS score]" in query
# Verify results
assert len(results) == 1
assert results[0][0].id == "record-1"
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_respects_limit_parameter(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search respects limit parameter."""
records = [
(
MemoryRecord(
id=f"record-{i}",
content=f"Record {i}",
scope="/test",
embedding=[0.1 * i, 0.2 * i, 0.3 * i, 0.4 * i],
),
0.9 - (i * 0.1)
)
for i in range(1, 6)
]
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = create_mock_ft_search_response(records[:3])
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(query_embedding, limit=3)
# Verify KNN limit in query
call_args = mock_ft_search.call_args
query = call_args[0][2]
assert "=>[KNN 3 @embedding $BLOB AS score]" in query
# Verify results respect limit
assert len(results) == 3
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_respects_min_score_parameter(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search respects min_score parameter."""
record1 = MemoryRecord(
id="record-1",
content="High score record",
scope="/test",
embedding=[0.1, 0.2, 0.3, 0.4],
)
record2 = MemoryRecord(
id="record-2",
content="Medium score record",
scope="/test",
embedding=[0.2, 0.3, 0.4, 0.5],
)
record3 = MemoryRecord(
id="record-3",
content="Low score record",
scope="/test",
embedding=[0.3, 0.4, 0.5, 0.6],
)
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = create_mock_ft_search_response([
(record1, 0.95),
(record2, 0.75),
(record3, 0.55),
])
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(
query_embedding,
limit=10,
min_score=0.7
)
# Verify only records with score >= 0.7 are returned
assert len(results) == 2
assert results[0][0].id == "record-1"
assert results[0][1] == 0.95
assert results[1][0].id == "record-2"
assert results[1][1] == 0.75
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_returns_results_ordered_by_descending_score(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search returns results ordered by descending score."""
record1 = MemoryRecord(
id="record-1",
content="Medium score",
scope="/test",
embedding=[0.1, 0.2, 0.3, 0.4],
)
record2 = MemoryRecord(
id="record-2",
content="Highest score",
scope="/test",
embedding=[0.2, 0.3, 0.4, 0.5],
)
record3 = MemoryRecord(
id="record-3",
content="Lowest score",
scope="/test",
embedding=[0.3, 0.4, 0.5, 0.6],
)
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = create_mock_ft_search_response([
(record1, 0.75),
(record2, 0.95),
(record3, 0.55),
])
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(query_embedding, limit=10)
# Verify results are ordered by descending score
assert len(results) == 3
assert results[0][0].id == "record-2"
assert results[0][1] == 0.95
assert results[1][0].id == "record-1"
assert results[1][1] == 0.75
assert results[2][0].id == "record-3"
assert results[2][1] == 0.55
# Verify scores are in descending order
for i in range(len(results) - 1):
assert results[i][1] >= results[i + 1][1]
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_with_empty_results(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search with no matching results."""
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = [0] # Total count = 0
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(query_embedding, limit=10)
# Verify empty results
assert len(results) == 0
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_with_special_characters_in_scope(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search with special characters in scope prefix."""
record1 = MemoryRecord(
id="record-1",
content="Record with special scope",
scope="/agent:task-1",
embedding=[0.1, 0.2, 0.3, 0.4],
)
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)])
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(
query_embedding,
scope_prefix="/agent:task",
limit=10
)
# Verify query contains escaped scope
call_args = mock_ft_search.call_args
query = call_args[0][2]
assert "@scope:{/agent\\:task*}" in query or "@scope:{/agent:task*}" in query
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_with_special_characters_in_categories(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search with special characters in categories."""
record1 = MemoryRecord(
id="record-1",
content="Record with special category",
scope="/test",
categories=["plan:execute"],
embedding=[0.1, 0.2, 0.3, 0.4],
)
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)])
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(
query_embedding,
categories=["plan:execute"],
limit=10
)
# Verify query contains escaped category
call_args = mock_ft_search.call_args
query = call_args[0][2]
assert "@categories:{plan\\:execute}" in query or "@categories:{plan:execute}" in query
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_with_numeric_metadata_values(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search with numeric metadata values."""
record1 = MemoryRecord(
id="record-1",
content="Record with numeric metadata",
scope="/test",
metadata={"count": 42, "score": 3.14},
embedding=[0.1, 0.2, 0.3, 0.4],
)
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)])
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(
query_embedding,
metadata_filter={"count": 42, "score": 3.14},
limit=10
)
# Verify query contains string-converted metadata values
call_args = mock_ft_search.call_args
query = call_args[0][2]
assert "@count:{42}" in query
assert "@score:{3" in query and "14}" in query
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_with_embedding_blob_parameter(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search passes embedding as BLOB parameter."""
record1 = MemoryRecord(
id="record-1",
content="Test record",
scope="/test",
embedding=[0.1, 0.2, 0.3, 0.4],
)
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)])
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(query_embedding, limit=10)
# Verify ft.search was called with search options containing BLOB param
call_args = mock_ft_search.call_args
# The 4th positional arg is the FtSearchOptions
search_options = call_args[0][3]
# The options object should have params with BLOB
assert search_options is not None
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_results_sorted_by_score(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search results are sorted by score (descending) automatically."""
record1 = MemoryRecord(
id="record-1",
content="Test record",
scope="/test",
embedding=[0.1, 0.2, 0.3, 0.4],
)
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)])
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(query_embedding, limit=10)
# Verify ft.search was called (results are auto-sorted by vector search)
mock_ft_search.assert_called_once()
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_with_return_fields(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search includes RETURN clause with all record fields."""
record1 = MemoryRecord(
id="record-1",
content="Test record",
scope="/test",
embedding=[0.1, 0.2, 0.3, 0.4],
)
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)])
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(query_embedding, limit=10)
# Verify ft.search was called with search options containing return fields
call_args = mock_ft_search.call_args
search_options = call_args[0][3]
assert search_options is not None
# The FtSearchOptions should have return_fields set
assert search_options.return_fields is not None
assert len(search_options.return_fields) == 11 # All fields including score
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.VectorFieldAttributesHnsw")
@patch("crewai.memory.storage.valkey_storage.ft.create")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_handles_valkey_search_not_available(
self, mock_ft_list: AsyncMock, mock_ft_create: AsyncMock,
mock_vector_attrs: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search raises error when Valkey Search module is not available."""
# Mock FT.INFO to fail (index doesn't exist)
mock_ft_list.return_value = []
# Mock FT.CREATE to fail (Search module not available)
mock_ft_create.side_effect = Exception("ERR unknown command 'ft.create'")
query_embedding = [0.1, 0.2, 0.3, 0.4]
with pytest.raises(RuntimeError, match="Valkey Search module is not available"):
await valkey_storage.asearch(query_embedding, limit=10)
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_handles_ft_search_error(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search handles FT.SEARCH errors gracefully."""
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.side_effect = Exception("ERR unknown command 'FT.SEARCH'")
query_embedding = [0.1, 0.2, 0.3, 0.4]
with pytest.raises(RuntimeError, match="Valkey Search module is not available"):
await valkey_storage.asearch(query_embedding, limit=10)
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_handles_malformed_ft_search_response(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search handles malformed FT.SEARCH response gracefully."""
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = None # Malformed response
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(query_embedding, limit=10)
# Verify empty results are returned (graceful handling)
assert len(results) == 0
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_handles_missing_score_field(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search handles missing score field in results."""
record1 = MemoryRecord(
id="record-1",
content="Test record",
scope="/test",
embedding=[0.1, 0.2, 0.3, 0.4],
)
# Create mock response without score field (dict format)
docs = {
f"record:{record1.id}": {
"id": record1.id,
"content": record1.content,
"scope": record1.scope,
"categories": str(record1.categories),
"metadata": str(record1.metadata),
"importance": str(record1.importance),
"created_at": record1.created_at.isoformat(),
"last_accessed": record1.last_accessed.isoformat(),
"source": record1.source or "",
"private": "false",
# No score field
}
}
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = [1, docs]
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(query_embedding, limit=10)
# Verify record is returned with default score of 0.0
assert len(results) == 1
assert results[0][0].id == "record-1"
assert results[0][1] == 0.0
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_filters_out_records_with_deserialization_errors(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search filters out records that fail deserialization."""
valid_record = MemoryRecord(
id="valid-record",
content="Valid record",
scope="/test",
embedding=[0.1, 0.2, 0.3, 0.4],
)
# Create mock response with one valid and one invalid record (dict format)
docs = {
f"record:{valid_record.id}": {
"id": valid_record.id,
"content": valid_record.content,
"scope": valid_record.scope,
"categories": str(valid_record.categories),
"metadata": str(valid_record.metadata),
"importance": str(valid_record.importance),
"created_at": valid_record.created_at.isoformat(),
"last_accessed": valid_record.last_accessed.isoformat(),
"source": valid_record.source or "",
"private": "false",
"score": "0.1",
},
"record:invalid-record": {
"id": "invalid-record",
# Missing content, scope, and other required fields
"score": "0.2",
},
}
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = [2, docs]
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(query_embedding, limit=10)
# Verify only valid record is returned
assert len(results) == 1
assert results[0][0].id == "valid-record"
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_converts_cosine_distance_to_similarity(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search converts Valkey Search cosine distance to similarity score."""
record1 = MemoryRecord(
id="record-1",
content="Test record",
scope="/test",
embedding=[0.1, 0.2, 0.3, 0.4],
)
# Create mock response with distance score (dict format)
docs = {
f"record:{record1.id}": {
"id": record1.id,
"content": record1.content,
"scope": record1.scope,
"categories": str(record1.categories),
"metadata": str(record1.metadata),
"importance": str(record1.importance),
"created_at": record1.created_at.isoformat(),
"last_accessed": record1.last_accessed.isoformat(),
"source": record1.source or "",
"private": "false",
"score": "0.1", # Distance = 0.1
}
}
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = [1, docs]
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(query_embedding, limit=10)
# Verify similarity score is correctly converted
assert len(results) == 1
assert results[0][0].id == "record-1"
# Distance 0.1 -> Similarity = 1 - (0.1 / 2) = 0.95
assert abs(results[0][1] - 0.95) < 0.01
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
def test_search_sync_wrapper(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test that sync search wrapper calls async implementation."""
record1 = MemoryRecord(
id="record-1",
content="Test record",
scope="/test",
embedding=[0.1, 0.2, 0.3, 0.4],
)
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)])
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = valkey_storage.search(query_embedding, limit=10)
# Verify ft.search was called
assert mock_ft_search.call_count >= 1
# Verify results
assert len(results) == 1
assert results[0][0].id == "record-1"
assert results[0][1] == 0.9
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_with_multiple_categories_uses_or_logic(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search with multiple categories uses OR logic."""
record1 = MemoryRecord(
id="record-1",
content="Record with one matching category",
scope="/test",
categories=["planning"],
embedding=[0.1, 0.2, 0.3, 0.4],
)
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)])
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(
query_embedding,
categories=["planning", "execution", "review"],
limit=10
)
# Verify query contains OR logic for categories
call_args = mock_ft_search.call_args
query = call_args[0][2]
assert "@categories:{planning|execution|review}" in query
# Verify record with only one matching category is returned
assert len(results) == 1
assert results[0][0].id == "record-1"
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_with_multiple_metadata_filters_uses_and_logic(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search with multiple metadata filters uses AND logic."""
record1 = MemoryRecord(
id="record-1",
content="Record matching all metadata",
scope="/test",
metadata={"agent_id": "agent-1", "priority": "high", "status": "active"},
embedding=[0.1, 0.2, 0.3, 0.4],
)
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)])
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(
query_embedding,
metadata_filter={"agent_id": "agent-1", "priority": "high", "status": "active"},
limit=10
)
# Verify query contains AND logic for metadata
call_args = mock_ft_search.call_args
query = call_args[0][2]
assert "@agent_id:" in query
assert "@priority:" in query
assert "@status:" in query
# Verify record matching all metadata is returned
assert len(results) == 1
assert results[0][0].id == "record-1"
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_with_zero_limit_returns_empty(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search with limit=0 returns empty results."""
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = [0]
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(query_embedding, limit=0)
# Verify empty results
assert len(results) == 0
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_with_min_score_one_filters_all(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search with min_score=1.0 filters out all non-perfect matches."""
record1 = MemoryRecord(
id="record-1",
content="High score but not perfect",
scope="/test",
embedding=[0.1, 0.2, 0.3, 0.4],
)
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.99)])
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(
query_embedding,
limit=10,
min_score=1.0
)
# Verify all results are filtered out
assert len(results) == 0
@pytest.mark.asyncio
@patch("crewai.memory.storage.valkey_storage.ft.search")
@patch("crewai.memory.storage.valkey_storage.ft.list")
async def test_search_with_min_score_zero_returns_all(
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
) -> None:
"""Test search with min_score=0.0 returns all results."""
record1 = MemoryRecord(
id="record-1",
content="High score",
scope="/test",
embedding=[0.1, 0.2, 0.3, 0.4],
)
record2 = MemoryRecord(
id="record-2",
content="Low score",
scope="/test",
embedding=[0.2, 0.3, 0.4, 0.5],
)
mock_ft_list.return_value = [b"memory_index"]
mock_ft_search.return_value = create_mock_ft_search_response([
(record1, 0.95),
(record2, 0.05),
])
query_embedding = [0.1, 0.2, 0.3, 0.4]
results = await valkey_storage.asearch(
query_embedding,
limit=10,
min_score=0.0
)
# Verify all results are returned
assert len(results) == 2
assert results[0][0].id == "record-1"
assert results[1][0].id == "record-2"

View File

@@ -1,115 +0,0 @@
"""Tests for embedding safety: bytes→float validators and async-safe embed_texts."""
from __future__ import annotations
import asyncio
import concurrent.futures
from unittest.mock import MagicMock
import numpy as np
import pytest
from crewai.memory.types import MemoryRecord, embed_text, embed_texts
class TestMemoryRecordEmbeddingValidator:
"""Tests for MemoryRecord.validate_embedding (bytes→list[float])."""
def test_none_embedding_stays_none(self) -> None:
r = MemoryRecord(content="test", embedding=None)
assert r.embedding is None
def test_list_of_floats_passes_through(self) -> None:
r = MemoryRecord(content="test", embedding=[0.1, 0.2, 0.3])
assert r.embedding == [0.1, 0.2, 0.3]
def test_bytes_converted_to_list_float(self) -> None:
arr = np.array([0.1, 0.2, 0.3], dtype=np.float32)
raw_bytes = arr.tobytes()
r = MemoryRecord(content="test", embedding=raw_bytes)
assert r.embedding is not None
assert len(r.embedding) == 3
assert all(isinstance(x, float) for x in r.embedding)
np.testing.assert_allclose(r.embedding, [0.1, 0.2, 0.3], atol=1e-6)
def test_empty_bytes_becomes_none(self) -> None:
r = MemoryRecord(content="test", embedding=b"")
assert r.embedding is None
def test_list_of_ints_converted_to_floats(self) -> None:
r = MemoryRecord(content="test", embedding=[1, 2, 3])
assert r.embedding == [1.0, 2.0, 3.0]
assert all(isinstance(x, float) for x in r.embedding)
def test_numpy_array_converted_to_list(self) -> None:
arr = np.array([0.5, 0.6], dtype=np.float32)
r = MemoryRecord(content="test", embedding=arr)
assert r.embedding is not None
assert isinstance(r.embedding, list)
assert len(r.embedding) == 2
class TestEmbedTextsAsyncSafety:
"""Tests for embed_texts running safely in async context."""
def test_embed_texts_sync_context(self) -> None:
"""embed_texts works in a normal sync context."""
embedder = MagicMock(return_value=[[0.1, 0.2], [0.3, 0.4]])
result = embed_texts(embedder, ["hello", "world"])
assert len(result) == 2
assert result[0] == [0.1, 0.2]
embedder.assert_called_once()
def test_embed_texts_empty_input(self) -> None:
embedder = MagicMock()
assert embed_texts(embedder, []) == []
embedder.assert_not_called()
def test_embed_texts_all_empty_strings(self) -> None:
embedder = MagicMock()
result = embed_texts(embedder, ["", " ", ""])
assert result == [[], [], []]
embedder.assert_not_called()
def test_embed_texts_skips_empty_preserves_positions(self) -> None:
embedder = MagicMock(return_value=[[0.1, 0.2]])
result = embed_texts(embedder, ["", "hello", ""])
assert result == [[], [0.1, 0.2], []]
embedder.assert_called_once_with(["hello"])
def test_embed_texts_in_async_context(self) -> None:
"""embed_texts uses thread pool when called from async context."""
embedder = MagicMock(return_value=[[0.1, 0.2]])
async def run() -> list[list[float]]:
return embed_texts(embedder, ["hello"])
result = asyncio.run(run())
assert result == [[0.1, 0.2]]
embedder.assert_called_once()
class TestEmbedText:
"""Tests for embed_text (single text)."""
def test_empty_string_returns_empty(self) -> None:
embedder = MagicMock()
assert embed_text(embedder, "") == []
embedder.assert_not_called()
def test_whitespace_only_returns_empty(self) -> None:
embedder = MagicMock()
assert embed_text(embedder, " ") == []
embedder.assert_not_called()
def test_normal_text_returns_embedding(self) -> None:
embedder = MagicMock(return_value=[[0.1, 0.2, 0.3]])
result = embed_text(embedder, "hello")
assert result == [0.1, 0.2, 0.3]
def test_numpy_array_result_converted(self) -> None:
arr = np.array([0.1, 0.2], dtype=np.float32)
embedder = MagicMock(return_value=[arr])
result = embed_text(embedder, "hello")
assert isinstance(result, list)
assert len(result) == 2

View File

@@ -1,125 +0,0 @@
"""Tests for shared cache configuration helpers."""
from __future__ import annotations
import os
from unittest.mock import patch
import pytest
from crewai.utilities.cache_config import (
get_aiocache_config,
parse_cache_url,
use_valkey_cache,
)
class TestParseCacheUrl:
"""Tests for parse_cache_url()."""
def test_returns_none_when_no_env_vars(self) -> None:
with patch.dict(os.environ, {}, clear=True):
assert parse_cache_url() is None
def test_parses_valkey_url(self) -> None:
with patch.dict(
os.environ, {"VALKEY_URL": "redis://myhost:6380/2"}, clear=True
):
result = parse_cache_url()
assert result is not None
assert result["host"] == "myhost"
assert result["port"] == 6380
assert result["db"] == 2
assert result["password"] is None
def test_parses_redis_url(self) -> None:
with patch.dict(
os.environ, {"REDIS_URL": "redis://localhost:6379/0"}, clear=True
):
result = parse_cache_url()
assert result is not None
assert result["host"] == "localhost"
assert result["port"] == 6379
assert result["db"] == 0
def test_valkey_url_takes_priority_over_redis_url(self) -> None:
with patch.dict(
os.environ,
{
"VALKEY_URL": "redis://valkey-host:6380/1",
"REDIS_URL": "redis://redis-host:6379/0",
},
clear=True,
):
result = parse_cache_url()
assert result is not None
assert result["host"] == "valkey-host"
assert result["port"] == 6380
def test_parses_password(self) -> None:
with patch.dict(
os.environ,
{"VALKEY_URL": "redis://:s3cret@myhost:6379/0"},
clear=True,
):
result = parse_cache_url()
assert result is not None
assert result["password"] == "s3cret"
def test_defaults_for_minimal_url(self) -> None:
with patch.dict(
os.environ, {"VALKEY_URL": "redis://myhost"}, clear=True
):
result = parse_cache_url()
assert result is not None
assert result["host"] == "myhost"
assert result["port"] == 6379
assert result["db"] == 0
assert result["password"] is None
def test_non_numeric_db_path_defaults_to_zero(self) -> None:
with patch.dict(
os.environ, {"VALKEY_URL": "redis://myhost:6379/mydb"}, clear=True
):
result = parse_cache_url()
assert result is not None
assert result["db"] == 0
class TestGetAiocacheConfig:
"""Tests for get_aiocache_config()."""
def test_returns_memory_cache_when_no_url(self) -> None:
with patch.dict(os.environ, {}, clear=True):
config = get_aiocache_config()
assert config["default"]["cache"] == "aiocache.SimpleMemoryCache"
def test_returns_redis_cache_when_url_set(self) -> None:
with patch.dict(
os.environ, {"VALKEY_URL": "redis://myhost:6380/2"}, clear=True
):
config = get_aiocache_config()
assert config["default"]["cache"] == "aiocache.RedisCache"
assert config["default"]["endpoint"] == "myhost"
assert config["default"]["port"] == 6380
assert config["default"]["db"] == 2
class TestUseValkeyCache:
"""Tests for use_valkey_cache()."""
def test_returns_false_when_not_set(self) -> None:
with patch.dict(os.environ, {}, clear=True):
assert use_valkey_cache() is False
def test_returns_true_when_set(self) -> None:
with patch.dict(
os.environ, {"VALKEY_URL": "redis://localhost:6379"}, clear=True
):
assert use_valkey_cache() is True
def test_returns_false_when_only_redis_url_set(self) -> None:
with patch.dict(
os.environ, {"REDIS_URL": "redis://localhost:6379"}, clear=True
):
assert use_valkey_cache() is False

View File

@@ -205,8 +205,6 @@ override-dependencies = [
"gitpython>=3.1.50,<4",
"langsmith>=0.7.31,<0.8",
"authlib>=1.6.11",
# scrapegraph-py 2.x removed Client class; pin until upstream fixes type ignores
"scrapegraph-py>=1.46.0,<2",
]
[tool.uv.workspace]

46
uv.lock generated
View File

@@ -13,7 +13,7 @@ resolution-markers = [
]
[options]
exclude-newer = "2026-05-08T20:07:25.621408Z"
exclude-newer = "2026-05-08T16:33:02.834109Z"
exclude-newer-span = "P3D"
[manifest]
@@ -38,7 +38,6 @@ overrides = [
{ name = "pypdf", specifier = ">=6.10.2,<7" },
{ name = "python-multipart", specifier = ">=0.0.27,<1" },
{ name = "rich", specifier = ">=13.7.1" },
{ name = "scrapegraph-py", specifier = ">=1.46.0,<2" },
{ name = "transformers", marker = "python_full_version >= '3.10'", specifier = ">=5.4.0" },
{ name = "urllib3", specifier = ">=2.7.0" },
{ name = "uv", specifier = ">=0.11.6,<1" },
@@ -1366,9 +1365,6 @@ qdrant-edge = [
tools = [
{ name = "crewai-tools" },
]
valkey = [
{ name = "valkey-glide" },
]
voyageai = [
{ name = "voyageai" },
]
@@ -1430,10 +1426,9 @@ requires-dist = [
{ name = "tokenizers", specifier = ">=0.21,<1" },
{ name = "tomli", specifier = "~=2.0.2" },
{ name = "tomli-w", specifier = "~=1.1.0" },
{ name = "valkey-glide", marker = "extra == 'valkey'", specifier = ">=1.3.0" },
{ name = "voyageai", marker = "extra == 'voyageai'", specifier = "~=0.3.5" },
]
provides-extras = ["a2a", "anthropic", "aws", "azure-ai-inference", "bedrock", "docling", "embeddings", "file-processing", "google-genai", "litellm", "mem0", "openpyxl", "pandas", "qdrant", "qdrant-edge", "tools", "valkey", "voyageai", "watson"]
provides-extras = ["a2a", "anthropic", "aws", "azure-ai-inference", "bedrock", "docling", "embeddings", "file-processing", "google-genai", "litellm", "mem0", "openpyxl", "pandas", "qdrant", "qdrant-edge", "tools", "voyageai", "watson"]
[[package]]
name = "crewai-cli"
@@ -9538,43 +9533,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/fa/6e/3e955517e22cbdd565f2f8b2e73d52528b14b8bcfdb04f62466b071de847/validators-0.35.0-py3-none-any.whl", hash = "sha256:e8c947097eae7892cb3d26868d637f79f47b4a0554bc6b80065dfe5aac3705dd", size = 44712, upload-time = "2025-05-01T05:42:04.203Z" },
]
[[package]]
name = "valkey-glide"
version = "2.0.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio" },
{ name = "protobuf" },
{ name = "typing-extensions", marker = "python_full_version < '3.11'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/32/35/fb0401c4bc7be748d937e95213786d21d9e56767b3ad816db5bad6f92c01/valkey_glide-2.0.1.tar.gz", hash = "sha256:4f9c62a88aedffd725cced7d28a9488b27e3f675d1a5294b4962624e97d346c4", size = 1026255, upload-time = "2025-06-20T01:08:15.861Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/44/a3/bf5ff3841538d0bb337371e073dc2c0e93f748f7f8b10a44806f36ab5fa1/valkey_glide-2.0.1-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:b3307934b76557b18ac559f327592cc09fc895fc653ba46010dd6d70fb6239dc", size = 5074638, upload-time = "2025-06-20T01:07:30.16Z" },
{ url = "https://files.pythonhosted.org/packages/0f/c4/20b66dced96bdca81aa294b39bc03018ed22628c52076752e8d1d3540a7d/valkey_glide-2.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6b83d34e2e723e97c41682479b0dce5882069066e808316292b363855992b449", size = 4750261, upload-time = "2025-06-20T01:07:32.452Z" },
{ url = "https://files.pythonhosted.org/packages/53/58/6440e66bde8963d86bc3c44d88f993059f2a9d7ebdb3256a695d035cff50/valkey_glide-2.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1baaf14d09d464ae645be5bdb5dc6b8a38b7eacf22f9dcb2907200c74fbdcdd3", size = 4767755, upload-time = "2025-06-20T01:07:33.86Z" },
{ url = "https://files.pythonhosted.org/packages/3b/69/dd5c350ce4d2cadde0d83beb601f05e1e62622895f268135e252e8bfc307/valkey_glide-2.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4427e7b4d54c9de289a35032c19d5956f94376f5d4335206c5ac4524cbd1c64a", size = 5094507, upload-time = "2025-06-20T01:07:35.349Z" },
{ url = "https://files.pythonhosted.org/packages/b5/dd/0dd6614e09123a5bd7273bf1159c958d1ea65e7decc2190b225d212e0cb9/valkey_glide-2.0.1-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:6379582d6fbd817697fb119274e37d397db450103cd15d4bd71e555e6d88fb6b", size = 5072939, upload-time = "2025-06-20T01:07:36.948Z" },
{ url = "https://files.pythonhosted.org/packages/c6/04/986188e407231a5f0bfaf31f31b68e3605ab66f4f4c656adfbb0345669d9/valkey_glide-2.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:0f1c0fe003026d8ae172369e0eb2337cbff16f41d4c085332487d6ca2e5282e6", size = 4750491, upload-time = "2025-06-20T01:07:38.659Z" },
{ url = "https://files.pythonhosted.org/packages/ac/fb/2f5cec71ae51c464502a892b6825426cd74a2c325827981726e557926c94/valkey_glide-2.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82c5f33598e50bcfec6fc924864931f3c6e30cd327a9c9562e1c7ac4e17e79fd", size = 4767597, upload-time = "2025-06-20T01:07:40.091Z" },
{ url = "https://files.pythonhosted.org/packages/3a/31/851a1a734fe5da5d520106fcfd824e4da09c3be8a0a2123bb4b1980db1ea/valkey_glide-2.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:79039a9dc23bb074680f171c12b36b3322357a0af85125534993e81a619dce21", size = 5094383, upload-time = "2025-06-20T01:07:41.329Z" },
{ url = "https://files.pythonhosted.org/packages/fc/6d/1e7b432cbc02fe63e7496b984b7fc830fb7de388c877b237e0579a6300fc/valkey_glide-2.0.1-cp312-cp312-macosx_10_7_x86_64.whl", hash = "sha256:f55ec8968b0fde364a5b3399be34b89dcb9068994b5cd384e20db0773ad12723", size = 5075024, upload-time = "2025-06-20T01:07:42.917Z" },
{ url = "https://files.pythonhosted.org/packages/ca/39/6e9f83970590d17d19f596e1b3a366d39077624888e3dd709309efc67690/valkey_glide-2.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:21598f49313912ad27dc700d7b13a3b4bfed7ed9dffad207235cac7d218f4966", size = 4748418, upload-time = "2025-06-20T01:07:44.64Z" },
{ url = "https://files.pythonhosted.org/packages/98/0e/91335c13dc8e7ceb95063234c16010b46e2dd874a2edef62dea155081647/valkey_glide-2.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f662285146529328e2b5a0a7047f699339b4e0d250eb1f252b15c9befa0dea05", size = 4767264, upload-time = "2025-06-20T01:07:46.185Z" },
{ url = "https://files.pythonhosted.org/packages/5f/94/ee4d9d441f83fec1464d9f4e52f7940bdd2aeb917589e6abd57498880876/valkey_glide-2.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3939aaa8411fcbba00cb1ff7c7ba73f388bb1deca919972f65cba7eda1d5fa95", size = 5093543, upload-time = "2025-06-20T01:07:47.345Z" },
{ url = "https://files.pythonhosted.org/packages/ed/7e/257a2e4b61ac29d5923f89bad5fe62be7b4a19e7bec78d191af3ce77aa39/valkey_glide-2.0.1-cp313-cp313-macosx_10_7_x86_64.whl", hash = "sha256:c49b53011a05b5820d0c660ee5c76574183b413a54faa33cf5c01ce77164d9c8", size = 5073114, upload-time = "2025-06-20T01:07:48.885Z" },
{ url = "https://files.pythonhosted.org/packages/20/14/a8a470679953980af7eac3ccb09638f2a76d4547116d48cbc69ae6f25080/valkey_glide-2.0.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:3a23572b83877537916ba36ad0a6b2fd96581534f0bc67ef8f8498bf4dbb2b40", size = 4747717, upload-time = "2025-06-20T01:07:50.092Z" },
{ url = "https://files.pythonhosted.org/packages/9f/49/f168dd0c778d9f6ff1be70d5d3bad7a86928fee563de7de5f4f575eddfd8/valkey_glide-2.0.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:943a2c4a5c38b8a6b53281201d5a4997ec454a6fdda72d27050eeb6aaef12afb", size = 4767128, upload-time = "2025-06-20T01:07:51.306Z" },
{ url = "https://files.pythonhosted.org/packages/43/be/68961b14ea133d1792ce50f6df1753848b5377c3e06a8dbe4e39188a549a/valkey_glide-2.0.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d770ec581acc59d5597e7ccaac37aee7e3b5e716a77a7fa44e2967db3a715f53", size = 5093522, upload-time = "2025-06-20T01:07:52.546Z" },
{ url = "https://files.pythonhosted.org/packages/51/2e/ad8595ffe84317385d52ceab8de1e9ef06a4da6b81ca8cd61b7961923de4/valkey_glide-2.0.1-pp310-pypy310_pp73-macosx_10_7_x86_64.whl", hash = "sha256:d4a9ccfe2b190c90622849dab62f9468acf76a282719a1245d272b649e7c12d1", size = 5074539, upload-time = "2025-06-20T01:07:59.87Z" },
{ url = "https://files.pythonhosted.org/packages/db/e5/2122541c7a64706f3631655209bb0b13723fb99db3c190d9a792b4e7d494/valkey_glide-2.0.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:9aa004077b82f64b23ea0d38d948b5116c23f7228dae3a5b4fcfa1799f8ff7de", size = 4753222, upload-time = "2025-06-20T01:08:01.376Z" },
{ url = "https://files.pythonhosted.org/packages/6c/13/cd9a20988a820ff61b127d3f850887b28bb734daf2c26d512d8e4c2e8e9e/valkey_glide-2.0.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:631a7a0e2045f7e5e3706e1903beeddf381a6529e318c27230798f4382579e4f", size = 4771530, upload-time = "2025-06-20T01:08:02.6Z" },
{ url = "https://files.pythonhosted.org/packages/c7/fc/047e89cc01b4cc71db1b6b8160d3b5d050097b408028022c002351238641/valkey_glide-2.0.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c5ed905fb62368c9bc6aef9df8d66269ef51f968dc527da4d7c956927382c1d", size = 5091242, upload-time = "2025-06-20T01:08:04.111Z" },
{ url = "https://files.pythonhosted.org/packages/1c/9e/68790c1a263f3a0094d67d0109be34631f6f79c2fbce5ced7e33a65ad363/valkey_glide-2.0.1-pp311-pypy311_pp73-macosx_10_7_x86_64.whl", hash = "sha256:53da3cc47c8d946ac76ecc4b468a469d3486778833a59162ea69aa7ce70cbb27", size = 5072793, upload-time = "2025-06-20T01:08:05.562Z" },
{ url = "https://files.pythonhosted.org/packages/1f/ae/a935af65ae4069d76c69f28f6bfb4533da8b89f7fc418beb7a1482cdd9ee/valkey_glide-2.0.1-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:e526a7d718cdd299d6b03091c12dcc15cd02ff22fe420f253341a4891c50824d", size = 4753435, upload-time = "2025-06-20T01:08:07.149Z" },
{ url = "https://files.pythonhosted.org/packages/3b/c2/c91d753a89dd87dce2fc8932cfbe174c7a1226c657b3cd64c063f21d4fe6/valkey_glide-2.0.1-pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8d3345ea2adf6f745733fa5157d8709bcf5ffbb2674391aeebd8f166a37cbc96", size = 4771401, upload-time = "2025-06-20T01:08:08.359Z" },
{ url = "https://files.pythonhosted.org/packages/00/fe/ad83cfc2ac87bf6bad2b75fa64fca5a6dd54568c1de551d36d369e07f948/valkey_glide-2.0.1-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d1c5fff0f12d2aa4277ddc335035b2c8e12bb11243c1a0f3c35071f4a8b11064", size = 5091360, upload-time = "2025-06-20T01:08:09.622Z" },
]
[[package]]
name = "vcrpy"
version = "7.0.0"