mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-13 21:18:10 +00:00
Compare commits
1 Commits
devin/valk
...
docs/custo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f3843432cd |
2445
docs/docs.json
2445
docs/docs.json
File diff suppressed because it is too large
Load Diff
139
docs/en/guides/tools/platform-tools-cli.mdx
Normal file
139
docs/en/guides/tools/platform-tools-cli.mdx
Normal file
@@ -0,0 +1,139 @@
|
||||
---
|
||||
title: Platform Tools CLI
|
||||
description: Create, publish, and install custom tools on the CrewAI platform using the CLI.
|
||||
icon: terminal
|
||||
mode: "wide"
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
The CrewAI CLI provides commands to manage custom tools on the **CrewAI platform** — a hosted tool registry that lets you share tools within your organization and across the community without publishing to PyPI.
|
||||
|
||||
| Command | Purpose |
|
||||
|---------|---------|
|
||||
| `crewai tool create <handle>` | Scaffold a new tool project |
|
||||
| `crewai tool publish` | Publish the tool to the CrewAI platform |
|
||||
| `crewai tool install <handle>` | Install a platform tool into your crew project |
|
||||
|
||||
<Note type="info" title="Platform vs PyPI">
|
||||
These commands manage tools on the **CrewAI platform registry**. If you want to publish a standalone Python package to PyPI instead, see the [Publish Custom Tools to PyPI](/en/guides/tools/publish-custom-tools) guide.
|
||||
</Note>
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- **CrewAI CLI** installed (`pip install crewai`)
|
||||
- **Authenticated** with the platform — run `crewai login` first
|
||||
|
||||
---
|
||||
|
||||
## Step 1: Create a Tool Project
|
||||
|
||||
Scaffold a new tool project:
|
||||
|
||||
```bash
|
||||
crewai tool create my_custom_tool
|
||||
```
|
||||
|
||||
This generates a project structure with the boilerplate you need to start building your tool.
|
||||
|
||||
<Tip>
|
||||
The `handle` is the unique identifier for your tool on the platform. Choose something descriptive and specific to what the tool does.
|
||||
</Tip>
|
||||
|
||||
### Implement Your Tool
|
||||
|
||||
Edit the generated tool file to add your logic. The tool follows the standard CrewAI tools contract — you can subclass `BaseTool` or use the `@tool` decorator:
|
||||
|
||||
```python
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
class MyCustomTool(BaseTool):
|
||||
name: str = "My Custom Tool"
|
||||
description: str = "Description of what this tool does — be specific so agents know when to use it."
|
||||
|
||||
def _run(self, argument: str) -> str:
|
||||
# Your tool logic here
|
||||
return "result"
|
||||
```
|
||||
|
||||
For the full tools API reference (input schemas, caching, async support, error handling), see the [Create Custom Tools](/en/learn/create-custom-tools) guide.
|
||||
|
||||
---
|
||||
|
||||
## Step 2: Publish to the Platform
|
||||
|
||||
From your tool project directory, publish it to the CrewAI platform:
|
||||
|
||||
```bash
|
||||
crewai tool publish
|
||||
```
|
||||
|
||||
### Visibility Options
|
||||
|
||||
| Flag | Description |
|
||||
|------|-------------|
|
||||
| `--public` | Make the tool available to all platform users |
|
||||
| `--private` | Restrict visibility to your organization |
|
||||
| `--force` | Bypass Git remote validations |
|
||||
|
||||
```bash
|
||||
# Publish as a public tool
|
||||
crewai tool publish --public
|
||||
|
||||
# Publish privately (organization only)
|
||||
crewai tool publish --private
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Step 3: Install a Platform Tool
|
||||
|
||||
To install a tool that's been published to the platform:
|
||||
|
||||
```bash
|
||||
crewai tool install my_custom_tool
|
||||
```
|
||||
|
||||
Once installed, you can use the tool in your crew like any other tool — assign it to an agent via the `tools` parameter.
|
||||
|
||||
---
|
||||
|
||||
## Full Lifecycle Example
|
||||
|
||||
```bash
|
||||
# 1. Authenticate with the platform
|
||||
crewai login
|
||||
|
||||
# 2. Scaffold a new tool
|
||||
crewai tool create weather_lookup
|
||||
|
||||
# 3. Implement your logic in the generated project
|
||||
cd weather_lookup
|
||||
# ... edit the tool file ...
|
||||
|
||||
# 4. Publish to the platform
|
||||
crewai tool publish --public
|
||||
|
||||
# 5. In another project, install and use it
|
||||
crewai tool install weather_lookup
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Platform Tools vs PyPI Packages
|
||||
|
||||
| | Platform Tools | PyPI Packages |
|
||||
|---|---|---|
|
||||
| **Publish** | `crewai tool publish` | `uv build` + `uv publish` |
|
||||
| **Registry** | CrewAI platform | PyPI |
|
||||
| **Install** | `crewai tool install <handle>` | `pip install <package>` |
|
||||
| **Auth** | `crewai login` | PyPI account + token |
|
||||
| **Visibility** | `--public` / `--private` flags | Always public |
|
||||
| **Guide** | This page | [Publish Custom Tools](/en/guides/tools/publish-custom-tools) |
|
||||
|
||||
---
|
||||
|
||||
## Related
|
||||
|
||||
- [Create Custom Tools](/en/learn/create-custom-tools) — Python API reference for building tools (BaseTool, @tool decorator)
|
||||
- [Publish Custom Tools to PyPI](/en/guides/tools/publish-custom-tools) — package and distribute tools as standalone Python libraries
|
||||
@@ -12,7 +12,9 @@ incorporating the latest functionalities such as tool delegation, error handling
|
||||
enabling agents to perform a wide range of actions.
|
||||
|
||||
<Tip>
|
||||
**Want to publish your tool for the community?** If you're building a tool that others could benefit from, check out the [Publish Custom Tools](/en/guides/tools/publish-custom-tools) guide to learn how to package and distribute your tool on PyPI.
|
||||
**Want to publish your tool to the CrewAI platform?** Use the CLI to scaffold, publish, and share tools directly on the platform — see the [Platform Tools CLI](/en/guides/tools/platform-tools-cli) guide.
|
||||
|
||||
**Prefer publishing to PyPI?** Check out the [Publish Custom Tools](/en/guides/tools/publish-custom-tools) guide to package and distribute your tool as a standalone Python library.
|
||||
</Tip>
|
||||
|
||||
### Subclassing `BaseTool`
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Cache for tracking uploaded files using aiocache or ValkeyCache."""
|
||||
"""Cache for tracking uploaded files using aiocache."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -10,11 +10,10 @@ from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Protocol
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from aiocache import Cache # type: ignore[import-untyped]
|
||||
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
|
||||
from crewai.utilities.cache_config import parse_cache_url
|
||||
|
||||
from crewai_files.core.constants import DEFAULT_MAX_CACHE_ENTRIES, DEFAULT_TTL_SECONDS
|
||||
from crewai_files.uploaders.factory import ProviderType
|
||||
@@ -52,33 +51,6 @@ class CachedUpload:
|
||||
return False
|
||||
return datetime.now(timezone.utc) >= self.expires_at
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Serialize to a JSON-compatible dict."""
|
||||
return {
|
||||
"file_id": self.file_id,
|
||||
"provider": self.provider,
|
||||
"file_uri": self.file_uri,
|
||||
"content_type": self.content_type,
|
||||
"uploaded_at": self.uploaded_at.isoformat(),
|
||||
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> CachedUpload:
|
||||
"""Deserialize from a dict."""
|
||||
return cls(
|
||||
file_id=data["file_id"],
|
||||
provider=data["provider"],
|
||||
file_uri=data.get("file_uri"),
|
||||
content_type=data["content_type"],
|
||||
uploaded_at=datetime.fromisoformat(data["uploaded_at"]),
|
||||
expires_at=(
|
||||
datetime.fromisoformat(data["expires_at"])
|
||||
if data.get("expires_at")
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _make_key(file_hash: str, provider: str) -> str:
|
||||
"""Create a cache key from file hash and provider."""
|
||||
@@ -86,7 +58,14 @@ def _make_key(file_hash: str, provider: str) -> str:
|
||||
|
||||
|
||||
def _compute_file_hash_streaming(chunks: Iterator[bytes]) -> str:
|
||||
"""Compute SHA-256 hash from streaming chunks."""
|
||||
"""Compute SHA-256 hash from streaming chunks.
|
||||
|
||||
Args:
|
||||
chunks: Iterator of byte chunks.
|
||||
|
||||
Returns:
|
||||
Hexadecimal hash string.
|
||||
"""
|
||||
hasher = hashlib.sha256()
|
||||
for chunk in chunks:
|
||||
hasher.update(chunk)
|
||||
@@ -94,7 +73,10 @@ def _compute_file_hash_streaming(chunks: Iterator[bytes]) -> str:
|
||||
|
||||
|
||||
def _compute_file_hash(file: FileInput) -> str:
|
||||
"""Compute SHA-256 hash of file content."""
|
||||
"""Compute SHA-256 hash of file content.
|
||||
|
||||
Uses streaming for FilePath sources to avoid loading large files into memory.
|
||||
"""
|
||||
from crewai_files.core.sources import FilePath
|
||||
|
||||
source = file._file_source
|
||||
@@ -104,73 +86,10 @@ def _compute_file_hash(file: FileInput) -> str:
|
||||
return hashlib.sha256(content).hexdigest()
|
||||
|
||||
|
||||
class CacheBackend(Protocol):
|
||||
"""Protocol for cache backends used by UploadCache."""
|
||||
|
||||
async def get(self, key: str) -> CachedUpload | None: ...
|
||||
async def set(self, key: str, value: CachedUpload, ttl: int) -> None: ...
|
||||
async def delete(self, key: str) -> bool: ...
|
||||
|
||||
|
||||
class AiocacheBackend:
|
||||
"""Cache backend backed by aiocache (memory or Redis)."""
|
||||
|
||||
def __init__(self, cache: Cache) -> None: # type: ignore[no-any-unimported]
|
||||
self._cache = cache
|
||||
|
||||
async def get(self, key: str) -> CachedUpload | None:
|
||||
result = await self._cache.get(key)
|
||||
if isinstance(result, CachedUpload):
|
||||
return result
|
||||
return None
|
||||
|
||||
async def set(self, key: str, value: CachedUpload, ttl: int) -> None:
|
||||
await self._cache.set(key, value, ttl=ttl)
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
result = await self._cache.delete(key)
|
||||
return bool(result > 0 if isinstance(result, int) else result)
|
||||
|
||||
|
||||
class ValkeyCacheBackend:
|
||||
"""Cache backend backed by ValkeyCache (JSON serialization)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str = "localhost",
|
||||
port: int = 6379,
|
||||
db: int = 0,
|
||||
password: str | None = None,
|
||||
default_ttl: int | None = None,
|
||||
) -> None:
|
||||
from crewai.memory.storage.valkey_cache import ValkeyCache
|
||||
|
||||
self._cache = ValkeyCache(
|
||||
host=host, port=port, db=db, password=password, default_ttl=default_ttl
|
||||
)
|
||||
|
||||
async def get(self, key: str) -> CachedUpload | None:
|
||||
data = await self._cache.get(key)
|
||||
if data is None:
|
||||
return None
|
||||
try:
|
||||
return CachedUpload.from_dict(data)
|
||||
except (KeyError, ValueError) as e:
|
||||
logger.warning(f"Failed to deserialize cached upload: {e}")
|
||||
return None
|
||||
|
||||
async def set(self, key: str, value: CachedUpload, ttl: int) -> None:
|
||||
await self._cache.set(key, value.to_dict(), ttl=ttl)
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
await self._cache.delete(key)
|
||||
return True # ValkeyCache.delete is void
|
||||
|
||||
|
||||
class UploadCache:
|
||||
"""Async cache for tracking uploaded files.
|
||||
"""Async cache for tracking uploaded files using aiocache.
|
||||
|
||||
Supports in-memory caching by default, with optional Redis or Valkey backend
|
||||
Supports in-memory caching by default, with optional Redis backend
|
||||
for distributed setups.
|
||||
|
||||
Attributes:
|
||||
@@ -191,7 +110,7 @@ class UploadCache:
|
||||
Args:
|
||||
ttl: Default TTL in seconds.
|
||||
namespace: Cache namespace.
|
||||
cache_type: Backend type ("memory", "redis", or "valkey").
|
||||
cache_type: Backend type ("memory" or "redis").
|
||||
max_entries: Maximum cache entries (None for unlimited).
|
||||
**cache_kwargs: Additional args for cache backend.
|
||||
"""
|
||||
@@ -201,39 +120,18 @@ class UploadCache:
|
||||
self._provider_keys: dict[ProviderType, set[str]] = {}
|
||||
self._key_access_order: list[str] = []
|
||||
|
||||
self._backend: CacheBackend = self._create_backend(
|
||||
cache_type, namespace, ttl, **cache_kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_backend(
|
||||
cache_type: str,
|
||||
namespace: str,
|
||||
ttl: int,
|
||||
**cache_kwargs: Any,
|
||||
) -> CacheBackend:
|
||||
"""Create the appropriate cache backend."""
|
||||
if cache_type == "valkey":
|
||||
conn = parse_cache_url() or {}
|
||||
return ValkeyCacheBackend(
|
||||
host=cache_kwargs.get("host", conn.get("host", "localhost")),
|
||||
port=cache_kwargs.get("port", conn.get("port", 6379)),
|
||||
db=cache_kwargs.get("db", conn.get("db", 0)),
|
||||
password=cache_kwargs.get("password", conn.get("password")),
|
||||
default_ttl=ttl,
|
||||
)
|
||||
if cache_type == "redis":
|
||||
return AiocacheBackend(
|
||||
Cache(
|
||||
Cache.REDIS,
|
||||
serializer=PickleSerializer(),
|
||||
namespace=namespace,
|
||||
**cache_kwargs,
|
||||
)
|
||||
self._cache = Cache(
|
||||
Cache.REDIS,
|
||||
serializer=PickleSerializer(),
|
||||
namespace=namespace,
|
||||
**cache_kwargs,
|
||||
)
|
||||
else:
|
||||
self._cache = Cache(
|
||||
serializer=PickleSerializer(),
|
||||
namespace=namespace,
|
||||
)
|
||||
return AiocacheBackend(
|
||||
Cache(serializer=PickleSerializer(), namespace=namespace)
|
||||
)
|
||||
|
||||
def _track_key(self, provider: ProviderType, key: str) -> None:
|
||||
"""Track a key for a provider (for cleanup) and access order."""
|
||||
@@ -259,9 +157,11 @@ class UploadCache:
|
||||
"""
|
||||
if self.max_entries is None:
|
||||
return 0
|
||||
|
||||
current_count = len(self)
|
||||
if current_count < self.max_entries:
|
||||
return 0
|
||||
|
||||
to_evict = max(1, self.max_entries // 10)
|
||||
return await self._evict_oldest(to_evict)
|
||||
|
||||
@@ -276,24 +176,31 @@ class UploadCache:
|
||||
"""
|
||||
evicted = 0
|
||||
keys_to_evict = self._key_access_order[:count]
|
||||
|
||||
for key in keys_to_evict:
|
||||
await self._backend.delete(key)
|
||||
await self._cache.delete(key)
|
||||
self._key_access_order.remove(key)
|
||||
for provider_keys in self._provider_keys.values():
|
||||
provider_keys.discard(key)
|
||||
evicted += 1
|
||||
|
||||
if evicted > 0:
|
||||
logger.debug(f"Evicted {evicted} oldest cache entries")
|
||||
return evicted
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Async public API
|
||||
# ------------------------------------------------------------------
|
||||
return evicted
|
||||
|
||||
async def aget(
|
||||
self, file: FileInput, provider: ProviderType
|
||||
) -> CachedUpload | None:
|
||||
"""Get a cached upload for a file."""
|
||||
"""Get a cached upload for a file.
|
||||
|
||||
Args:
|
||||
file: The file to look up.
|
||||
provider: The provider name.
|
||||
|
||||
Returns:
|
||||
Cached upload if found and not expired, None otherwise.
|
||||
"""
|
||||
file_hash = _compute_file_hash(file)
|
||||
return await self.aget_by_hash(file_hash, provider)
|
||||
|
||||
@@ -310,14 +217,17 @@ class UploadCache:
|
||||
Cached upload if found and not expired, None otherwise.
|
||||
"""
|
||||
key = _make_key(file_hash, provider)
|
||||
result = await self._backend.get(key)
|
||||
result = await self._cache.get(key)
|
||||
|
||||
if result is None:
|
||||
return None
|
||||
if result.is_expired():
|
||||
await self._backend.delete(key)
|
||||
self._untrack_key(provider, key)
|
||||
return None
|
||||
return result
|
||||
if isinstance(result, CachedUpload):
|
||||
if result.is_expired():
|
||||
await self._cache.delete(key)
|
||||
self._untrack_key(provider, key)
|
||||
return None
|
||||
return result
|
||||
return None
|
||||
|
||||
async def aset(
|
||||
self,
|
||||
@@ -327,7 +237,18 @@ class UploadCache:
|
||||
file_uri: str | None = None,
|
||||
expires_at: datetime | None = None,
|
||||
) -> CachedUpload:
|
||||
"""Cache an uploaded file."""
|
||||
"""Cache an uploaded file.
|
||||
|
||||
Args:
|
||||
file: The file that was uploaded.
|
||||
provider: The provider name.
|
||||
file_id: Provider-specific file identifier.
|
||||
file_uri: Optional URI for accessing the file.
|
||||
expires_at: When the upload expires.
|
||||
|
||||
Returns:
|
||||
The created cache entry.
|
||||
"""
|
||||
file_hash = _compute_file_hash(file)
|
||||
return await self.aset_by_hash(
|
||||
file_hash=file_hash,
|
||||
@@ -361,6 +282,7 @@ class UploadCache:
|
||||
The created cache entry.
|
||||
"""
|
||||
await self._evict_if_needed()
|
||||
|
||||
key = _make_key(file_hash, provider)
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
@@ -377,7 +299,7 @@ class UploadCache:
|
||||
if expires_at is not None:
|
||||
ttl = max(0, int((expires_at - now).total_seconds()))
|
||||
|
||||
await self._backend.set(key, cached, ttl=ttl)
|
||||
await self._cache.set(key, cached, ttl=ttl)
|
||||
self._track_key(provider, key)
|
||||
logger.debug(f"Cached upload: {file_id} for provider {provider}")
|
||||
return cached
|
||||
@@ -394,7 +316,9 @@ class UploadCache:
|
||||
"""
|
||||
file_hash = _compute_file_hash(file)
|
||||
key = _make_key(file_hash, provider)
|
||||
removed = await self._backend.delete(key)
|
||||
|
||||
result = await self._cache.delete(key)
|
||||
removed = bool(result > 0 if isinstance(result, int) else result)
|
||||
if removed:
|
||||
self._untrack_key(provider, key)
|
||||
return removed
|
||||
@@ -411,10 +335,11 @@ class UploadCache:
|
||||
"""
|
||||
if provider not in self._provider_keys:
|
||||
return False
|
||||
|
||||
for key in list(self._provider_keys[provider]):
|
||||
cached = await self._backend.get(key)
|
||||
if cached is not None and cached.file_id == file_id:
|
||||
await self._backend.delete(key)
|
||||
cached = await self._cache.get(key)
|
||||
if isinstance(cached, CachedUpload) and cached.file_id == file_id:
|
||||
await self._cache.delete(key)
|
||||
self._untrack_key(provider, key)
|
||||
return True
|
||||
return False
|
||||
@@ -426,13 +351,17 @@ class UploadCache:
|
||||
Number of entries removed.
|
||||
"""
|
||||
removed = 0
|
||||
|
||||
for provider, keys in list(self._provider_keys.items()):
|
||||
for key in list(keys):
|
||||
cached = await self._backend.get(key)
|
||||
if cached is None or cached.is_expired():
|
||||
await self._backend.delete(key)
|
||||
cached = await self._cache.get(key)
|
||||
if cached is None or (
|
||||
isinstance(cached, CachedUpload) and cached.is_expired()
|
||||
):
|
||||
await self._cache.delete(key)
|
||||
self._untrack_key(provider, key)
|
||||
removed += 1
|
||||
|
||||
if removed > 0:
|
||||
logger.debug(f"Cleared {removed} expired cache entries")
|
||||
return removed
|
||||
@@ -444,12 +373,9 @@ class UploadCache:
|
||||
Number of entries cleared.
|
||||
"""
|
||||
count = sum(len(keys) for keys in self._provider_keys.values())
|
||||
# Delete all tracked keys individually (works for all backends)
|
||||
for keys in self._provider_keys.values():
|
||||
for key in keys:
|
||||
await self._backend.delete(key)
|
||||
await self._cache.clear(namespace=self.namespace)
|
||||
self._provider_keys.clear()
|
||||
self._key_access_order.clear()
|
||||
|
||||
if count > 0:
|
||||
logger.debug(f"Cleared {count} cache entries")
|
||||
return count
|
||||
@@ -465,17 +391,14 @@ class UploadCache:
|
||||
"""
|
||||
if provider not in self._provider_keys:
|
||||
return []
|
||||
|
||||
results: list[CachedUpload] = []
|
||||
for key in list(self._provider_keys[provider]):
|
||||
cached = await self._backend.get(key)
|
||||
if cached is not None and not cached.is_expired():
|
||||
cached = await self._cache.get(key)
|
||||
if isinstance(cached, CachedUpload) and not cached.is_expired():
|
||||
results.append(cached)
|
||||
return results
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Sync wrappers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _run_sync(coro: Any) -> Any:
|
||||
"""Run an async coroutine from sync context without blocking event loop."""
|
||||
@@ -566,7 +489,11 @@ class UploadCache:
|
||||
return sum(len(keys) for keys in self._provider_keys.values())
|
||||
|
||||
def get_providers(self) -> builtins.set[ProviderType]:
|
||||
"""Get all provider names that have cached entries."""
|
||||
"""Get all provider names that have cached entries.
|
||||
|
||||
Returns:
|
||||
Set of provider names.
|
||||
"""
|
||||
return builtins.set(self._provider_keys.keys())
|
||||
|
||||
|
||||
@@ -579,7 +506,17 @@ def get_upload_cache(
|
||||
cache_type: str = "memory",
|
||||
**cache_kwargs: Any,
|
||||
) -> UploadCache:
|
||||
"""Get or create the default upload cache."""
|
||||
"""Get or create the default upload cache.
|
||||
|
||||
Args:
|
||||
ttl: Default TTL in seconds.
|
||||
namespace: Cache namespace.
|
||||
cache_type: Backend type ("memory" or "redis").
|
||||
**cache_kwargs: Additional args for cache backend.
|
||||
|
||||
Returns:
|
||||
The upload cache instance.
|
||||
"""
|
||||
global _default_cache
|
||||
if _default_cache is None:
|
||||
_default_cache = UploadCache(
|
||||
|
||||
@@ -110,9 +110,6 @@ file-processing = [
|
||||
qdrant-edge = [
|
||||
"qdrant-edge-py>=0.6.0",
|
||||
]
|
||||
valkey = [
|
||||
"valkey-glide>=1.3.0",
|
||||
]
|
||||
|
||||
|
||||
[tool.uv]
|
||||
|
||||
@@ -13,12 +13,8 @@ from types import MethodType
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from a2a.client.errors import A2AClientHTTPError
|
||||
from a2a.types import (
|
||||
AgentCapabilities,
|
||||
AgentCard,
|
||||
AgentSkill,
|
||||
)
|
||||
from aiocache import cached, caches # type: ignore[import-untyped]
|
||||
from a2a.types import AgentCapabilities, AgentCard, AgentSkill
|
||||
from aiocache import cached # type: ignore[import-untyped]
|
||||
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
|
||||
import httpx
|
||||
|
||||
@@ -36,7 +32,6 @@ from crewai.events.types.a2a_events import (
|
||||
A2AAuthenticationFailedEvent,
|
||||
A2AConnectionErrorEvent,
|
||||
)
|
||||
from crewai.utilities.cache_config import get_aiocache_config
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -45,18 +40,6 @@ if TYPE_CHECKING:
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
_cache_configured = False
|
||||
|
||||
|
||||
def _ensure_cache_configured() -> None:
|
||||
"""Configure aiocache on first use (lazy initialization)."""
|
||||
global _cache_configured
|
||||
if _cache_configured:
|
||||
return
|
||||
caches.set_config(get_aiocache_config())
|
||||
_cache_configured = True
|
||||
|
||||
|
||||
def _get_tls_verify(auth: ClientAuthScheme | None) -> ssl.SSLContext | bool | str:
|
||||
"""Get TLS verify parameter from auth scheme.
|
||||
|
||||
@@ -208,7 +191,6 @@ async def afetch_agent_card(
|
||||
else:
|
||||
auth_hash = _auth_store.compute_key("none", "")
|
||||
_auth_store.set(auth_hash, auth)
|
||||
_ensure_cache_configured()
|
||||
agent_card: AgentCard = await _afetch_agent_card_cached(
|
||||
endpoint, auth_hash, timeout
|
||||
)
|
||||
|
||||
@@ -9,8 +9,9 @@ from datetime import datetime
|
||||
from functools import wraps
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from a2a.server.agent_execution import RequestContext
|
||||
from a2a.server.events import EventQueue
|
||||
@@ -37,6 +38,7 @@ from a2a.utils import (
|
||||
from a2a.utils.errors import ServerError
|
||||
from aiocache import SimpleMemoryCache, caches # type: ignore[import-untyped]
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from crewai.a2a.utils.agent_card import _get_server_config
|
||||
from crewai.a2a.utils.content_type import validate_message_parts
|
||||
@@ -48,18 +50,12 @@ from crewai.events.types.a2a_events import (
|
||||
A2AServerTaskStartedEvent,
|
||||
)
|
||||
from crewai.task import Task
|
||||
from crewai.utilities.cache_config import (
|
||||
get_aiocache_config,
|
||||
parse_cache_url,
|
||||
use_valkey_cache,
|
||||
)
|
||||
from crewai.utilities.pydantic_schema_utils import create_model_from_schema
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.a2a.extensions.server import ExtensionContext, ServerExtensionRegistry
|
||||
from crewai.agent import Agent
|
||||
from crewai.memory.storage.valkey_cache import ValkeyCache
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -68,61 +64,52 @@ P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lazy cache initialisation
|
||||
# ---------------------------------------------------------------------------
|
||||
class RedisCacheConfig(TypedDict, total=False):
|
||||
"""Configuration for aiocache Redis backend."""
|
||||
|
||||
_task_cache: ValkeyCache | None = None
|
||||
_lazy_init_complete = False
|
||||
_cache_init_lock = threading.Lock()
|
||||
|
||||
# Cancellation polling interval in seconds.
|
||||
_CANCEL_POLL_INTERVAL = 0.1
|
||||
|
||||
# Configure aiocache at import time (matches upstream behaviour).
|
||||
# This is safe — it only touches aiocache, no optional dependencies.
|
||||
# The Valkey path is deferred to _ensure_task_cache() to avoid importing
|
||||
# valkey-glide at module level (it may not be installed).
|
||||
if not use_valkey_cache():
|
||||
caches.set_config(get_aiocache_config())
|
||||
cache: str
|
||||
endpoint: str
|
||||
port: int
|
||||
db: int
|
||||
password: str
|
||||
|
||||
|
||||
def _ensure_task_cache() -> None:
|
||||
"""Initialise the Valkey task cache on first use (thread-safe).
|
||||
def _parse_redis_url(url: str) -> RedisCacheConfig:
|
||||
"""Parse a Redis URL into aiocache configuration.
|
||||
|
||||
For the aiocache path, configuration happens at module level above.
|
||||
This function only needs to run for the Valkey path.
|
||||
Args:
|
||||
url: Redis connection URL (e.g., redis://localhost:6379/0).
|
||||
|
||||
Returns:
|
||||
Configuration dict for aiocache.RedisCache.
|
||||
"""
|
||||
global _task_cache, _lazy_init_complete
|
||||
if _lazy_init_complete:
|
||||
return
|
||||
parsed = urlparse(url)
|
||||
config: RedisCacheConfig = {
|
||||
"cache": "aiocache.RedisCache",
|
||||
"endpoint": parsed.hostname or "localhost",
|
||||
"port": parsed.port or 6379,
|
||||
}
|
||||
if parsed.path and parsed.path != "/":
|
||||
try:
|
||||
config["db"] = int(parsed.path.lstrip("/"))
|
||||
except ValueError:
|
||||
pass
|
||||
if parsed.password:
|
||||
config["password"] = parsed.password
|
||||
return config
|
||||
|
||||
with _cache_init_lock:
|
||||
if _lazy_init_complete:
|
||||
return
|
||||
|
||||
if use_valkey_cache():
|
||||
from crewai.memory.storage.valkey_cache import ValkeyCache
|
||||
_redis_url = os.environ.get("REDIS_URL")
|
||||
|
||||
conn = parse_cache_url() or {}
|
||||
try:
|
||||
_task_cache = ValkeyCache(
|
||||
host=conn.get("host", "localhost"),
|
||||
port=conn.get("port", 6379),
|
||||
db=conn.get("db", 0),
|
||||
password=conn.get("password"),
|
||||
default_ttl=3600,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to initialize ValkeyCache for task cancellation, "
|
||||
"falling back to aiocache",
|
||||
extra={"error": str(e)},
|
||||
)
|
||||
caches.set_config(get_aiocache_config())
|
||||
_task_cache = None
|
||||
|
||||
_lazy_init_complete = True
|
||||
caches.set_config(
|
||||
{
|
||||
"default": _parse_redis_url(_redis_url)
|
||||
if _redis_url
|
||||
else {
|
||||
"cache": "aiocache.SimpleMemoryCache",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def cancellable(
|
||||
@@ -143,8 +130,6 @@ def cancellable(
|
||||
@wraps(fn)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
"""Wrap function with cancellation monitoring."""
|
||||
_ensure_task_cache()
|
||||
|
||||
context: RequestContext | None = None
|
||||
for arg in args:
|
||||
if isinstance(arg, RequestContext):
|
||||
@@ -157,34 +142,19 @@ def cancellable(
|
||||
return await fn(*args, **kwargs)
|
||||
|
||||
task_id = context.task_id
|
||||
cache = caches.get("default")
|
||||
|
||||
async def poll_for_cancel_valkey() -> bool:
|
||||
"""Poll ValkeyCache for cancellation flag."""
|
||||
while True:
|
||||
if _task_cache is not None and await _task_cache.get(
|
||||
f"cancel:{task_id}"
|
||||
):
|
||||
return True
|
||||
await asyncio.sleep(_CANCEL_POLL_INTERVAL)
|
||||
|
||||
async def poll_for_cancel_aiocache() -> bool:
|
||||
"""Poll aiocache for cancellation flag."""
|
||||
cache = caches.get("default")
|
||||
async def poll_for_cancel() -> bool:
|
||||
"""Poll cache for cancellation flag."""
|
||||
while True:
|
||||
if await cache.get(f"cancel:{task_id}"):
|
||||
return True
|
||||
await asyncio.sleep(_CANCEL_POLL_INTERVAL)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def watch_for_cancel() -> bool:
|
||||
"""Watch for cancellation events via pub/sub or polling."""
|
||||
if _task_cache is not None:
|
||||
# ValkeyCache: use polling (pub/sub not implemented yet)
|
||||
return await poll_for_cancel_valkey()
|
||||
|
||||
# aiocache: use pub/sub if Redis, otherwise poll
|
||||
cache = caches.get("default")
|
||||
if isinstance(cache, SimpleMemoryCache):
|
||||
return await poll_for_cancel_aiocache()
|
||||
return await poll_for_cancel()
|
||||
|
||||
try:
|
||||
client = cache.client
|
||||
@@ -198,7 +168,7 @@ def cancellable(
|
||||
"Cancel watcher Redis error, falling back to polling",
|
||||
extra={"task_id": task_id, "error": str(e)},
|
||||
)
|
||||
return await poll_for_cancel_aiocache()
|
||||
return await poll_for_cancel()
|
||||
return False
|
||||
|
||||
execute_task = asyncio.create_task(fn(*args, **kwargs))
|
||||
@@ -220,12 +190,7 @@ def cancellable(
|
||||
cancel_watch.cancel()
|
||||
return execute_task.result()
|
||||
finally:
|
||||
# Clean up cancellation flag
|
||||
if _task_cache is not None:
|
||||
await _task_cache.delete(f"cancel:{task_id}")
|
||||
else:
|
||||
cache = caches.get("default")
|
||||
await cache.delete(f"cancel:{task_id}")
|
||||
await cache.delete(f"cancel:{task_id}")
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -510,8 +475,6 @@ async def cancel(
|
||||
if task_id is None or context_id is None:
|
||||
raise ServerError(InvalidParamsError(message="task_id and context_id required"))
|
||||
|
||||
_ensure_task_cache()
|
||||
|
||||
if context.current_task and context.current_task.status.state in (
|
||||
TaskState.completed,
|
||||
TaskState.failed,
|
||||
@@ -519,16 +482,11 @@ async def cancel(
|
||||
):
|
||||
return context.current_task
|
||||
|
||||
if _task_cache is not None:
|
||||
# Use ValkeyCache
|
||||
await _task_cache.set(f"cancel:{task_id}", True, ttl=3600)
|
||||
# Note: pub/sub not implemented for ValkeyCache yet, relies on polling
|
||||
else:
|
||||
# Use aiocache
|
||||
cache = caches.get("default")
|
||||
await cache.set(f"cancel:{task_id}", True, ttl=3600)
|
||||
if not isinstance(cache, SimpleMemoryCache):
|
||||
await cache.client.publish(f"cancel:{task_id}", "cancel")
|
||||
cache = caches.get("default")
|
||||
|
||||
await cache.set(f"cancel:{task_id}", True, ttl=3600)
|
||||
if not isinstance(cache, SimpleMemoryCache):
|
||||
await cache.client.publish(f"cancel:{task_id}", "cancel")
|
||||
|
||||
await event_queue.enqueue_event(
|
||||
TaskStatusUpdateEvent(
|
||||
|
||||
@@ -18,7 +18,7 @@ import math
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.flow.flow import Flow, listen, start
|
||||
from crewai.memory.analyze import (
|
||||
@@ -68,27 +68,6 @@ class ItemState(BaseModel):
|
||||
plan: ConsolidationPlan | None = None
|
||||
result_record: MemoryRecord | None = None
|
||||
|
||||
@field_validator("similar_records", "result_record", mode="before")
|
||||
@classmethod
|
||||
def ensure_embedding_is_list(cls, v: Any) -> Any:
|
||||
"""Ensure MemoryRecord embeddings are list[float], not bytes.
|
||||
|
||||
Delegates to MemoryRecord.validate_embedding for consistent behavior
|
||||
(e.g. empty bytes → None).
|
||||
"""
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, list):
|
||||
for record in v:
|
||||
if isinstance(record, MemoryRecord) and isinstance(
|
||||
record.embedding, bytes
|
||||
):
|
||||
record.embedding = MemoryRecord.validate_embedding(record.embedding)
|
||||
return v
|
||||
if isinstance(v, MemoryRecord) and isinstance(v.embedding, bytes):
|
||||
v.embedding = MemoryRecord.validate_embedding(v.embedding)
|
||||
return v
|
||||
|
||||
|
||||
class EncodingState(BaseModel):
|
||||
"""Batch-level state for the encoding flow."""
|
||||
|
||||
@@ -1,198 +0,0 @@
|
||||
"""Valkey-based cache implementation for CrewAI.
|
||||
|
||||
This module provides a simple cache interface using Valkey-GLIDE client
|
||||
for caching operations with optional TTL support. It replaces Redis usage
|
||||
in A2A communication, file uploads, and agent card caching.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from glide import GlideClient, GlideClientConfiguration, NodeAddress
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ValkeyCache:
|
||||
"""Simple cache interface using Valkey-GLIDE client.
|
||||
|
||||
Provides get/set/delete/exists operations for caching with optional TTL.
|
||||
Uses JSON serialization for complex values and lazy client initialization.
|
||||
|
||||
Example:
|
||||
>>> cache = ValkeyCache(host="localhost", port=6379)
|
||||
>>> await cache.set("key", {"data": "value"}, ttl=3600)
|
||||
>>> value = await cache.get("key")
|
||||
>>> await cache.delete("key")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str = "localhost",
|
||||
port: int = 6379,
|
||||
db: int = 0,
|
||||
password: str | None = None,
|
||||
default_ttl: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize Valkey cache.
|
||||
|
||||
Args:
|
||||
host: Valkey server hostname.
|
||||
port: Valkey server port.
|
||||
db: Database number to use.
|
||||
password: Optional password for authentication.
|
||||
default_ttl: Default TTL in seconds (None = no expiration).
|
||||
"""
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._db = db
|
||||
self._password = password
|
||||
self._default_ttl = default_ttl
|
||||
self._client: GlideClient | None = None
|
||||
|
||||
async def _get_client(self) -> GlideClient:
|
||||
"""Get or create Valkey client (lazy initialization).
|
||||
|
||||
Returns:
|
||||
Initialized GlideClient instance.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If connection to Valkey fails.
|
||||
TimeoutError: If connection attempt times out (10 seconds).
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
if self._client is None:
|
||||
host = self._host
|
||||
port = self._port
|
||||
db = self._db
|
||||
try:
|
||||
from glide import ServerCredentials
|
||||
|
||||
config = GlideClientConfiguration(
|
||||
addresses=[NodeAddress(host, port)],
|
||||
database_id=db,
|
||||
credentials=(
|
||||
ServerCredentials(password=self._password)
|
||||
if self._password
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
# Add connection timeout (10 seconds)
|
||||
try:
|
||||
self._client = await asyncio.wait_for(
|
||||
GlideClient.create(config), timeout=10.0
|
||||
)
|
||||
except asyncio.TimeoutError as e:
|
||||
_logger.error("Connection timeout connecting to Valkey")
|
||||
raise TimeoutError(
|
||||
"Connection timeout to Valkey. "
|
||||
"Ensure Valkey is running and accessible."
|
||||
) from e
|
||||
|
||||
_logger.info("Valkey cache client initialized")
|
||||
except (TimeoutError, RuntimeError):
|
||||
raise
|
||||
except Exception as e:
|
||||
_logger.error(
|
||||
"Failed to create Valkey cache client: %s", type(e).__name__
|
||||
)
|
||||
raise RuntimeError(
|
||||
"Cannot connect to Valkey. Check connection settings."
|
||||
) from e
|
||||
|
||||
return self._client
|
||||
|
||||
async def get(self, key: str) -> Any | None:
|
||||
"""Get value from cache.
|
||||
|
||||
Args:
|
||||
key: Cache key.
|
||||
|
||||
Returns:
|
||||
Cached value (deserialized from JSON) or None if not found.
|
||||
"""
|
||||
client = await self._get_client()
|
||||
value = await client.get(key)
|
||||
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
_logger.warning(f"Failed to deserialize cached value for key: {key}")
|
||||
return None
|
||||
|
||||
async def set(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
ttl: int | None = None,
|
||||
) -> None:
|
||||
"""Set value in cache.
|
||||
|
||||
Args:
|
||||
key: Cache key.
|
||||
value: Value to cache (will be serialized to JSON).
|
||||
ttl: TTL in seconds (None uses default_ttl, 0 = no expiration).
|
||||
|
||||
Raises:
|
||||
TypeError: If value is not JSON-serializable.
|
||||
"""
|
||||
from glide import ExpirySet, ExpiryType
|
||||
|
||||
client = await self._get_client()
|
||||
try:
|
||||
serialized = json.dumps(value)
|
||||
except (TypeError, ValueError) as e:
|
||||
_logger.error("Cannot serialize value for key %r: %s", key, e)
|
||||
raise TypeError(
|
||||
f"Value for cache key {key!r} is not JSON-serializable: {e}"
|
||||
) from e
|
||||
|
||||
ttl_to_use = ttl if ttl is not None else self._default_ttl
|
||||
|
||||
if ttl_to_use and ttl_to_use > 0:
|
||||
# Set with expiration using SET command with EX option
|
||||
await client.set(
|
||||
key,
|
||||
serialized,
|
||||
expiry=ExpirySet(ExpiryType.SEC, ttl_to_use),
|
||||
)
|
||||
else:
|
||||
await client.set(key, serialized)
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
"""Delete value from cache.
|
||||
|
||||
Args:
|
||||
key: Cache key to delete.
|
||||
"""
|
||||
client = await self._get_client()
|
||||
await client.delete([key])
|
||||
|
||||
async def exists(self, key: str) -> bool:
|
||||
"""Check if key exists in cache.
|
||||
|
||||
Args:
|
||||
key: Cache key to check.
|
||||
|
||||
Returns:
|
||||
True if key exists, False otherwise.
|
||||
"""
|
||||
client = await self._get_client()
|
||||
result = await client.exists([key])
|
||||
return result > 0
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close Valkey client connection."""
|
||||
if self._client:
|
||||
await self._client.close()
|
||||
self._client = None
|
||||
_logger.debug("Valkey cache client closed")
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,17 +2,13 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import concurrent.futures
|
||||
from datetime import datetime
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
# When searching the vector store, we ask for more results than the caller
|
||||
# requested so that post-search steps (composite scoring, deduplication,
|
||||
# category filtering) have enough candidates to fill the final result set.
|
||||
@@ -61,23 +57,6 @@ class MemoryRecord(BaseModel):
|
||||
repr=False,
|
||||
description="Vector embedding for semantic search. Excluded from serialization to save tokens.",
|
||||
)
|
||||
|
||||
@field_validator("embedding", mode="before")
|
||||
@classmethod
|
||||
def validate_embedding(cls, v: Any) -> list[float] | None:
|
||||
"""Ensure embedding is always list[float] or None, never bytes."""
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, bytes):
|
||||
# Convert bytes to list[float] if needed
|
||||
import numpy as np
|
||||
|
||||
if len(v) == 0:
|
||||
return None
|
||||
arr = np.frombuffer(v, dtype=np.float32)
|
||||
return [float(x) for x in arr]
|
||||
return [float(x) for x in v]
|
||||
|
||||
source: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
@@ -325,11 +304,7 @@ def embed_text(embedder: Any, text: str) -> list[float]:
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
return []
|
||||
|
||||
# Just call the embedder directly - the blocking issue needs to be fixed
|
||||
# at a higher level (making Memory.recall() async)
|
||||
result = embedder([text])
|
||||
|
||||
if not result:
|
||||
return []
|
||||
first = result[0]
|
||||
@@ -340,27 +315,12 @@ def embed_text(embedder: Any, text: str) -> list[float]:
|
||||
return list(first)
|
||||
|
||||
|
||||
# Reusable thread pool for running embedder calls from sync context
|
||||
# when an async event loop is already running. Uses max_workers=2 so
|
||||
# a single slow/timed-out call doesn't block subsequent embeds.
|
||||
_EMBED_POOL = concurrent.futures.ThreadPoolExecutor(max_workers=2)
|
||||
|
||||
|
||||
def embed_texts(embedder: Any, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed multiple texts in a single API call.
|
||||
|
||||
The embedder already accepts ``list[str]``, so this just calls it once
|
||||
with the full batch and normalises the output format.
|
||||
|
||||
When called from an async context, offloads the embedder to a thread pool
|
||||
so the embedding work doesn't run on the event loop thread. The calling
|
||||
thread still blocks on the result (unavoidable for a sync function), but
|
||||
this prevents the embedder from starving the event loop's I/O callbacks.
|
||||
The pool uses ``max_workers=2`` so a single timed-out call doesn't block
|
||||
subsequent embeds.
|
||||
|
||||
Note: the proper long-term fix is making ``Memory.recall()`` async.
|
||||
|
||||
Args:
|
||||
embedder: Callable that accepts a list of strings and returns embeddings.
|
||||
texts: List of texts to embed.
|
||||
@@ -368,8 +328,6 @@ def embed_texts(embedder: Any, texts: list[str]) -> list[list[float]]:
|
||||
Returns:
|
||||
List of embeddings, one per input text. Empty texts produce empty lists.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
if not texts:
|
||||
return []
|
||||
# Filter out empty texts, remembering their positions
|
||||
@@ -379,28 +337,7 @@ def embed_texts(embedder: Any, texts: list[str]) -> list[list[float]]:
|
||||
if not valid:
|
||||
return [[] for _ in texts]
|
||||
|
||||
texts_to_embed = [t for _, t in valid]
|
||||
|
||||
# Check if we're in an async context
|
||||
result: Any
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
# We're in an async context but this is a sync function.
|
||||
# Offload to thread pool so the embedder doesn't run on the
|
||||
# event loop thread. The .result() call blocks this thread
|
||||
# (acceptable — callers like Memory.recall() are sync).
|
||||
try:
|
||||
result = _EMBED_POOL.submit(embedder, texts_to_embed).result(timeout=30)
|
||||
except concurrent.futures.TimeoutError:
|
||||
_logger.warning(
|
||||
"Embedder timed out after 30s, returning empty embeddings. "
|
||||
"The worker thread may still be running."
|
||||
)
|
||||
return [[] for _ in texts]
|
||||
except RuntimeError:
|
||||
# Not in async context, run directly
|
||||
result = embedder(texts_to_embed)
|
||||
|
||||
result = embedder([t for _, t in valid])
|
||||
embeddings: list[list[float]] = [[] for _ in texts]
|
||||
for (orig_idx, _), emb in zip(valid, result, strict=False):
|
||||
if hasattr(emb, "tolist"):
|
||||
|
||||
@@ -2,11 +2,9 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import concurrent.futures
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
import contextvars
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Literal
|
||||
@@ -38,9 +36,6 @@ from crewai.rag.embeddings.factory import build_embedder
|
||||
from crewai.rag.embeddings.providers.openai.types import OpenAIProviderSpec
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
@@ -216,18 +211,6 @@ class Memory(BaseModel):
|
||||
from crewai.memory.storage.lancedb_storage import LanceDBStorage
|
||||
|
||||
self._storage = LanceDBStorage()
|
||||
elif self.storage == "valkey":
|
||||
from crewai.memory.storage.valkey_storage import ValkeyStorage
|
||||
from crewai.utilities.cache_config import parse_cache_url
|
||||
|
||||
conn = parse_cache_url() or {}
|
||||
self._storage = ValkeyStorage(
|
||||
host=conn.get("host", "localhost"),
|
||||
port=conn.get("port", 6379),
|
||||
db=conn.get("db", 0),
|
||||
password=conn.get("password"),
|
||||
use_tls=conn.get("use_tls", False),
|
||||
)
|
||||
else:
|
||||
from crewai.memory.storage.lancedb_storage import LanceDBStorage
|
||||
|
||||
@@ -333,60 +316,16 @@ class Memory(BaseModel):
|
||||
except Exception: # noqa: S110
|
||||
pass # swallow everything during shutdown
|
||||
|
||||
def drain_writes(self, timeout_per_save: float = 60.0) -> None:
|
||||
def drain_writes(self) -> None:
|
||||
"""Block until all pending background saves have completed.
|
||||
|
||||
Called automatically by ``recall()`` and should be called by the
|
||||
crew at shutdown to ensure no saves are lost.
|
||||
|
||||
Args:
|
||||
timeout_per_save: Maximum seconds to wait per save operation.
|
||||
Default 60s. If a save times out, logs warning
|
||||
but continues to avoid blocking crew completion.
|
||||
"""
|
||||
with self._pending_lock:
|
||||
pending = list(self._pending_saves)
|
||||
|
||||
if pending:
|
||||
_logger.debug(
|
||||
"[DRAIN_WRITES] Waiting for %d pending saves...", len(pending)
|
||||
)
|
||||
|
||||
failed_saves = 0
|
||||
for i, future in enumerate(pending):
|
||||
try:
|
||||
_logger.debug(
|
||||
"[DRAIN_WRITES] Waiting for save %d/%d...", i + 1, len(pending)
|
||||
)
|
||||
future.result(timeout=timeout_per_save)
|
||||
_logger.debug(
|
||||
"[DRAIN_WRITES] Save %d/%d completed", i + 1, len(pending)
|
||||
)
|
||||
except (TimeoutError, concurrent.futures.TimeoutError): # noqa: PERF203
|
||||
failed_saves += 1
|
||||
_logger.warning(
|
||||
"[DRAIN_WRITES] Save %d/%d timed out after %ss. "
|
||||
"This save will be abandoned. Consider increasing timeout or checking "
|
||||
"LLM/embedder performance.",
|
||||
i + 1,
|
||||
len(pending),
|
||||
timeout_per_save,
|
||||
)
|
||||
# Don't raise - just log and continue to avoid blocking crew completion
|
||||
except Exception as e:
|
||||
failed_saves += 1
|
||||
_logger.error(
|
||||
"[DRAIN_WRITES] Save %d/%d failed: %s", i + 1, len(pending), e
|
||||
)
|
||||
# Don't raise - just log and continue
|
||||
|
||||
if failed_saves > 0:
|
||||
_logger.warning(
|
||||
"[DRAIN_WRITES] %d/%d saves failed or timed out. "
|
||||
"Some memories may not have been persisted.",
|
||||
failed_saves,
|
||||
len(pending),
|
||||
)
|
||||
for future in pending:
|
||||
future.result() # blocks until done; re-raises exceptions
|
||||
|
||||
def close(self) -> None:
|
||||
"""Drain pending saves, flush storage, and shut down the background thread pool."""
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
"""Shared cache configuration helpers for Valkey/Redis URL parsing."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_cache_url() -> dict[str, Any] | None:
|
||||
"""Parse VALKEY_URL or REDIS_URL from environment.
|
||||
|
||||
Priority: VALKEY_URL > REDIS_URL.
|
||||
|
||||
Returns:
|
||||
Dict with host, port, db, password keys, or None if no URL is set.
|
||||
"""
|
||||
url = os.environ.get("VALKEY_URL") or os.environ.get("REDIS_URL")
|
||||
if not url:
|
||||
return None
|
||||
parsed = urlparse(url)
|
||||
return {
|
||||
"host": parsed.hostname or "localhost",
|
||||
"port": parsed.port or 6379,
|
||||
"db": _parse_db_from_path(parsed.path),
|
||||
"password": parsed.password,
|
||||
"use_tls": parsed.scheme in ("rediss", "valkeys"),
|
||||
}
|
||||
|
||||
|
||||
def _parse_db_from_path(path: str | None) -> int:
|
||||
"""Parse database number from URL path, defaulting to 0."""
|
||||
if not path or path == "/":
|
||||
return 0
|
||||
try:
|
||||
return int(path.lstrip("/"))
|
||||
except ValueError:
|
||||
_logger.warning(
|
||||
"Invalid database number in URL path: %s, using default 0", path
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
def get_aiocache_config() -> dict[str, Any]:
|
||||
"""Build an aiocache configuration dict from environment.
|
||||
|
||||
Uses VALKEY_URL or REDIS_URL (both are Redis-wire-compatible) to
|
||||
configure ``aiocache.RedisCache``. Falls back to
|
||||
``aiocache.SimpleMemoryCache`` when neither variable is set.
|
||||
|
||||
Returns:
|
||||
Configuration dict suitable for ``aiocache.caches.set_config()``.
|
||||
"""
|
||||
conn = parse_cache_url()
|
||||
if conn is not None:
|
||||
return {
|
||||
"default": {
|
||||
"cache": "aiocache.RedisCache",
|
||||
"endpoint": conn["host"],
|
||||
"port": conn["port"],
|
||||
"db": conn.get("db", 0),
|
||||
"password": conn.get("password"),
|
||||
}
|
||||
}
|
||||
return {
|
||||
"default": {
|
||||
"cache": "aiocache.SimpleMemoryCache",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def use_valkey_cache() -> bool:
|
||||
"""Return True if VALKEY_URL is set in the environment."""
|
||||
return bool(os.environ.get("VALKEY_URL"))
|
||||
@@ -1,511 +0,0 @@
|
||||
"""Tests for ValkeyCache implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.memory.storage.valkey_cache import ValkeyCache
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_glide_client() -> AsyncMock:
|
||||
"""Create a mock GlideClient for testing."""
|
||||
client = AsyncMock()
|
||||
client.get = AsyncMock()
|
||||
client.set = AsyncMock()
|
||||
client.delete = AsyncMock()
|
||||
client.exists = AsyncMock()
|
||||
client.close = AsyncMock()
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def valkey_cache(mock_glide_client: AsyncMock) -> ValkeyCache:
|
||||
"""Create a ValkeyCache instance with mocked client."""
|
||||
cache = ValkeyCache(host="localhost", port=6379, db=0)
|
||||
|
||||
# Mock the client creation to return our mock
|
||||
async def mock_create_client() -> AsyncMock:
|
||||
cache._client = mock_glide_client
|
||||
return mock_glide_client
|
||||
|
||||
cache._get_client = mock_create_client # type: ignore[method-assign]
|
||||
return cache
|
||||
|
||||
|
||||
class TestValkeyCacheBasicOperations:
|
||||
"""Tests for basic ValkeyCache operations (get/set/delete/exists)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_and_get_string_value(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test setting and getting a string value."""
|
||||
# Mock get to return serialized string
|
||||
mock_glide_client.get.return_value = json.dumps("test_value")
|
||||
|
||||
# Set value
|
||||
await valkey_cache.set("test_key", "test_value")
|
||||
|
||||
# Verify set was called
|
||||
mock_glide_client.set.assert_called_once()
|
||||
call_args = mock_glide_client.set.call_args
|
||||
assert call_args[0][0] == "test_key"
|
||||
assert call_args[0][1] == json.dumps("test_value")
|
||||
|
||||
# Get value
|
||||
result = await valkey_cache.get("test_key")
|
||||
|
||||
# Verify get was called and result is correct
|
||||
mock_glide_client.get.assert_called_once_with("test_key")
|
||||
assert result == "test_value"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_and_get_dict_value(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test setting and getting a dictionary value."""
|
||||
test_dict = {"key1": "value1", "key2": 42, "key3": [1, 2, 3]}
|
||||
mock_glide_client.get.return_value = json.dumps(test_dict)
|
||||
|
||||
# Set value
|
||||
await valkey_cache.set("dict_key", test_dict)
|
||||
|
||||
# Verify set was called with serialized dict
|
||||
mock_glide_client.set.assert_called_once()
|
||||
call_args = mock_glide_client.set.call_args
|
||||
assert call_args[0][0] == "dict_key"
|
||||
assert call_args[0][1] == json.dumps(test_dict)
|
||||
|
||||
# Get value
|
||||
result = await valkey_cache.get("dict_key")
|
||||
|
||||
# Verify result matches original dict
|
||||
assert result == test_dict
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_and_get_list_value(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test setting and getting a list value."""
|
||||
test_list = [1, "two", 3.0, {"nested": "dict"}]
|
||||
mock_glide_client.get.return_value = json.dumps(test_list)
|
||||
|
||||
await valkey_cache.set("list_key", test_list)
|
||||
result = await valkey_cache.get("list_key")
|
||||
|
||||
assert result == test_list
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_nonexistent_key_returns_none(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test getting a non-existent key returns None."""
|
||||
mock_glide_client.get.return_value = None
|
||||
|
||||
result = await valkey_cache.get("nonexistent_key")
|
||||
|
||||
assert result is None
|
||||
mock_glide_client.get.assert_called_once_with("nonexistent_key")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_key(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test deleting a key."""
|
||||
await valkey_cache.delete("test_key")
|
||||
|
||||
mock_glide_client.delete.assert_called_once_with(["test_key"])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exists_returns_true_for_existing_key(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test exists returns True for existing key."""
|
||||
mock_glide_client.exists.return_value = 1
|
||||
|
||||
result = await valkey_cache.exists("existing_key")
|
||||
|
||||
assert result is True
|
||||
mock_glide_client.exists.assert_called_once_with(["existing_key"])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exists_returns_false_for_nonexistent_key(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test exists returns False for non-existent key."""
|
||||
mock_glide_client.exists.return_value = 0
|
||||
|
||||
result = await valkey_cache.exists("nonexistent_key")
|
||||
|
||||
assert result is False
|
||||
mock_glide_client.exists.assert_called_once_with(["nonexistent_key"])
|
||||
|
||||
|
||||
class TestValkeyCacheTTL:
|
||||
"""Tests for ValkeyCache TTL functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_with_explicit_ttl(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test setting a value with explicit TTL."""
|
||||
await valkey_cache.set("ttl_key", "value", ttl=3600)
|
||||
|
||||
# Verify set was called with expiry
|
||||
mock_glide_client.set.assert_called_once()
|
||||
call_args = mock_glide_client.set.call_args
|
||||
assert call_args[0][0] == "ttl_key"
|
||||
assert call_args[0][1] == json.dumps("value")
|
||||
assert "expiry" in call_args[1]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_with_default_ttl(
|
||||
self, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test setting a value with default TTL from constructor."""
|
||||
cache = ValkeyCache(host="localhost", port=6379, default_ttl=1800)
|
||||
|
||||
async def mock_create_client() -> AsyncMock:
|
||||
cache._client = mock_glide_client
|
||||
return mock_glide_client
|
||||
|
||||
cache._get_client = mock_create_client # type: ignore[method-assign]
|
||||
|
||||
await cache.set("default_ttl_key", "value")
|
||||
|
||||
# Verify set was called with default TTL
|
||||
mock_glide_client.set.assert_called_once()
|
||||
call_args = mock_glide_client.set.call_args
|
||||
assert "expiry" in call_args[1]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_without_ttl(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test setting a value without TTL (no expiration)."""
|
||||
await valkey_cache.set("no_ttl_key", "value")
|
||||
|
||||
# Verify set was called without expiry
|
||||
mock_glide_client.set.assert_called_once()
|
||||
call_args = mock_glide_client.set.call_args
|
||||
assert call_args[0][0] == "no_ttl_key"
|
||||
assert call_args[0][1] == json.dumps("value")
|
||||
# Should not have expiry parameter
|
||||
assert "expiry" not in call_args[1] or call_args[1].get("expiry") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_with_zero_ttl_no_expiration(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test setting a value with TTL=0 means no expiration."""
|
||||
await valkey_cache.set("zero_ttl_key", "value", ttl=0)
|
||||
|
||||
# Verify set was called without expiry
|
||||
mock_glide_client.set.assert_called_once()
|
||||
call_args = mock_glide_client.set.call_args
|
||||
assert "expiry" not in call_args[1] or call_args[1].get("expiry") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_explicit_ttl_overrides_default(
|
||||
self, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test explicit TTL overrides default TTL."""
|
||||
cache = ValkeyCache(host="localhost", port=6379, default_ttl=1800)
|
||||
|
||||
async def mock_create_client() -> AsyncMock:
|
||||
cache._client = mock_glide_client
|
||||
return mock_glide_client
|
||||
|
||||
cache._get_client = mock_create_client # type: ignore[method-assign]
|
||||
|
||||
await cache.set("override_key", "value", ttl=7200)
|
||||
|
||||
# Verify set was called with explicit TTL (7200), not default (1800)
|
||||
mock_glide_client.set.assert_called_once()
|
||||
call_args = mock_glide_client.set.call_args
|
||||
assert "expiry" in call_args[1]
|
||||
|
||||
|
||||
class TestValkeyCacheJSONSerialization:
|
||||
"""Tests for ValkeyCache JSON serialization edge cases."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serialize_none_value(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test serializing None value."""
|
||||
mock_glide_client.get.return_value = json.dumps(None)
|
||||
|
||||
await valkey_cache.set("none_key", None)
|
||||
result = await valkey_cache.get("none_key")
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serialize_boolean_values(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test serializing boolean values."""
|
||||
mock_glide_client.get.side_effect = [
|
||||
json.dumps(True),
|
||||
json.dumps(False),
|
||||
]
|
||||
|
||||
await valkey_cache.set("true_key", True)
|
||||
await valkey_cache.set("false_key", False)
|
||||
|
||||
result_true = await valkey_cache.get("true_key")
|
||||
result_false = await valkey_cache.get("false_key")
|
||||
|
||||
assert result_true is True
|
||||
assert result_false is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serialize_numeric_values(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test serializing numeric values (int, float)."""
|
||||
mock_glide_client.get.side_effect = [
|
||||
json.dumps(42),
|
||||
json.dumps(3.14159),
|
||||
json.dumps(0),
|
||||
json.dumps(-100),
|
||||
]
|
||||
|
||||
await valkey_cache.set("int_key", 42)
|
||||
await valkey_cache.set("float_key", 3.14159)
|
||||
await valkey_cache.set("zero_key", 0)
|
||||
await valkey_cache.set("negative_key", -100)
|
||||
|
||||
assert await valkey_cache.get("int_key") == 42
|
||||
assert await valkey_cache.get("float_key") == 3.14159
|
||||
assert await valkey_cache.get("zero_key") == 0
|
||||
assert await valkey_cache.get("negative_key") == -100
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serialize_empty_collections(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test serializing empty collections."""
|
||||
mock_glide_client.get.side_effect = [
|
||||
json.dumps([]),
|
||||
json.dumps({}),
|
||||
json.dumps(""),
|
||||
]
|
||||
|
||||
await valkey_cache.set("empty_list", [])
|
||||
await valkey_cache.set("empty_dict", {})
|
||||
await valkey_cache.set("empty_string", "")
|
||||
|
||||
assert await valkey_cache.get("empty_list") == []
|
||||
assert await valkey_cache.get("empty_dict") == {}
|
||||
assert await valkey_cache.get("empty_string") == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serialize_nested_structures(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test serializing deeply nested structures."""
|
||||
nested_data = {
|
||||
"level1": {
|
||||
"level2": {
|
||||
"level3": [1, 2, {"level4": "deep"}]
|
||||
}
|
||||
},
|
||||
"list": [{"a": 1}, {"b": 2}]
|
||||
}
|
||||
mock_glide_client.get.return_value = json.dumps(nested_data)
|
||||
|
||||
await valkey_cache.set("nested_key", nested_data)
|
||||
result = await valkey_cache.get("nested_key")
|
||||
|
||||
assert result == nested_data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deserialize_invalid_json_returns_none(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test deserializing invalid JSON returns None and logs warning."""
|
||||
mock_glide_client.get.return_value = "invalid json {{"
|
||||
|
||||
with patch("crewai.memory.storage.valkey_cache._logger") as mock_logger:
|
||||
result = await valkey_cache.get("invalid_key")
|
||||
|
||||
assert result is None
|
||||
mock_logger.warning.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serialize_unicode_strings(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test serializing unicode strings."""
|
||||
unicode_data = "Hello 世界 🌍 Привет"
|
||||
mock_glide_client.get.return_value = json.dumps(unicode_data)
|
||||
|
||||
await valkey_cache.set("unicode_key", unicode_data)
|
||||
result = await valkey_cache.get("unicode_key")
|
||||
|
||||
assert result == unicode_data
|
||||
|
||||
|
||||
class TestValkeyCacheConnectionManagement:
|
||||
"""Tests for ValkeyCache connection management."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lazy_client_initialization(self) -> None:
|
||||
"""Test client is initialized lazily on first use."""
|
||||
cache = ValkeyCache(host="localhost", port=6379)
|
||||
|
||||
# Client should be None initially
|
||||
assert cache._client is None
|
||||
|
||||
# Mock GlideClient.create
|
||||
with patch("crewai.memory.storage.valkey_cache.GlideClient") as mock_glide:
|
||||
mock_client = AsyncMock()
|
||||
mock_glide.create = AsyncMock(return_value=mock_client)
|
||||
mock_client.get = AsyncMock(return_value=None)
|
||||
|
||||
# First operation should initialize client
|
||||
await cache.get("test_key")
|
||||
|
||||
# Client should now be initialized
|
||||
assert cache._client is not None
|
||||
mock_glide.create.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_reuse_across_operations(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test client is reused across multiple operations."""
|
||||
mock_glide_client.get.return_value = json.dumps("value")
|
||||
mock_glide_client.exists.return_value = 1
|
||||
|
||||
# Perform multiple operations
|
||||
await valkey_cache.get("key1")
|
||||
await valkey_cache.set("key2", "value2")
|
||||
await valkey_cache.exists("key3")
|
||||
await valkey_cache.delete("key4")
|
||||
|
||||
# _get_client should return the same client instance
|
||||
client1 = await valkey_cache._get_client()
|
||||
client2 = await valkey_cache._get_client()
|
||||
assert client1 is client2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_connection(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test closing the client connection."""
|
||||
# Initialize client
|
||||
await valkey_cache._get_client()
|
||||
assert valkey_cache._client is not None
|
||||
|
||||
# Close connection
|
||||
await valkey_cache.close()
|
||||
|
||||
# Verify close was called and client is None
|
||||
mock_glide_client.close.assert_called_once()
|
||||
assert valkey_cache._client is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_error_raises_runtime_error(self) -> None:
|
||||
"""Test connection error raises RuntimeError with descriptive message."""
|
||||
cache = ValkeyCache(host="invalid-host", port=9999)
|
||||
|
||||
with patch("crewai.memory.storage.valkey_cache.GlideClient") as mock_glide:
|
||||
mock_glide.create = AsyncMock(side_effect=Exception("Connection refused"))
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
await cache._get_client()
|
||||
|
||||
assert "Cannot connect to Valkey" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authentication_with_password(self) -> None:
|
||||
"""Test client initialization with password authentication."""
|
||||
cache = ValkeyCache(
|
||||
host="localhost",
|
||||
port=6379,
|
||||
password="secret_password"
|
||||
)
|
||||
|
||||
with patch("crewai.memory.storage.valkey_cache.GlideClient") as mock_glide:
|
||||
mock_client = AsyncMock()
|
||||
mock_glide.create = AsyncMock(return_value=mock_client)
|
||||
|
||||
await cache._get_client()
|
||||
|
||||
# Verify GlideClient.create was called with credentials
|
||||
mock_glide.create.assert_called_once()
|
||||
config = mock_glide.create.call_args[0][0]
|
||||
assert hasattr(config, "credentials")
|
||||
|
||||
|
||||
class TestValkeyCacheEdgeCases:
|
||||
"""Tests for ValkeyCache edge cases and error conditions."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_with_special_characters_in_key(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test setting values with special characters in key."""
|
||||
special_keys = [
|
||||
"key:with:colons",
|
||||
"key/with/slashes",
|
||||
"key-with-dashes",
|
||||
"key_with_underscores",
|
||||
"key.with.dots",
|
||||
]
|
||||
|
||||
for key in special_keys:
|
||||
await valkey_cache.set(key, "value")
|
||||
mock_glide_client.set.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_value_serialization(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test serializing large values."""
|
||||
large_list = list(range(10000))
|
||||
mock_glide_client.get.return_value = json.dumps(large_list)
|
||||
|
||||
await valkey_cache.set("large_key", large_list)
|
||||
result = await valkey_cache.get("large_key")
|
||||
|
||||
assert result == large_list
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_operations(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test concurrent cache operations."""
|
||||
import asyncio
|
||||
|
||||
mock_glide_client.get.return_value = json.dumps("value")
|
||||
|
||||
# Perform concurrent operations
|
||||
tasks = [
|
||||
valkey_cache.set(f"key{i}", f"value{i}")
|
||||
for i in range(10)
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# Verify all operations completed
|
||||
assert mock_glide_client.set.call_count == 10
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_non_serializable_value_raises_type_error(
|
||||
self, valkey_cache: ValkeyCache, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test that non-JSON-serializable values raise TypeError."""
|
||||
from datetime import datetime
|
||||
|
||||
with pytest.raises(TypeError, match="not JSON-serializable"):
|
||||
await valkey_cache.set("bad_key", datetime.now())
|
||||
|
||||
# Verify set was never called on the client
|
||||
mock_glide_client.set.assert_not_called()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,267 +0,0 @@
|
||||
"""Tests for ValkeyStorage error handling."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.memory.storage.valkey_storage import ValkeyStorage
|
||||
from crewai.memory.types import MemoryRecord
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_glide_client() -> AsyncMock:
|
||||
"""Create a mock GlideClient for testing."""
|
||||
client = AsyncMock()
|
||||
client.hset = AsyncMock(return_value=1)
|
||||
client.zrange = AsyncMock(return_value=[])
|
||||
client.zadd = AsyncMock()
|
||||
client.sadd = AsyncMock()
|
||||
client.hgetall = AsyncMock(return_value={})
|
||||
client.close = AsyncMock()
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def valkey_storage(mock_glide_client: AsyncMock) -> ValkeyStorage:
|
||||
"""Create a ValkeyStorage instance with mocked client."""
|
||||
storage = ValkeyStorage(host="localhost", port=6379, db=0)
|
||||
|
||||
# Mock the client creation to return our mock
|
||||
async def mock_create_client() -> AsyncMock:
|
||||
storage._client = mock_glide_client
|
||||
return mock_glide_client
|
||||
|
||||
storage._get_client = mock_create_client # type: ignore[method-assign]
|
||||
return storage
|
||||
|
||||
|
||||
class TestSerializationErrors:
|
||||
"""Tests for serialization error handling."""
|
||||
|
||||
def test_serialization_error_raises_descriptive_exception(
|
||||
self, valkey_storage: ValkeyStorage
|
||||
) -> None:
|
||||
"""Test that serialization errors raise descriptive ValueError."""
|
||||
# Create a record with non-serializable metadata
|
||||
record = MemoryRecord(
|
||||
id="test-id",
|
||||
content="test content",
|
||||
scope="/test",
|
||||
categories=["test"],
|
||||
metadata={"bad_key": object()}, # Non-serializable object
|
||||
importance=0.5,
|
||||
created_at=datetime.now(),
|
||||
last_accessed=datetime.now(),
|
||||
embedding=[0.1, 0.2, 0.3],
|
||||
)
|
||||
|
||||
# Should raise ValueError with descriptive message
|
||||
with pytest.raises(ValueError, match="Failed to serialize record test-id"):
|
||||
valkey_storage._record_to_dict(record)
|
||||
|
||||
def test_serialization_error_includes_cause(
|
||||
self, valkey_storage: ValkeyStorage
|
||||
) -> None:
|
||||
"""Test that serialization error includes the original exception as cause."""
|
||||
# Create a mock record that will fail during JSON serialization
|
||||
# We need to bypass Pydantic validation, so we'll patch json.dumps
|
||||
record = MemoryRecord(
|
||||
id="test-id-2",
|
||||
content="test content",
|
||||
scope="/test",
|
||||
categories=["valid"],
|
||||
metadata={"key": "value"},
|
||||
importance=0.5,
|
||||
created_at=datetime.now(),
|
||||
last_accessed=datetime.now(),
|
||||
embedding=[0.1, 0.2, 0.3],
|
||||
)
|
||||
|
||||
# Patch json.dumps to raise an error
|
||||
with patch("json.dumps", side_effect=TypeError("Cannot serialize")):
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
valkey_storage._record_to_dict(record)
|
||||
|
||||
# Verify the exception has a cause
|
||||
assert exc_info.value.__cause__ is not None
|
||||
assert isinstance(exc_info.value.__cause__, TypeError)
|
||||
|
||||
|
||||
class TestDeserializationErrors:
|
||||
"""Tests for deserialization error handling."""
|
||||
|
||||
def test_deserialization_error_logs_and_returns_none(
|
||||
self, valkey_storage: ValkeyStorage, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""Test that deserialization errors log error and return None."""
|
||||
# Create malformed data (missing required fields)
|
||||
malformed_data = {
|
||||
"id": "test-id",
|
||||
"content": "test content",
|
||||
# Missing scope, categories, metadata, etc.
|
||||
}
|
||||
|
||||
# Should return None and log error
|
||||
result = valkey_storage._dict_to_record(malformed_data)
|
||||
|
||||
assert result is None
|
||||
assert "Failed to deserialize record test-id" in caplog.text
|
||||
|
||||
def test_deserialization_with_invalid_json_categories_uses_tag_fallback(
|
||||
self, valkey_storage: ValkeyStorage
|
||||
) -> None:
|
||||
"""Test that non-JSON categories fall back to TAG (comma-separated) parsing."""
|
||||
# Create data with non-JSON categories string
|
||||
data = {
|
||||
"id": "test-id-json",
|
||||
"content": "test content",
|
||||
"scope": "/test",
|
||||
"categories": "not valid json [", # Not JSON, treated as TAG format
|
||||
"metadata": "{}",
|
||||
"importance": "0.5",
|
||||
"created_at": "2024-01-01T12:00:00",
|
||||
"last_accessed": "2024-01-01T12:00:00",
|
||||
"source": "",
|
||||
"private": "false",
|
||||
}
|
||||
|
||||
result = valkey_storage._dict_to_record(data)
|
||||
|
||||
# TAG fallback: comma-split produces the raw string as a single category
|
||||
assert result is not None
|
||||
assert result.id == "test-id-json"
|
||||
assert result.categories == ["not valid json ["]
|
||||
|
||||
def test_deserialization_with_invalid_datetime_returns_none(
|
||||
self, valkey_storage: ValkeyStorage, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""Test that invalid datetime format returns None."""
|
||||
# Create data with invalid datetime
|
||||
invalid_data = {
|
||||
"id": "test-id-datetime",
|
||||
"content": "test content",
|
||||
"scope": "/test",
|
||||
"categories": '["test"]',
|
||||
"metadata": "{}",
|
||||
"importance": "0.5",
|
||||
"created_at": "not a datetime", # Invalid datetime
|
||||
"last_accessed": "2024-01-01T12:00:00",
|
||||
"source": "",
|
||||
"private": "false",
|
||||
}
|
||||
|
||||
result = valkey_storage._dict_to_record(invalid_data)
|
||||
|
||||
assert result is None
|
||||
assert "Failed to deserialize record test-id-datetime" in caplog.text
|
||||
|
||||
def test_deserialization_with_invalid_float_returns_none(
|
||||
self, valkey_storage: ValkeyStorage, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""Test that invalid float importance returns None."""
|
||||
# Create data with invalid float
|
||||
invalid_data = {
|
||||
"id": "test-id-float",
|
||||
"content": "test content",
|
||||
"scope": "/test",
|
||||
"categories": '["test"]',
|
||||
"metadata": "{}",
|
||||
"importance": "not a float", # Invalid float
|
||||
"created_at": "2024-01-01T12:00:00",
|
||||
"last_accessed": "2024-01-01T12:00:00",
|
||||
"source": "",
|
||||
"private": "false",
|
||||
}
|
||||
|
||||
result = valkey_storage._dict_to_record(invalid_data)
|
||||
|
||||
assert result is None
|
||||
assert "Failed to deserialize record test-id-float" in caplog.text
|
||||
|
||||
def test_deserialization_with_bytes_keys_uses_tag_fallback(
|
||||
self, valkey_storage: ValkeyStorage
|
||||
) -> None:
|
||||
"""Test that deserialization handles bytes keys with non-JSON categories via TAG fallback."""
|
||||
# Create data with bytes keys (as returned by Valkey)
|
||||
bytes_data = {
|
||||
b"id": b"test-id-bytes",
|
||||
b"content": b"test content",
|
||||
b"scope": b"/test",
|
||||
b"categories": b"invalid json [", # Not JSON, treated as TAG format
|
||||
b"metadata": b"{}",
|
||||
b"importance": b"0.5",
|
||||
b"created_at": b"2024-01-01T12:00:00",
|
||||
b"last_accessed": b"2024-01-01T12:00:00",
|
||||
}
|
||||
|
||||
result = valkey_storage._dict_to_record(bytes_data)
|
||||
|
||||
# TAG fallback: comma-split produces the raw string as a single category
|
||||
assert result is not None
|
||||
assert result.id == "test-id-bytes"
|
||||
assert result.categories == ["invalid json ["]
|
||||
|
||||
|
||||
class TestRetryBehaviorIntegration:
|
||||
"""Integration tests demonstrating retry behavior patterns."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mock_client_operation_with_retry_pattern(
|
||||
self, valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test demonstrating how retry would work with client operations."""
|
||||
from glide import ClosingError
|
||||
|
||||
# Mock a client operation that fails once
|
||||
mock_glide_client.hgetall.side_effect = [
|
||||
ClosingError("Connection lost"),
|
||||
{
|
||||
b"id": b"test-id",
|
||||
b"content": b"test content",
|
||||
b"scope": b"/test",
|
||||
b"categories": b'["test"]',
|
||||
b"metadata": b"{}",
|
||||
b"importance": b"0.5",
|
||||
b"created_at": b"2024-01-01T12:00:00",
|
||||
b"last_accessed": b"2024-01-01T12:00:00",
|
||||
b"source": b"",
|
||||
b"private": b"false",
|
||||
b"embedding": b"",
|
||||
},
|
||||
]
|
||||
|
||||
# First call fails, second succeeds
|
||||
with pytest.raises(ClosingError):
|
||||
await mock_glide_client.hgetall("record:test-id")
|
||||
|
||||
# Second call succeeds
|
||||
result = await mock_glide_client.hgetall("record:test-id")
|
||||
assert result is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serialization_error_not_retried(
|
||||
self, valkey_storage: ValkeyStorage
|
||||
) -> None:
|
||||
"""Test that serialization errors are not retried (they're not connection errors)."""
|
||||
# Create a record with non-serializable data
|
||||
record = MemoryRecord(
|
||||
id="test-id",
|
||||
content="test content",
|
||||
scope="/test",
|
||||
categories=["test"],
|
||||
metadata={"bad": object()},
|
||||
importance=0.5,
|
||||
created_at=datetime.now(),
|
||||
last_accessed=datetime.now(),
|
||||
embedding=[0.1, 0.2, 0.3],
|
||||
)
|
||||
|
||||
# Serialization error should not be retried
|
||||
with pytest.raises(ValueError, match="Failed to serialize"):
|
||||
valkey_storage._record_to_dict(record)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,998 +0,0 @@
|
||||
"""Tests for ValkeyStorage vector search operation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.memory.storage.valkey_storage import ValkeyStorage
|
||||
from crewai.memory.types import MemoryRecord
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_glide_client() -> AsyncMock:
|
||||
"""Create a mock GlideClient for testing."""
|
||||
client = AsyncMock()
|
||||
client.hset = AsyncMock(return_value=1)
|
||||
client.zrange = AsyncMock(return_value=[])
|
||||
client.zadd = AsyncMock()
|
||||
client.sadd = AsyncMock()
|
||||
client.hgetall = AsyncMock(return_value={})
|
||||
client.close = AsyncMock()
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def valkey_storage(mock_glide_client: AsyncMock) -> ValkeyStorage:
|
||||
"""Create a ValkeyStorage instance with mocked client."""
|
||||
storage = ValkeyStorage(host="localhost", port=6379, db=0)
|
||||
|
||||
# Mock the client creation to return our mock
|
||||
async def mock_create_client() -> AsyncMock:
|
||||
storage._client = mock_glide_client
|
||||
return mock_glide_client
|
||||
|
||||
storage._get_client = mock_create_client # type: ignore[method-assign]
|
||||
return storage
|
||||
|
||||
|
||||
def create_mock_ft_search_response(
|
||||
records: list[tuple[MemoryRecord, float]]
|
||||
) -> list[int | dict[str, dict[str, str]]]:
|
||||
"""Create a mock FT.SEARCH response in native dict format.
|
||||
|
||||
Args:
|
||||
records: List of (MemoryRecord, score) tuples to include in response.
|
||||
|
||||
Returns:
|
||||
Mock FT.SEARCH response in the native format:
|
||||
[total_count, {doc_key: {field: value, ...}, ...}]
|
||||
"""
|
||||
if not records:
|
||||
return [0]
|
||||
|
||||
docs: dict[str, dict[str, str]] = {}
|
||||
|
||||
for record, score in records:
|
||||
doc_key = f"record:{record.id}"
|
||||
|
||||
# Build field dict
|
||||
fields: dict[str, str] = {}
|
||||
fields["id"] = record.id
|
||||
fields["content"] = record.content
|
||||
fields["scope"] = record.scope
|
||||
fields["categories"] = json.dumps(record.categories)
|
||||
fields["metadata"] = json.dumps(record.metadata)
|
||||
fields["importance"] = str(record.importance)
|
||||
fields["created_at"] = record.created_at.isoformat()
|
||||
fields["last_accessed"] = record.last_accessed.isoformat()
|
||||
fields["source"] = record.source or ""
|
||||
fields["private"] = "true" if record.private else "false"
|
||||
|
||||
# Add score (Valkey Search returns cosine distance, not similarity)
|
||||
# Convert similarity to distance: distance = 2 * (1 - similarity)
|
||||
distance = 2.0 * (1.0 - score)
|
||||
fields["score"] = str(distance)
|
||||
|
||||
# Add embedding if present
|
||||
if record.embedding:
|
||||
fields["embedding"] = json.dumps(record.embedding)
|
||||
|
||||
docs[doc_key] = fields
|
||||
|
||||
return [len(records), docs]
|
||||
|
||||
|
||||
class TestValkeyStorageVectorSearch:
|
||||
"""Tests for ValkeyStorage vector search operation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_with_no_filters_returns_all_records(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search with no filters returns all records."""
|
||||
# Create test records
|
||||
record1 = MemoryRecord(
|
||||
id="record-1",
|
||||
content="First test record",
|
||||
scope="/test",
|
||||
categories=["cat1"],
|
||||
metadata={"key": "value1"},
|
||||
importance=0.8,
|
||||
created_at=datetime(2024, 1, 1, 10, 0, 0),
|
||||
last_accessed=datetime(2024, 1, 1, 11, 0, 0),
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
record2 = MemoryRecord(
|
||||
id="record-2",
|
||||
content="Second test record",
|
||||
scope="/test",
|
||||
categories=["cat2"],
|
||||
metadata={"key": "value2"},
|
||||
importance=0.6,
|
||||
created_at=datetime(2024, 1, 2, 10, 0, 0),
|
||||
last_accessed=datetime(2024, 1, 2, 11, 0, 0),
|
||||
embedding=[0.2, 0.3, 0.4, 0.5],
|
||||
)
|
||||
|
||||
# Mock FT.INFO to simulate index exists
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
# Mock FT.SEARCH to return both records
|
||||
mock_ft_search.return_value = create_mock_ft_search_response([
|
||||
(record1, 0.95),
|
||||
(record2, 0.85),
|
||||
])
|
||||
|
||||
# Perform search with no filters
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(query_embedding, limit=10)
|
||||
|
||||
# Verify ft.search was called
|
||||
mock_ft_search.assert_called_once()
|
||||
|
||||
# Verify query contains only KNN part (no filters)
|
||||
call_args = mock_ft_search.call_args
|
||||
query = call_args[0][2] # 3rd positional arg: query string
|
||||
assert "*=>[KNN 10 @embedding $BLOB AS score]" in query
|
||||
assert "@scope" not in query
|
||||
assert "@categories" not in query
|
||||
|
||||
# Verify results
|
||||
assert len(results) == 2
|
||||
assert results[0][0].id == "record-1"
|
||||
assert results[0][1] == 0.95
|
||||
assert results[1][0].id == "record-2"
|
||||
assert results[1][1] == 0.85
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_with_scope_filter_only(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search with scope filter only."""
|
||||
record1 = MemoryRecord(
|
||||
id="record-1",
|
||||
content="Record in scope",
|
||||
scope="/agent/task",
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)])
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(
|
||||
query_embedding,
|
||||
scope_prefix="/agent",
|
||||
limit=10
|
||||
)
|
||||
|
||||
# Verify query contains scope filter
|
||||
call_args = mock_ft_search.call_args
|
||||
query = call_args[0][2]
|
||||
assert "(@scope:{/agent*})=>[KNN 10 @embedding $BLOB AS score]" in query
|
||||
|
||||
# Verify results
|
||||
assert len(results) == 1
|
||||
assert results[0][0].id == "record-1"
|
||||
assert results[0][0].scope == "/agent/task"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_with_category_filter_only(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search with category filter only."""
|
||||
record1 = MemoryRecord(
|
||||
id="record-1",
|
||||
content="Record with planning category",
|
||||
scope="/test",
|
||||
categories=["planning"],
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.88)])
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(
|
||||
query_embedding,
|
||||
categories=["planning", "execution"],
|
||||
limit=10
|
||||
)
|
||||
|
||||
# Verify query contains category filter with OR logic
|
||||
call_args = mock_ft_search.call_args
|
||||
query = call_args[0][2]
|
||||
assert "(@categories:{planning|execution})=>[KNN 10 @embedding $BLOB AS score]" in query
|
||||
|
||||
# Verify results
|
||||
assert len(results) == 1
|
||||
assert results[0][0].id == "record-1"
|
||||
assert "planning" in results[0][0].categories
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_with_metadata_filter_only(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search with metadata filter only."""
|
||||
record1 = MemoryRecord(
|
||||
id="record-1",
|
||||
content="Record with metadata",
|
||||
scope="/test",
|
||||
metadata={"agent_id": "agent-1", "priority": "high"},
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.92)])
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(
|
||||
query_embedding,
|
||||
metadata_filter={"agent_id": "agent-1", "priority": "high"},
|
||||
limit=10
|
||||
)
|
||||
|
||||
# Verify query contains metadata filters (AND logic)
|
||||
call_args = mock_ft_search.call_args
|
||||
query = call_args[0][2]
|
||||
assert "@agent_id:{agent\\-1}" in query or "@agent_id:{agent-1}" in query
|
||||
assert "@priority:{high}" in query
|
||||
assert "=>[KNN 10 @embedding $BLOB AS score]" in query
|
||||
|
||||
# Verify results
|
||||
assert len(results) == 1
|
||||
assert results[0][0].id == "record-1"
|
||||
assert results[0][0].metadata["agent_id"] == "agent-1"
|
||||
assert results[0][0].metadata["priority"] == "high"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_with_combined_filters(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search with combined filters (scope + categories + metadata)."""
|
||||
record1 = MemoryRecord(
|
||||
id="record-1",
|
||||
content="Record matching all filters",
|
||||
scope="/agent/task",
|
||||
categories=["planning"],
|
||||
metadata={"agent_id": "agent-1"},
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.93)])
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(
|
||||
query_embedding,
|
||||
scope_prefix="/agent",
|
||||
categories=["planning"],
|
||||
metadata_filter={"agent_id": "agent-1"},
|
||||
limit=10
|
||||
)
|
||||
|
||||
# Verify query contains all filters
|
||||
call_args = mock_ft_search.call_args
|
||||
query = call_args[0][2]
|
||||
assert "@scope:{/agent*}" in query
|
||||
assert "@categories:{planning}" in query
|
||||
assert "@agent_id:{agent\\-1}" in query or "@agent_id:{agent-1}" in query
|
||||
assert "=>[KNN 10 @embedding $BLOB AS score]" in query
|
||||
|
||||
# Verify results
|
||||
assert len(results) == 1
|
||||
assert results[0][0].id == "record-1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_respects_limit_parameter(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search respects limit parameter."""
|
||||
records = [
|
||||
(
|
||||
MemoryRecord(
|
||||
id=f"record-{i}",
|
||||
content=f"Record {i}",
|
||||
scope="/test",
|
||||
embedding=[0.1 * i, 0.2 * i, 0.3 * i, 0.4 * i],
|
||||
),
|
||||
0.9 - (i * 0.1)
|
||||
)
|
||||
for i in range(1, 6)
|
||||
]
|
||||
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = create_mock_ft_search_response(records[:3])
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(query_embedding, limit=3)
|
||||
|
||||
# Verify KNN limit in query
|
||||
call_args = mock_ft_search.call_args
|
||||
query = call_args[0][2]
|
||||
assert "=>[KNN 3 @embedding $BLOB AS score]" in query
|
||||
|
||||
# Verify results respect limit
|
||||
assert len(results) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_respects_min_score_parameter(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search respects min_score parameter."""
|
||||
record1 = MemoryRecord(
|
||||
id="record-1",
|
||||
content="High score record",
|
||||
scope="/test",
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
record2 = MemoryRecord(
|
||||
id="record-2",
|
||||
content="Medium score record",
|
||||
scope="/test",
|
||||
embedding=[0.2, 0.3, 0.4, 0.5],
|
||||
)
|
||||
record3 = MemoryRecord(
|
||||
id="record-3",
|
||||
content="Low score record",
|
||||
scope="/test",
|
||||
embedding=[0.3, 0.4, 0.5, 0.6],
|
||||
)
|
||||
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = create_mock_ft_search_response([
|
||||
(record1, 0.95),
|
||||
(record2, 0.75),
|
||||
(record3, 0.55),
|
||||
])
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(
|
||||
query_embedding,
|
||||
limit=10,
|
||||
min_score=0.7
|
||||
)
|
||||
|
||||
# Verify only records with score >= 0.7 are returned
|
||||
assert len(results) == 2
|
||||
assert results[0][0].id == "record-1"
|
||||
assert results[0][1] == 0.95
|
||||
assert results[1][0].id == "record-2"
|
||||
assert results[1][1] == 0.75
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_returns_results_ordered_by_descending_score(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search returns results ordered by descending score."""
|
||||
record1 = MemoryRecord(
|
||||
id="record-1",
|
||||
content="Medium score",
|
||||
scope="/test",
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
record2 = MemoryRecord(
|
||||
id="record-2",
|
||||
content="Highest score",
|
||||
scope="/test",
|
||||
embedding=[0.2, 0.3, 0.4, 0.5],
|
||||
)
|
||||
record3 = MemoryRecord(
|
||||
id="record-3",
|
||||
content="Lowest score",
|
||||
scope="/test",
|
||||
embedding=[0.3, 0.4, 0.5, 0.6],
|
||||
)
|
||||
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = create_mock_ft_search_response([
|
||||
(record1, 0.75),
|
||||
(record2, 0.95),
|
||||
(record3, 0.55),
|
||||
])
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(query_embedding, limit=10)
|
||||
|
||||
# Verify results are ordered by descending score
|
||||
assert len(results) == 3
|
||||
assert results[0][0].id == "record-2"
|
||||
assert results[0][1] == 0.95
|
||||
assert results[1][0].id == "record-1"
|
||||
assert results[1][1] == 0.75
|
||||
assert results[2][0].id == "record-3"
|
||||
assert results[2][1] == 0.55
|
||||
|
||||
# Verify scores are in descending order
|
||||
for i in range(len(results) - 1):
|
||||
assert results[i][1] >= results[i + 1][1]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_with_empty_results(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search with no matching results."""
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = [0] # Total count = 0
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(query_embedding, limit=10)
|
||||
|
||||
# Verify empty results
|
||||
assert len(results) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_with_special_characters_in_scope(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search with special characters in scope prefix."""
|
||||
record1 = MemoryRecord(
|
||||
id="record-1",
|
||||
content="Record with special scope",
|
||||
scope="/agent:task-1",
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)])
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(
|
||||
query_embedding,
|
||||
scope_prefix="/agent:task",
|
||||
limit=10
|
||||
)
|
||||
|
||||
# Verify query contains escaped scope
|
||||
call_args = mock_ft_search.call_args
|
||||
query = call_args[0][2]
|
||||
assert "@scope:{/agent\\:task*}" in query or "@scope:{/agent:task*}" in query
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_with_special_characters_in_categories(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search with special characters in categories."""
|
||||
record1 = MemoryRecord(
|
||||
id="record-1",
|
||||
content="Record with special category",
|
||||
scope="/test",
|
||||
categories=["plan:execute"],
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)])
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(
|
||||
query_embedding,
|
||||
categories=["plan:execute"],
|
||||
limit=10
|
||||
)
|
||||
|
||||
# Verify query contains escaped category
|
||||
call_args = mock_ft_search.call_args
|
||||
query = call_args[0][2]
|
||||
assert "@categories:{plan\\:execute}" in query or "@categories:{plan:execute}" in query
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_with_numeric_metadata_values(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search with numeric metadata values."""
|
||||
record1 = MemoryRecord(
|
||||
id="record-1",
|
||||
content="Record with numeric metadata",
|
||||
scope="/test",
|
||||
metadata={"count": 42, "score": 3.14},
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)])
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(
|
||||
query_embedding,
|
||||
metadata_filter={"count": 42, "score": 3.14},
|
||||
limit=10
|
||||
)
|
||||
|
||||
# Verify query contains string-converted metadata values
|
||||
call_args = mock_ft_search.call_args
|
||||
query = call_args[0][2]
|
||||
assert "@count:{42}" in query
|
||||
assert "@score:{3" in query and "14}" in query
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_with_embedding_blob_parameter(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search passes embedding as BLOB parameter."""
|
||||
record1 = MemoryRecord(
|
||||
id="record-1",
|
||||
content="Test record",
|
||||
scope="/test",
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)])
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(query_embedding, limit=10)
|
||||
|
||||
# Verify ft.search was called with search options containing BLOB param
|
||||
call_args = mock_ft_search.call_args
|
||||
# The 4th positional arg is the FtSearchOptions
|
||||
search_options = call_args[0][3]
|
||||
# The options object should have params with BLOB
|
||||
assert search_options is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_results_sorted_by_score(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search results are sorted by score (descending) automatically."""
|
||||
record1 = MemoryRecord(
|
||||
id="record-1",
|
||||
content="Test record",
|
||||
scope="/test",
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)])
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(query_embedding, limit=10)
|
||||
|
||||
# Verify ft.search was called (results are auto-sorted by vector search)
|
||||
mock_ft_search.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_with_return_fields(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search includes RETURN clause with all record fields."""
|
||||
record1 = MemoryRecord(
|
||||
id="record-1",
|
||||
content="Test record",
|
||||
scope="/test",
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)])
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(query_embedding, limit=10)
|
||||
|
||||
# Verify ft.search was called with search options containing return fields
|
||||
call_args = mock_ft_search.call_args
|
||||
search_options = call_args[0][3]
|
||||
assert search_options is not None
|
||||
# The FtSearchOptions should have return_fields set
|
||||
assert search_options.return_fields is not None
|
||||
assert len(search_options.return_fields) == 11 # All fields including score
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.VectorFieldAttributesHnsw")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.create")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_handles_valkey_search_not_available(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_create: AsyncMock,
|
||||
mock_vector_attrs: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search raises error when Valkey Search module is not available."""
|
||||
# Mock FT.INFO to fail (index doesn't exist)
|
||||
mock_ft_list.return_value = []
|
||||
# Mock FT.CREATE to fail (Search module not available)
|
||||
mock_ft_create.side_effect = Exception("ERR unknown command 'ft.create'")
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
|
||||
with pytest.raises(RuntimeError, match="Valkey Search module is not available"):
|
||||
await valkey_storage.asearch(query_embedding, limit=10)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_handles_ft_search_error(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search handles FT.SEARCH errors gracefully."""
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.side_effect = Exception("ERR unknown command 'FT.SEARCH'")
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
|
||||
with pytest.raises(RuntimeError, match="Valkey Search module is not available"):
|
||||
await valkey_storage.asearch(query_embedding, limit=10)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_handles_malformed_ft_search_response(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search handles malformed FT.SEARCH response gracefully."""
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = None # Malformed response
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(query_embedding, limit=10)
|
||||
|
||||
# Verify empty results are returned (graceful handling)
|
||||
assert len(results) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_handles_missing_score_field(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search handles missing score field in results."""
|
||||
record1 = MemoryRecord(
|
||||
id="record-1",
|
||||
content="Test record",
|
||||
scope="/test",
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
|
||||
# Create mock response without score field (dict format)
|
||||
docs = {
|
||||
f"record:{record1.id}": {
|
||||
"id": record1.id,
|
||||
"content": record1.content,
|
||||
"scope": record1.scope,
|
||||
"categories": str(record1.categories),
|
||||
"metadata": str(record1.metadata),
|
||||
"importance": str(record1.importance),
|
||||
"created_at": record1.created_at.isoformat(),
|
||||
"last_accessed": record1.last_accessed.isoformat(),
|
||||
"source": record1.source or "",
|
||||
"private": "false",
|
||||
# No score field
|
||||
}
|
||||
}
|
||||
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = [1, docs]
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(query_embedding, limit=10)
|
||||
|
||||
# Verify record is returned with default score of 0.0
|
||||
assert len(results) == 1
|
||||
assert results[0][0].id == "record-1"
|
||||
assert results[0][1] == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_filters_out_records_with_deserialization_errors(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search filters out records that fail deserialization."""
|
||||
valid_record = MemoryRecord(
|
||||
id="valid-record",
|
||||
content="Valid record",
|
||||
scope="/test",
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
|
||||
# Create mock response with one valid and one invalid record (dict format)
|
||||
docs = {
|
||||
f"record:{valid_record.id}": {
|
||||
"id": valid_record.id,
|
||||
"content": valid_record.content,
|
||||
"scope": valid_record.scope,
|
||||
"categories": str(valid_record.categories),
|
||||
"metadata": str(valid_record.metadata),
|
||||
"importance": str(valid_record.importance),
|
||||
"created_at": valid_record.created_at.isoformat(),
|
||||
"last_accessed": valid_record.last_accessed.isoformat(),
|
||||
"source": valid_record.source or "",
|
||||
"private": "false",
|
||||
"score": "0.1",
|
||||
},
|
||||
"record:invalid-record": {
|
||||
"id": "invalid-record",
|
||||
# Missing content, scope, and other required fields
|
||||
"score": "0.2",
|
||||
},
|
||||
}
|
||||
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = [2, docs]
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(query_embedding, limit=10)
|
||||
|
||||
# Verify only valid record is returned
|
||||
assert len(results) == 1
|
||||
assert results[0][0].id == "valid-record"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_converts_cosine_distance_to_similarity(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search converts Valkey Search cosine distance to similarity score."""
|
||||
record1 = MemoryRecord(
|
||||
id="record-1",
|
||||
content="Test record",
|
||||
scope="/test",
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
|
||||
# Create mock response with distance score (dict format)
|
||||
docs = {
|
||||
f"record:{record1.id}": {
|
||||
"id": record1.id,
|
||||
"content": record1.content,
|
||||
"scope": record1.scope,
|
||||
"categories": str(record1.categories),
|
||||
"metadata": str(record1.metadata),
|
||||
"importance": str(record1.importance),
|
||||
"created_at": record1.created_at.isoformat(),
|
||||
"last_accessed": record1.last_accessed.isoformat(),
|
||||
"source": record1.source or "",
|
||||
"private": "false",
|
||||
"score": "0.1", # Distance = 0.1
|
||||
}
|
||||
}
|
||||
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = [1, docs]
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(query_embedding, limit=10)
|
||||
|
||||
# Verify similarity score is correctly converted
|
||||
assert len(results) == 1
|
||||
assert results[0][0].id == "record-1"
|
||||
# Distance 0.1 -> Similarity = 1 - (0.1 / 2) = 0.95
|
||||
assert abs(results[0][1] - 0.95) < 0.01
|
||||
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
def test_search_sync_wrapper(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test that sync search wrapper calls async implementation."""
|
||||
record1 = MemoryRecord(
|
||||
id="record-1",
|
||||
content="Test record",
|
||||
scope="/test",
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)])
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = valkey_storage.search(query_embedding, limit=10)
|
||||
|
||||
# Verify ft.search was called
|
||||
assert mock_ft_search.call_count >= 1
|
||||
|
||||
# Verify results
|
||||
assert len(results) == 1
|
||||
assert results[0][0].id == "record-1"
|
||||
assert results[0][1] == 0.9
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_with_multiple_categories_uses_or_logic(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search with multiple categories uses OR logic."""
|
||||
record1 = MemoryRecord(
|
||||
id="record-1",
|
||||
content="Record with one matching category",
|
||||
scope="/test",
|
||||
categories=["planning"],
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)])
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(
|
||||
query_embedding,
|
||||
categories=["planning", "execution", "review"],
|
||||
limit=10
|
||||
)
|
||||
|
||||
# Verify query contains OR logic for categories
|
||||
call_args = mock_ft_search.call_args
|
||||
query = call_args[0][2]
|
||||
assert "@categories:{planning|execution|review}" in query
|
||||
|
||||
# Verify record with only one matching category is returned
|
||||
assert len(results) == 1
|
||||
assert results[0][0].id == "record-1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_with_multiple_metadata_filters_uses_and_logic(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search with multiple metadata filters uses AND logic."""
|
||||
record1 = MemoryRecord(
|
||||
id="record-1",
|
||||
content="Record matching all metadata",
|
||||
scope="/test",
|
||||
metadata={"agent_id": "agent-1", "priority": "high", "status": "active"},
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.9)])
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(
|
||||
query_embedding,
|
||||
metadata_filter={"agent_id": "agent-1", "priority": "high", "status": "active"},
|
||||
limit=10
|
||||
)
|
||||
|
||||
# Verify query contains AND logic for metadata
|
||||
call_args = mock_ft_search.call_args
|
||||
query = call_args[0][2]
|
||||
assert "@agent_id:" in query
|
||||
assert "@priority:" in query
|
||||
assert "@status:" in query
|
||||
|
||||
# Verify record matching all metadata is returned
|
||||
assert len(results) == 1
|
||||
assert results[0][0].id == "record-1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_with_zero_limit_returns_empty(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search with limit=0 returns empty results."""
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = [0]
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(query_embedding, limit=0)
|
||||
|
||||
# Verify empty results
|
||||
assert len(results) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_with_min_score_one_filters_all(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search with min_score=1.0 filters out all non-perfect matches."""
|
||||
record1 = MemoryRecord(
|
||||
id="record-1",
|
||||
content="High score but not perfect",
|
||||
scope="/test",
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = create_mock_ft_search_response([(record1, 0.99)])
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(
|
||||
query_embedding,
|
||||
limit=10,
|
||||
min_score=1.0
|
||||
)
|
||||
|
||||
# Verify all results are filtered out
|
||||
assert len(results) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.search")
|
||||
@patch("crewai.memory.storage.valkey_storage.ft.list")
|
||||
async def test_search_with_min_score_zero_returns_all(
|
||||
self, mock_ft_list: AsyncMock, mock_ft_search: AsyncMock,
|
||||
valkey_storage: ValkeyStorage, mock_glide_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test search with min_score=0.0 returns all results."""
|
||||
record1 = MemoryRecord(
|
||||
id="record-1",
|
||||
content="High score",
|
||||
scope="/test",
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
record2 = MemoryRecord(
|
||||
id="record-2",
|
||||
content="Low score",
|
||||
scope="/test",
|
||||
embedding=[0.2, 0.3, 0.4, 0.5],
|
||||
)
|
||||
|
||||
mock_ft_list.return_value = [b"memory_index"]
|
||||
mock_ft_search.return_value = create_mock_ft_search_response([
|
||||
(record1, 0.95),
|
||||
(record2, 0.05),
|
||||
])
|
||||
|
||||
query_embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
results = await valkey_storage.asearch(
|
||||
query_embedding,
|
||||
limit=10,
|
||||
min_score=0.0
|
||||
)
|
||||
|
||||
# Verify all results are returned
|
||||
assert len(results) == 2
|
||||
assert results[0][0].id == "record-1"
|
||||
assert results[1][0].id == "record-2"
|
||||
@@ -1,115 +0,0 @@
|
||||
"""Tests for embedding safety: bytes→float validators and async-safe embed_texts."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from crewai.memory.types import MemoryRecord, embed_text, embed_texts
|
||||
|
||||
|
||||
class TestMemoryRecordEmbeddingValidator:
|
||||
"""Tests for MemoryRecord.validate_embedding (bytes→list[float])."""
|
||||
|
||||
def test_none_embedding_stays_none(self) -> None:
|
||||
r = MemoryRecord(content="test", embedding=None)
|
||||
assert r.embedding is None
|
||||
|
||||
def test_list_of_floats_passes_through(self) -> None:
|
||||
r = MemoryRecord(content="test", embedding=[0.1, 0.2, 0.3])
|
||||
assert r.embedding == [0.1, 0.2, 0.3]
|
||||
|
||||
def test_bytes_converted_to_list_float(self) -> None:
|
||||
arr = np.array([0.1, 0.2, 0.3], dtype=np.float32)
|
||||
raw_bytes = arr.tobytes()
|
||||
r = MemoryRecord(content="test", embedding=raw_bytes)
|
||||
assert r.embedding is not None
|
||||
assert len(r.embedding) == 3
|
||||
assert all(isinstance(x, float) for x in r.embedding)
|
||||
np.testing.assert_allclose(r.embedding, [0.1, 0.2, 0.3], atol=1e-6)
|
||||
|
||||
def test_empty_bytes_becomes_none(self) -> None:
|
||||
r = MemoryRecord(content="test", embedding=b"")
|
||||
assert r.embedding is None
|
||||
|
||||
def test_list_of_ints_converted_to_floats(self) -> None:
|
||||
r = MemoryRecord(content="test", embedding=[1, 2, 3])
|
||||
assert r.embedding == [1.0, 2.0, 3.0]
|
||||
assert all(isinstance(x, float) for x in r.embedding)
|
||||
|
||||
def test_numpy_array_converted_to_list(self) -> None:
|
||||
arr = np.array([0.5, 0.6], dtype=np.float32)
|
||||
r = MemoryRecord(content="test", embedding=arr)
|
||||
assert r.embedding is not None
|
||||
assert isinstance(r.embedding, list)
|
||||
assert len(r.embedding) == 2
|
||||
|
||||
|
||||
class TestEmbedTextsAsyncSafety:
|
||||
"""Tests for embed_texts running safely in async context."""
|
||||
|
||||
def test_embed_texts_sync_context(self) -> None:
|
||||
"""embed_texts works in a normal sync context."""
|
||||
embedder = MagicMock(return_value=[[0.1, 0.2], [0.3, 0.4]])
|
||||
result = embed_texts(embedder, ["hello", "world"])
|
||||
assert len(result) == 2
|
||||
assert result[0] == [0.1, 0.2]
|
||||
embedder.assert_called_once()
|
||||
|
||||
def test_embed_texts_empty_input(self) -> None:
|
||||
embedder = MagicMock()
|
||||
assert embed_texts(embedder, []) == []
|
||||
embedder.assert_not_called()
|
||||
|
||||
def test_embed_texts_all_empty_strings(self) -> None:
|
||||
embedder = MagicMock()
|
||||
result = embed_texts(embedder, ["", " ", ""])
|
||||
assert result == [[], [], []]
|
||||
embedder.assert_not_called()
|
||||
|
||||
def test_embed_texts_skips_empty_preserves_positions(self) -> None:
|
||||
embedder = MagicMock(return_value=[[0.1, 0.2]])
|
||||
result = embed_texts(embedder, ["", "hello", ""])
|
||||
assert result == [[], [0.1, 0.2], []]
|
||||
embedder.assert_called_once_with(["hello"])
|
||||
|
||||
def test_embed_texts_in_async_context(self) -> None:
|
||||
"""embed_texts uses thread pool when called from async context."""
|
||||
embedder = MagicMock(return_value=[[0.1, 0.2]])
|
||||
|
||||
async def run() -> list[list[float]]:
|
||||
return embed_texts(embedder, ["hello"])
|
||||
|
||||
result = asyncio.run(run())
|
||||
assert result == [[0.1, 0.2]]
|
||||
embedder.assert_called_once()
|
||||
|
||||
|
||||
class TestEmbedText:
|
||||
"""Tests for embed_text (single text)."""
|
||||
|
||||
def test_empty_string_returns_empty(self) -> None:
|
||||
embedder = MagicMock()
|
||||
assert embed_text(embedder, "") == []
|
||||
embedder.assert_not_called()
|
||||
|
||||
def test_whitespace_only_returns_empty(self) -> None:
|
||||
embedder = MagicMock()
|
||||
assert embed_text(embedder, " ") == []
|
||||
embedder.assert_not_called()
|
||||
|
||||
def test_normal_text_returns_embedding(self) -> None:
|
||||
embedder = MagicMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
result = embed_text(embedder, "hello")
|
||||
assert result == [0.1, 0.2, 0.3]
|
||||
|
||||
def test_numpy_array_result_converted(self) -> None:
|
||||
arr = np.array([0.1, 0.2], dtype=np.float32)
|
||||
embedder = MagicMock(return_value=[arr])
|
||||
result = embed_text(embedder, "hello")
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 2
|
||||
@@ -1,125 +0,0 @@
|
||||
"""Tests for shared cache configuration helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.utilities.cache_config import (
|
||||
get_aiocache_config,
|
||||
parse_cache_url,
|
||||
use_valkey_cache,
|
||||
)
|
||||
|
||||
|
||||
class TestParseCacheUrl:
|
||||
"""Tests for parse_cache_url()."""
|
||||
|
||||
def test_returns_none_when_no_env_vars(self) -> None:
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
assert parse_cache_url() is None
|
||||
|
||||
def test_parses_valkey_url(self) -> None:
|
||||
with patch.dict(
|
||||
os.environ, {"VALKEY_URL": "redis://myhost:6380/2"}, clear=True
|
||||
):
|
||||
result = parse_cache_url()
|
||||
assert result is not None
|
||||
assert result["host"] == "myhost"
|
||||
assert result["port"] == 6380
|
||||
assert result["db"] == 2
|
||||
assert result["password"] is None
|
||||
|
||||
def test_parses_redis_url(self) -> None:
|
||||
with patch.dict(
|
||||
os.environ, {"REDIS_URL": "redis://localhost:6379/0"}, clear=True
|
||||
):
|
||||
result = parse_cache_url()
|
||||
assert result is not None
|
||||
assert result["host"] == "localhost"
|
||||
assert result["port"] == 6379
|
||||
assert result["db"] == 0
|
||||
|
||||
def test_valkey_url_takes_priority_over_redis_url(self) -> None:
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"VALKEY_URL": "redis://valkey-host:6380/1",
|
||||
"REDIS_URL": "redis://redis-host:6379/0",
|
||||
},
|
||||
clear=True,
|
||||
):
|
||||
result = parse_cache_url()
|
||||
assert result is not None
|
||||
assert result["host"] == "valkey-host"
|
||||
assert result["port"] == 6380
|
||||
|
||||
def test_parses_password(self) -> None:
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{"VALKEY_URL": "redis://:s3cret@myhost:6379/0"},
|
||||
clear=True,
|
||||
):
|
||||
result = parse_cache_url()
|
||||
assert result is not None
|
||||
assert result["password"] == "s3cret"
|
||||
|
||||
def test_defaults_for_minimal_url(self) -> None:
|
||||
with patch.dict(
|
||||
os.environ, {"VALKEY_URL": "redis://myhost"}, clear=True
|
||||
):
|
||||
result = parse_cache_url()
|
||||
assert result is not None
|
||||
assert result["host"] == "myhost"
|
||||
assert result["port"] == 6379
|
||||
assert result["db"] == 0
|
||||
assert result["password"] is None
|
||||
|
||||
def test_non_numeric_db_path_defaults_to_zero(self) -> None:
|
||||
with patch.dict(
|
||||
os.environ, {"VALKEY_URL": "redis://myhost:6379/mydb"}, clear=True
|
||||
):
|
||||
result = parse_cache_url()
|
||||
assert result is not None
|
||||
assert result["db"] == 0
|
||||
|
||||
|
||||
class TestGetAiocacheConfig:
|
||||
"""Tests for get_aiocache_config()."""
|
||||
|
||||
def test_returns_memory_cache_when_no_url(self) -> None:
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
config = get_aiocache_config()
|
||||
assert config["default"]["cache"] == "aiocache.SimpleMemoryCache"
|
||||
|
||||
def test_returns_redis_cache_when_url_set(self) -> None:
|
||||
with patch.dict(
|
||||
os.environ, {"VALKEY_URL": "redis://myhost:6380/2"}, clear=True
|
||||
):
|
||||
config = get_aiocache_config()
|
||||
assert config["default"]["cache"] == "aiocache.RedisCache"
|
||||
assert config["default"]["endpoint"] == "myhost"
|
||||
assert config["default"]["port"] == 6380
|
||||
assert config["default"]["db"] == 2
|
||||
|
||||
|
||||
class TestUseValkeyCache:
|
||||
"""Tests for use_valkey_cache()."""
|
||||
|
||||
def test_returns_false_when_not_set(self) -> None:
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
assert use_valkey_cache() is False
|
||||
|
||||
def test_returns_true_when_set(self) -> None:
|
||||
with patch.dict(
|
||||
os.environ, {"VALKEY_URL": "redis://localhost:6379"}, clear=True
|
||||
):
|
||||
assert use_valkey_cache() is True
|
||||
|
||||
def test_returns_false_when_only_redis_url_set(self) -> None:
|
||||
with patch.dict(
|
||||
os.environ, {"REDIS_URL": "redis://localhost:6379"}, clear=True
|
||||
):
|
||||
assert use_valkey_cache() is False
|
||||
@@ -205,8 +205,6 @@ override-dependencies = [
|
||||
"gitpython>=3.1.50,<4",
|
||||
"langsmith>=0.7.31,<0.8",
|
||||
"authlib>=1.6.11",
|
||||
# scrapegraph-py 2.x removed Client class; pin until upstream fixes type ignores
|
||||
"scrapegraph-py>=1.46.0,<2",
|
||||
]
|
||||
|
||||
[tool.uv.workspace]
|
||||
|
||||
46
uv.lock
generated
46
uv.lock
generated
@@ -13,7 +13,7 @@ resolution-markers = [
|
||||
]
|
||||
|
||||
[options]
|
||||
exclude-newer = "2026-05-08T20:07:25.621408Z"
|
||||
exclude-newer = "2026-05-08T16:33:02.834109Z"
|
||||
exclude-newer-span = "P3D"
|
||||
|
||||
[manifest]
|
||||
@@ -38,7 +38,6 @@ overrides = [
|
||||
{ name = "pypdf", specifier = ">=6.10.2,<7" },
|
||||
{ name = "python-multipart", specifier = ">=0.0.27,<1" },
|
||||
{ name = "rich", specifier = ">=13.7.1" },
|
||||
{ name = "scrapegraph-py", specifier = ">=1.46.0,<2" },
|
||||
{ name = "transformers", marker = "python_full_version >= '3.10'", specifier = ">=5.4.0" },
|
||||
{ name = "urllib3", specifier = ">=2.7.0" },
|
||||
{ name = "uv", specifier = ">=0.11.6,<1" },
|
||||
@@ -1366,9 +1365,6 @@ qdrant-edge = [
|
||||
tools = [
|
||||
{ name = "crewai-tools" },
|
||||
]
|
||||
valkey = [
|
||||
{ name = "valkey-glide" },
|
||||
]
|
||||
voyageai = [
|
||||
{ name = "voyageai" },
|
||||
]
|
||||
@@ -1430,10 +1426,9 @@ requires-dist = [
|
||||
{ name = "tokenizers", specifier = ">=0.21,<1" },
|
||||
{ name = "tomli", specifier = "~=2.0.2" },
|
||||
{ name = "tomli-w", specifier = "~=1.1.0" },
|
||||
{ name = "valkey-glide", marker = "extra == 'valkey'", specifier = ">=1.3.0" },
|
||||
{ name = "voyageai", marker = "extra == 'voyageai'", specifier = "~=0.3.5" },
|
||||
]
|
||||
provides-extras = ["a2a", "anthropic", "aws", "azure-ai-inference", "bedrock", "docling", "embeddings", "file-processing", "google-genai", "litellm", "mem0", "openpyxl", "pandas", "qdrant", "qdrant-edge", "tools", "valkey", "voyageai", "watson"]
|
||||
provides-extras = ["a2a", "anthropic", "aws", "azure-ai-inference", "bedrock", "docling", "embeddings", "file-processing", "google-genai", "litellm", "mem0", "openpyxl", "pandas", "qdrant", "qdrant-edge", "tools", "voyageai", "watson"]
|
||||
|
||||
[[package]]
|
||||
name = "crewai-cli"
|
||||
@@ -9538,43 +9533,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/fa/6e/3e955517e22cbdd565f2f8b2e73d52528b14b8bcfdb04f62466b071de847/validators-0.35.0-py3-none-any.whl", hash = "sha256:e8c947097eae7892cb3d26868d637f79f47b4a0554bc6b80065dfe5aac3705dd", size = 44712, upload-time = "2025-05-01T05:42:04.203Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "valkey-glide"
|
||||
version = "2.0.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "anyio" },
|
||||
{ name = "protobuf" },
|
||||
{ name = "typing-extensions", marker = "python_full_version < '3.11'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/32/35/fb0401c4bc7be748d937e95213786d21d9e56767b3ad816db5bad6f92c01/valkey_glide-2.0.1.tar.gz", hash = "sha256:4f9c62a88aedffd725cced7d28a9488b27e3f675d1a5294b4962624e97d346c4", size = 1026255, upload-time = "2025-06-20T01:08:15.861Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/44/a3/bf5ff3841538d0bb337371e073dc2c0e93f748f7f8b10a44806f36ab5fa1/valkey_glide-2.0.1-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:b3307934b76557b18ac559f327592cc09fc895fc653ba46010dd6d70fb6239dc", size = 5074638, upload-time = "2025-06-20T01:07:30.16Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0f/c4/20b66dced96bdca81aa294b39bc03018ed22628c52076752e8d1d3540a7d/valkey_glide-2.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6b83d34e2e723e97c41682479b0dce5882069066e808316292b363855992b449", size = 4750261, upload-time = "2025-06-20T01:07:32.452Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/53/58/6440e66bde8963d86bc3c44d88f993059f2a9d7ebdb3256a695d035cff50/valkey_glide-2.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1baaf14d09d464ae645be5bdb5dc6b8a38b7eacf22f9dcb2907200c74fbdcdd3", size = 4767755, upload-time = "2025-06-20T01:07:33.86Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3b/69/dd5c350ce4d2cadde0d83beb601f05e1e62622895f268135e252e8bfc307/valkey_glide-2.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4427e7b4d54c9de289a35032c19d5956f94376f5d4335206c5ac4524cbd1c64a", size = 5094507, upload-time = "2025-06-20T01:07:35.349Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b5/dd/0dd6614e09123a5bd7273bf1159c958d1ea65e7decc2190b225d212e0cb9/valkey_glide-2.0.1-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:6379582d6fbd817697fb119274e37d397db450103cd15d4bd71e555e6d88fb6b", size = 5072939, upload-time = "2025-06-20T01:07:36.948Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c6/04/986188e407231a5f0bfaf31f31b68e3605ab66f4f4c656adfbb0345669d9/valkey_glide-2.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:0f1c0fe003026d8ae172369e0eb2337cbff16f41d4c085332487d6ca2e5282e6", size = 4750491, upload-time = "2025-06-20T01:07:38.659Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ac/fb/2f5cec71ae51c464502a892b6825426cd74a2c325827981726e557926c94/valkey_glide-2.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82c5f33598e50bcfec6fc924864931f3c6e30cd327a9c9562e1c7ac4e17e79fd", size = 4767597, upload-time = "2025-06-20T01:07:40.091Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3a/31/851a1a734fe5da5d520106fcfd824e4da09c3be8a0a2123bb4b1980db1ea/valkey_glide-2.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:79039a9dc23bb074680f171c12b36b3322357a0af85125534993e81a619dce21", size = 5094383, upload-time = "2025-06-20T01:07:41.329Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fc/6d/1e7b432cbc02fe63e7496b984b7fc830fb7de388c877b237e0579a6300fc/valkey_glide-2.0.1-cp312-cp312-macosx_10_7_x86_64.whl", hash = "sha256:f55ec8968b0fde364a5b3399be34b89dcb9068994b5cd384e20db0773ad12723", size = 5075024, upload-time = "2025-06-20T01:07:42.917Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ca/39/6e9f83970590d17d19f596e1b3a366d39077624888e3dd709309efc67690/valkey_glide-2.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:21598f49313912ad27dc700d7b13a3b4bfed7ed9dffad207235cac7d218f4966", size = 4748418, upload-time = "2025-06-20T01:07:44.64Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/98/0e/91335c13dc8e7ceb95063234c16010b46e2dd874a2edef62dea155081647/valkey_glide-2.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f662285146529328e2b5a0a7047f699339b4e0d250eb1f252b15c9befa0dea05", size = 4767264, upload-time = "2025-06-20T01:07:46.185Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5f/94/ee4d9d441f83fec1464d9f4e52f7940bdd2aeb917589e6abd57498880876/valkey_glide-2.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3939aaa8411fcbba00cb1ff7c7ba73f388bb1deca919972f65cba7eda1d5fa95", size = 5093543, upload-time = "2025-06-20T01:07:47.345Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ed/7e/257a2e4b61ac29d5923f89bad5fe62be7b4a19e7bec78d191af3ce77aa39/valkey_glide-2.0.1-cp313-cp313-macosx_10_7_x86_64.whl", hash = "sha256:c49b53011a05b5820d0c660ee5c76574183b413a54faa33cf5c01ce77164d9c8", size = 5073114, upload-time = "2025-06-20T01:07:48.885Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/20/14/a8a470679953980af7eac3ccb09638f2a76d4547116d48cbc69ae6f25080/valkey_glide-2.0.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:3a23572b83877537916ba36ad0a6b2fd96581534f0bc67ef8f8498bf4dbb2b40", size = 4747717, upload-time = "2025-06-20T01:07:50.092Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9f/49/f168dd0c778d9f6ff1be70d5d3bad7a86928fee563de7de5f4f575eddfd8/valkey_glide-2.0.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:943a2c4a5c38b8a6b53281201d5a4997ec454a6fdda72d27050eeb6aaef12afb", size = 4767128, upload-time = "2025-06-20T01:07:51.306Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/43/be/68961b14ea133d1792ce50f6df1753848b5377c3e06a8dbe4e39188a549a/valkey_glide-2.0.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d770ec581acc59d5597e7ccaac37aee7e3b5e716a77a7fa44e2967db3a715f53", size = 5093522, upload-time = "2025-06-20T01:07:52.546Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/51/2e/ad8595ffe84317385d52ceab8de1e9ef06a4da6b81ca8cd61b7961923de4/valkey_glide-2.0.1-pp310-pypy310_pp73-macosx_10_7_x86_64.whl", hash = "sha256:d4a9ccfe2b190c90622849dab62f9468acf76a282719a1245d272b649e7c12d1", size = 5074539, upload-time = "2025-06-20T01:07:59.87Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/db/e5/2122541c7a64706f3631655209bb0b13723fb99db3c190d9a792b4e7d494/valkey_glide-2.0.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:9aa004077b82f64b23ea0d38d948b5116c23f7228dae3a5b4fcfa1799f8ff7de", size = 4753222, upload-time = "2025-06-20T01:08:01.376Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6c/13/cd9a20988a820ff61b127d3f850887b28bb734daf2c26d512d8e4c2e8e9e/valkey_glide-2.0.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:631a7a0e2045f7e5e3706e1903beeddf381a6529e318c27230798f4382579e4f", size = 4771530, upload-time = "2025-06-20T01:08:02.6Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c7/fc/047e89cc01b4cc71db1b6b8160d3b5d050097b408028022c002351238641/valkey_glide-2.0.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c5ed905fb62368c9bc6aef9df8d66269ef51f968dc527da4d7c956927382c1d", size = 5091242, upload-time = "2025-06-20T01:08:04.111Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1c/9e/68790c1a263f3a0094d67d0109be34631f6f79c2fbce5ced7e33a65ad363/valkey_glide-2.0.1-pp311-pypy311_pp73-macosx_10_7_x86_64.whl", hash = "sha256:53da3cc47c8d946ac76ecc4b468a469d3486778833a59162ea69aa7ce70cbb27", size = 5072793, upload-time = "2025-06-20T01:08:05.562Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1f/ae/a935af65ae4069d76c69f28f6bfb4533da8b89f7fc418beb7a1482cdd9ee/valkey_glide-2.0.1-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:e526a7d718cdd299d6b03091c12dcc15cd02ff22fe420f253341a4891c50824d", size = 4753435, upload-time = "2025-06-20T01:08:07.149Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3b/c2/c91d753a89dd87dce2fc8932cfbe174c7a1226c657b3cd64c063f21d4fe6/valkey_glide-2.0.1-pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8d3345ea2adf6f745733fa5157d8709bcf5ffbb2674391aeebd8f166a37cbc96", size = 4771401, upload-time = "2025-06-20T01:08:08.359Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/00/fe/ad83cfc2ac87bf6bad2b75fa64fca5a6dd54568c1de551d36d369e07f948/valkey_glide-2.0.1-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d1c5fff0f12d2aa4277ddc335035b2c8e12bb11243c1a0f3c35071f4a8b11064", size = 5091360, upload-time = "2025-06-20T01:08:09.622Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "vcrpy"
|
||||
version = "7.0.0"
|
||||
|
||||
Reference in New Issue
Block a user