mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-23 15:18:14 +00:00
feat: upgrade upload cache to aiocache with atexit cleanup
This commit is contained in:
@@ -177,10 +177,4 @@ def _get_providers_from_cache(cache: UploadCache) -> set[str]:
|
||||
Returns:
|
||||
Set of provider names.
|
||||
"""
|
||||
providers: set[str] = set()
|
||||
|
||||
with cache._lock:
|
||||
for _, provider in cache._cache.keys():
|
||||
providers.add(provider)
|
||||
|
||||
return providers
|
||||
return cache.get_providers()
|
||||
|
||||
@@ -1,24 +1,36 @@
|
||||
"""Cache for tracking uploaded files to avoid redundant uploads."""
|
||||
"""Cache for tracking uploaded files using aiocache."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
import builtins
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
import hashlib
|
||||
import logging
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.utilities.files.content_types import (
|
||||
AudioFile,
|
||||
File,
|
||||
ImageFile,
|
||||
PDFFile,
|
||||
TextFile,
|
||||
VideoFile,
|
||||
)
|
||||
from aiocache import Cache # type: ignore[import-untyped]
|
||||
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.utilities.files.content_types import (
|
||||
AudioFile,
|
||||
File,
|
||||
ImageFile,
|
||||
PDFFile,
|
||||
TextFile,
|
||||
VideoFile,
|
||||
)
|
||||
|
||||
FileInput = AudioFile | File | ImageFile | PDFFile | TextFile | VideoFile
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
FileInput = AudioFile | File | ImageFile | PDFFile | TextFile | VideoFile
|
||||
DEFAULT_TTL_SECONDS = 24 * 60 * 60 # 24 hours
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -42,44 +54,83 @@ class CachedUpload:
|
||||
expires_at: datetime | None = None
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if this cached upload has expired.
|
||||
|
||||
Returns:
|
||||
True if expired, False otherwise.
|
||||
"""
|
||||
"""Check if this cached upload has expired."""
|
||||
if self.expires_at is None:
|
||||
return False
|
||||
return datetime.now(timezone.utc) >= self.expires_at
|
||||
|
||||
|
||||
@dataclass
|
||||
class UploadCache:
|
||||
"""Thread-safe cache for tracking uploaded files.
|
||||
def _make_key(file_hash: str, provider: str) -> str:
|
||||
"""Create a cache key from file hash and provider."""
|
||||
return f"upload:{provider}:{file_hash}"
|
||||
|
||||
Uses file content hash and provider as composite key to avoid
|
||||
uploading the same file multiple times.
|
||||
|
||||
def _compute_file_hash(file: FileInput) -> str:
|
||||
"""Compute SHA-256 hash of file content."""
|
||||
content = file.read()
|
||||
return hashlib.sha256(content).hexdigest()
|
||||
|
||||
|
||||
class UploadCache:
|
||||
"""Async cache for tracking uploaded files using aiocache.
|
||||
|
||||
Supports in-memory caching by default, with optional Redis backend
|
||||
for distributed setups.
|
||||
|
||||
Attributes:
|
||||
default_ttl: Default time-to-live for cached entries.
|
||||
ttl: Default time-to-live in seconds for cached entries.
|
||||
namespace: Cache namespace for isolation.
|
||||
"""
|
||||
|
||||
default_ttl: timedelta = field(default_factory=lambda: timedelta(hours=24))
|
||||
_cache: dict[tuple[str, str], CachedUpload] = field(default_factory=dict)
|
||||
_lock: threading.Lock = field(default_factory=threading.Lock)
|
||||
|
||||
def _compute_hash(self, file: FileInput) -> str:
|
||||
"""Compute a hash of file content for cache key.
|
||||
def __init__(
|
||||
self,
|
||||
ttl: int = DEFAULT_TTL_SECONDS,
|
||||
namespace: str = "crewai_uploads",
|
||||
cache_type: str = "memory",
|
||||
**cache_kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the upload cache.
|
||||
|
||||
Args:
|
||||
file: The file to hash.
|
||||
|
||||
Returns:
|
||||
SHA-256 hash of the file content.
|
||||
ttl: Default TTL in seconds.
|
||||
namespace: Cache namespace.
|
||||
cache_type: Backend type ("memory" or "redis").
|
||||
**cache_kwargs: Additional args for cache backend.
|
||||
"""
|
||||
content = file.source.read()
|
||||
return hashlib.sha256(content).hexdigest()
|
||||
self.ttl = ttl
|
||||
self.namespace = namespace
|
||||
self._provider_keys: dict[str, set[str]] = {}
|
||||
|
||||
def get(self, file: FileInput, provider: str) -> CachedUpload | None:
|
||||
if cache_type == "redis":
|
||||
self._cache = Cache(
|
||||
Cache.REDIS,
|
||||
serializer=PickleSerializer(),
|
||||
namespace=namespace,
|
||||
**cache_kwargs,
|
||||
)
|
||||
else:
|
||||
self._cache = Cache(
|
||||
Cache.MEMORY,
|
||||
serializer=PickleSerializer(),
|
||||
namespace=namespace,
|
||||
)
|
||||
|
||||
def _track_key(self, provider: str, key: str) -> None:
|
||||
"""Track a key for a provider (for cleanup)."""
|
||||
if provider not in self._provider_keys:
|
||||
self._provider_keys[provider] = set()
|
||||
self._provider_keys[provider].add(key)
|
||||
|
||||
def _untrack_key(self, provider: str, key: str) -> None:
|
||||
"""Remove key tracking for a provider."""
|
||||
if provider in self._provider_keys:
|
||||
self._provider_keys[provider].discard(key)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Async methods (primary interface)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def aget(self, file: FileInput, provider: str) -> CachedUpload | None:
|
||||
"""Get a cached upload for a file.
|
||||
|
||||
Args:
|
||||
@@ -89,21 +140,10 @@ class UploadCache:
|
||||
Returns:
|
||||
Cached upload if found and not expired, None otherwise.
|
||||
"""
|
||||
file_hash = self._compute_hash(file)
|
||||
key = (file_hash, provider)
|
||||
file_hash = _compute_file_hash(file)
|
||||
return await self.aget_by_hash(file_hash, provider)
|
||||
|
||||
with self._lock:
|
||||
cached = self._cache.get(key)
|
||||
if cached is None:
|
||||
return None
|
||||
|
||||
if cached.is_expired():
|
||||
del self._cache[key]
|
||||
return None
|
||||
|
||||
return cached
|
||||
|
||||
def get_by_hash(self, file_hash: str, provider: str) -> CachedUpload | None:
|
||||
async def aget_by_hash(self, file_hash: str, provider: str) -> CachedUpload | None:
|
||||
"""Get a cached upload by file hash.
|
||||
|
||||
Args:
|
||||
@@ -113,20 +153,20 @@ class UploadCache:
|
||||
Returns:
|
||||
Cached upload if found and not expired, None otherwise.
|
||||
"""
|
||||
key = (file_hash, provider)
|
||||
key = _make_key(file_hash, provider)
|
||||
result = await self._cache.get(key)
|
||||
|
||||
with self._lock:
|
||||
cached = self._cache.get(key)
|
||||
if cached is None:
|
||||
if result is None:
|
||||
return None
|
||||
if isinstance(result, CachedUpload):
|
||||
if result.is_expired():
|
||||
await self._cache.delete(key)
|
||||
self._untrack_key(provider, key)
|
||||
return None
|
||||
return result
|
||||
return None
|
||||
|
||||
if cached.is_expired():
|
||||
del self._cache[key]
|
||||
return None
|
||||
|
||||
return cached
|
||||
|
||||
def set(
|
||||
async def aset(
|
||||
self,
|
||||
file: FileInput,
|
||||
provider: str,
|
||||
@@ -146,26 +186,17 @@ class UploadCache:
|
||||
Returns:
|
||||
The created cache entry.
|
||||
"""
|
||||
file_hash = self._compute_hash(file)
|
||||
key = (file_hash, provider)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
cached = CachedUpload(
|
||||
file_id=file_id,
|
||||
provider=provider,
|
||||
file_uri=file_uri,
|
||||
file_hash = _compute_file_hash(file)
|
||||
return await self.aset_by_hash(
|
||||
file_hash=file_hash,
|
||||
content_type=file.content_type,
|
||||
uploaded_at=now,
|
||||
provider=provider,
|
||||
file_id=file_id,
|
||||
file_uri=file_uri,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
with self._lock:
|
||||
self._cache[key] = cached
|
||||
|
||||
logger.debug(f"Cached upload: {file_id} for provider {provider}")
|
||||
return cached
|
||||
|
||||
def set_by_hash(
|
||||
async def aset_by_hash(
|
||||
self,
|
||||
file_hash: str,
|
||||
content_type: str,
|
||||
@@ -187,9 +218,9 @@ class UploadCache:
|
||||
Returns:
|
||||
The created cache entry.
|
||||
"""
|
||||
key = (file_hash, provider)
|
||||
|
||||
key = _make_key(file_hash, provider)
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
cached = CachedUpload(
|
||||
file_id=file_id,
|
||||
provider=provider,
|
||||
@@ -199,13 +230,16 @@ class UploadCache:
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
with self._lock:
|
||||
self._cache[key] = cached
|
||||
ttl = self.ttl
|
||||
if expires_at is not None:
|
||||
ttl = max(0, int((expires_at - now).total_seconds()))
|
||||
|
||||
await self._cache.set(key, cached, ttl=ttl)
|
||||
self._track_key(provider, key)
|
||||
logger.debug(f"Cached upload: {file_id} for provider {provider}")
|
||||
return cached
|
||||
|
||||
def remove(self, file: FileInput, provider: str) -> bool:
|
||||
async def aremove(self, file: FileInput, provider: str) -> bool:
|
||||
"""Remove a cached upload.
|
||||
|
||||
Args:
|
||||
@@ -215,16 +249,16 @@ class UploadCache:
|
||||
Returns:
|
||||
True if entry was removed, False if not found.
|
||||
"""
|
||||
file_hash = self._compute_hash(file)
|
||||
key = (file_hash, provider)
|
||||
file_hash = _compute_file_hash(file)
|
||||
key = _make_key(file_hash, provider)
|
||||
|
||||
with self._lock:
|
||||
if key in self._cache:
|
||||
del self._cache[key]
|
||||
return True
|
||||
return False
|
||||
result = await self._cache.delete(key)
|
||||
removed = bool(result > 0 if isinstance(result, int) else result)
|
||||
if removed:
|
||||
self._untrack_key(provider, key)
|
||||
return removed
|
||||
|
||||
def remove_by_file_id(self, file_id: str, provider: str) -> bool:
|
||||
async def aremove_by_file_id(self, file_id: str, provider: str) -> bool:
|
||||
"""Remove a cached upload by file ID.
|
||||
|
||||
Args:
|
||||
@@ -234,14 +268,18 @@ class UploadCache:
|
||||
Returns:
|
||||
True if entry was removed, False if not found.
|
||||
"""
|
||||
with self._lock:
|
||||
for key, cached in list(self._cache.items()):
|
||||
if cached.file_id == file_id and cached.provider == provider:
|
||||
del self._cache[key]
|
||||
return True
|
||||
if provider not in self._provider_keys:
|
||||
return False
|
||||
|
||||
def clear_expired(self) -> int:
|
||||
for key in list(self._provider_keys[provider]):
|
||||
cached = await self._cache.get(key)
|
||||
if isinstance(cached, CachedUpload) and cached.file_id == file_id:
|
||||
await self._cache.delete(key)
|
||||
self._untrack_key(provider, key)
|
||||
return True
|
||||
return False
|
||||
|
||||
async def aclear_expired(self) -> int:
|
||||
"""Remove all expired entries from the cache.
|
||||
|
||||
Returns:
|
||||
@@ -249,33 +287,35 @@ class UploadCache:
|
||||
"""
|
||||
removed = 0
|
||||
|
||||
with self._lock:
|
||||
for key in list(self._cache.keys()):
|
||||
if self._cache[key].is_expired():
|
||||
del self._cache[key]
|
||||
for provider, keys in list(self._provider_keys.items()):
|
||||
for key in list(keys):
|
||||
cached = await self._cache.get(key)
|
||||
if cached is None or (
|
||||
isinstance(cached, CachedUpload) and cached.is_expired()
|
||||
):
|
||||
await self._cache.delete(key)
|
||||
self._untrack_key(provider, key)
|
||||
removed += 1
|
||||
|
||||
if removed > 0:
|
||||
logger.debug(f"Cleared {removed} expired cache entries")
|
||||
|
||||
return removed
|
||||
|
||||
def clear(self) -> int:
|
||||
async def aclear(self) -> int:
|
||||
"""Clear all entries from the cache.
|
||||
|
||||
Returns:
|
||||
Number of entries cleared.
|
||||
"""
|
||||
with self._lock:
|
||||
count = len(self._cache)
|
||||
self._cache.clear()
|
||||
count = sum(len(keys) for keys in self._provider_keys.values())
|
||||
await self._cache.clear(namespace=self.namespace)
|
||||
self._provider_keys.clear()
|
||||
|
||||
if count > 0:
|
||||
logger.debug(f"Cleared {count} cache entries")
|
||||
|
||||
return count
|
||||
|
||||
def get_all_for_provider(self, provider: str) -> list[CachedUpload]:
|
||||
async def aget_all_for_provider(self, provider: str) -> list[CachedUpload]:
|
||||
"""Get all cached uploads for a provider.
|
||||
|
||||
Args:
|
||||
@@ -284,14 +324,171 @@ class UploadCache:
|
||||
Returns:
|
||||
List of cached uploads for the provider.
|
||||
"""
|
||||
with self._lock:
|
||||
return [
|
||||
cached
|
||||
for (_, p), cached in self._cache.items()
|
||||
if p == provider and not cached.is_expired()
|
||||
]
|
||||
if provider not in self._provider_keys:
|
||||
return []
|
||||
|
||||
results: list[CachedUpload] = []
|
||||
for key in list(self._provider_keys[provider]):
|
||||
cached = await self._cache.get(key)
|
||||
if isinstance(cached, CachedUpload) and not cached.is_expired():
|
||||
results.append(cached)
|
||||
return results
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Sync wrappers (convenience)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def _run_sync(self, coro: Any) -> Any:
|
||||
"""Run an async coroutine from sync context."""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
|
||||
if loop is not None and loop.is_running():
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||
future = pool.submit(asyncio.run, coro)
|
||||
return future.result()
|
||||
return asyncio.run(coro)
|
||||
|
||||
def get(self, file: FileInput, provider: str) -> CachedUpload | None:
|
||||
"""Sync wrapper for aget."""
|
||||
result: CachedUpload | None = self._run_sync(self.aget(file, provider))
|
||||
return result
|
||||
|
||||
def get_by_hash(self, file_hash: str, provider: str) -> CachedUpload | None:
|
||||
"""Sync wrapper for aget_by_hash."""
|
||||
result: CachedUpload | None = self._run_sync(
|
||||
self.aget_by_hash(file_hash, provider)
|
||||
)
|
||||
return result
|
||||
|
||||
def set(
|
||||
self,
|
||||
file: FileInput,
|
||||
provider: str,
|
||||
file_id: str,
|
||||
file_uri: str | None = None,
|
||||
expires_at: datetime | None = None,
|
||||
) -> CachedUpload:
|
||||
"""Sync wrapper for aset."""
|
||||
result: CachedUpload = self._run_sync(
|
||||
self.aset(file, provider, file_id, file_uri, expires_at)
|
||||
)
|
||||
return result
|
||||
|
||||
def set_by_hash(
|
||||
self,
|
||||
file_hash: str,
|
||||
content_type: str,
|
||||
provider: str,
|
||||
file_id: str,
|
||||
file_uri: str | None = None,
|
||||
expires_at: datetime | None = None,
|
||||
) -> CachedUpload:
|
||||
"""Sync wrapper for aset_by_hash."""
|
||||
result: CachedUpload = self._run_sync(
|
||||
self.aset_by_hash(
|
||||
file_hash, content_type, provider, file_id, file_uri, expires_at
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
def remove(self, file: FileInput, provider: str) -> bool:
|
||||
"""Sync wrapper for aremove."""
|
||||
result: bool = self._run_sync(self.aremove(file, provider))
|
||||
return result
|
||||
|
||||
def remove_by_file_id(self, file_id: str, provider: str) -> bool:
|
||||
"""Sync wrapper for aremove_by_file_id."""
|
||||
result: bool = self._run_sync(self.aremove_by_file_id(file_id, provider))
|
||||
return result
|
||||
|
||||
def clear_expired(self) -> int:
|
||||
"""Sync wrapper for aclear_expired."""
|
||||
result: int = self._run_sync(self.aclear_expired())
|
||||
return result
|
||||
|
||||
def clear(self) -> int:
|
||||
"""Sync wrapper for aclear."""
|
||||
result: int = self._run_sync(self.aclear())
|
||||
return result
|
||||
|
||||
def get_all_for_provider(self, provider: str) -> list[CachedUpload]:
|
||||
"""Sync wrapper for aget_all_for_provider."""
|
||||
result: list[CachedUpload] = self._run_sync(
|
||||
self.aget_all_for_provider(provider)
|
||||
)
|
||||
return result
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of cached entries."""
|
||||
with self._lock:
|
||||
return len(self._cache)
|
||||
return sum(len(keys) for keys in self._provider_keys.values())
|
||||
|
||||
def get_providers(self) -> builtins.set[str]:
|
||||
"""Get all provider names that have cached entries.
|
||||
|
||||
Returns:
|
||||
Set of provider names.
|
||||
"""
|
||||
return builtins.set(self._provider_keys.keys())
|
||||
|
||||
|
||||
# Module-level cache instance
|
||||
_default_cache: UploadCache | None = None
|
||||
|
||||
|
||||
def get_upload_cache(
|
||||
ttl: int = DEFAULT_TTL_SECONDS,
|
||||
namespace: str = "crewai_uploads",
|
||||
cache_type: str = "memory",
|
||||
**cache_kwargs: Any,
|
||||
) -> UploadCache:
|
||||
"""Get or create the default upload cache.
|
||||
|
||||
Args:
|
||||
ttl: Default TTL in seconds.
|
||||
namespace: Cache namespace.
|
||||
cache_type: Backend type ("memory" or "redis").
|
||||
**cache_kwargs: Additional args for cache backend.
|
||||
|
||||
Returns:
|
||||
The upload cache instance.
|
||||
"""
|
||||
global _default_cache
|
||||
if _default_cache is None:
|
||||
_default_cache = UploadCache(
|
||||
ttl=ttl,
|
||||
namespace=namespace,
|
||||
cache_type=cache_type,
|
||||
**cache_kwargs,
|
||||
)
|
||||
return _default_cache
|
||||
|
||||
|
||||
def reset_upload_cache() -> None:
|
||||
"""Reset the default upload cache (useful for testing)."""
|
||||
global _default_cache
|
||||
if _default_cache is not None:
|
||||
_default_cache.clear()
|
||||
_default_cache = None
|
||||
|
||||
|
||||
def _cleanup_on_exit() -> None:
|
||||
"""Clean up uploaded files on process exit."""
|
||||
global _default_cache
|
||||
if _default_cache is None or len(_default_cache) == 0:
|
||||
return
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from crewai.utilities.files.cleanup import cleanup_uploaded_files
|
||||
|
||||
try:
|
||||
cleanup_uploaded_files(_default_cache, delete_from_provider=True)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error during exit cleanup: {e}")
|
||||
|
||||
|
||||
atexit.register(_cleanup_on_exit)
|
||||
|
||||
206
lib/crewai/tests/utilities/files/test_upload_cache.py
Normal file
206
lib/crewai/tests/utilities/files/test_upload_cache.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""Tests for upload cache."""
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.utilities.files import FileBytes, ImageFile
|
||||
from crewai.utilities.files.upload_cache import CachedUpload, UploadCache
|
||||
|
||||
|
||||
# Minimal valid PNG
|
||||
MINIMAL_PNG = (
|
||||
b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x08\x00\x00\x00\x08"
|
||||
b"\x01\x00\x00\x00\x00\xf9Y\xab\xcd\x00\x00\x00\nIDATx\x9cc`\x00\x00"
|
||||
b"\x00\x02\x00\x01\xe2!\xbc3\x00\x00\x00\x00IEND\xaeB`\x82"
|
||||
)
|
||||
|
||||
|
||||
class TestCachedUpload:
|
||||
"""Tests for CachedUpload dataclass."""
|
||||
|
||||
def test_cached_upload_creation(self):
|
||||
"""Test creating a cached upload."""
|
||||
now = datetime.now(timezone.utc)
|
||||
cached = CachedUpload(
|
||||
file_id="file-123",
|
||||
provider="gemini",
|
||||
file_uri="files/file-123",
|
||||
content_type="image/png",
|
||||
uploaded_at=now,
|
||||
expires_at=now + timedelta(hours=48),
|
||||
)
|
||||
|
||||
assert cached.file_id == "file-123"
|
||||
assert cached.provider == "gemini"
|
||||
assert cached.file_uri == "files/file-123"
|
||||
assert cached.content_type == "image/png"
|
||||
|
||||
def test_is_expired_false(self):
|
||||
"""Test is_expired returns False for non-expired upload."""
|
||||
future = datetime.now(timezone.utc) + timedelta(hours=24)
|
||||
cached = CachedUpload(
|
||||
file_id="file-123",
|
||||
provider="gemini",
|
||||
file_uri=None,
|
||||
content_type="image/png",
|
||||
uploaded_at=datetime.now(timezone.utc),
|
||||
expires_at=future,
|
||||
)
|
||||
|
||||
assert cached.is_expired() is False
|
||||
|
||||
def test_is_expired_true(self):
|
||||
"""Test is_expired returns True for expired upload."""
|
||||
past = datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
cached = CachedUpload(
|
||||
file_id="file-123",
|
||||
provider="gemini",
|
||||
file_uri=None,
|
||||
content_type="image/png",
|
||||
uploaded_at=datetime.now(timezone.utc) - timedelta(hours=2),
|
||||
expires_at=past,
|
||||
)
|
||||
|
||||
assert cached.is_expired() is True
|
||||
|
||||
def test_is_expired_no_expiry(self):
|
||||
"""Test is_expired returns False when no expiry set."""
|
||||
cached = CachedUpload(
|
||||
file_id="file-123",
|
||||
provider="anthropic",
|
||||
file_uri=None,
|
||||
content_type="image/png",
|
||||
uploaded_at=datetime.now(timezone.utc),
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
assert cached.is_expired() is False
|
||||
|
||||
|
||||
class TestUploadCache:
|
||||
"""Tests for UploadCache class."""
|
||||
|
||||
def test_cache_creation(self):
|
||||
"""Test creating an empty cache."""
|
||||
cache = UploadCache()
|
||||
|
||||
assert len(cache) == 0
|
||||
|
||||
def test_set_and_get(self):
|
||||
"""Test setting and getting cached uploads."""
|
||||
cache = UploadCache()
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
cached = cache.set(
|
||||
file=file,
|
||||
provider="gemini",
|
||||
file_id="file-123",
|
||||
file_uri="files/file-123",
|
||||
)
|
||||
|
||||
result = cache.get(file, "gemini")
|
||||
|
||||
assert result is not None
|
||||
assert result.file_id == "file-123"
|
||||
assert result.provider == "gemini"
|
||||
|
||||
def test_get_missing(self):
|
||||
"""Test getting non-existent entry returns None."""
|
||||
cache = UploadCache()
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
result = cache.get(file, "gemini")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_get_different_provider(self):
|
||||
"""Test getting with different provider returns None."""
|
||||
cache = UploadCache()
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
cache.set(file=file, provider="gemini", file_id="file-123")
|
||||
|
||||
result = cache.get(file, "anthropic") # Different provider
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_remove(self):
|
||||
"""Test removing cached entry."""
|
||||
cache = UploadCache()
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
cache.set(file=file, provider="gemini", file_id="file-123")
|
||||
removed = cache.remove(file, "gemini")
|
||||
|
||||
assert removed is True
|
||||
assert cache.get(file, "gemini") is None
|
||||
|
||||
def test_remove_missing(self):
|
||||
"""Test removing non-existent entry returns False."""
|
||||
cache = UploadCache()
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
removed = cache.remove(file, "gemini")
|
||||
|
||||
assert removed is False
|
||||
|
||||
def test_remove_by_file_id(self):
|
||||
"""Test removing by file ID."""
|
||||
cache = UploadCache()
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
cache.set(file=file, provider="gemini", file_id="file-123")
|
||||
removed = cache.remove_by_file_id("file-123", "gemini")
|
||||
|
||||
assert removed is True
|
||||
assert len(cache) == 0
|
||||
|
||||
def test_clear_expired(self):
|
||||
"""Test clearing expired entries."""
|
||||
cache = UploadCache()
|
||||
file1 = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test1.png"))
|
||||
file2 = ImageFile(source=FileBytes(data=MINIMAL_PNG + b"x", filename="test2.png"))
|
||||
|
||||
# Add one expired and one valid entry
|
||||
past = datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
future = datetime.now(timezone.utc) + timedelta(hours=24)
|
||||
|
||||
cache.set(file=file1, provider="gemini", file_id="expired", expires_at=past)
|
||||
cache.set(file=file2, provider="gemini", file_id="valid", expires_at=future)
|
||||
|
||||
removed = cache.clear_expired()
|
||||
|
||||
assert removed == 1
|
||||
assert len(cache) == 1
|
||||
assert cache.get(file2, "gemini") is not None
|
||||
|
||||
def test_clear(self):
|
||||
"""Test clearing all entries."""
|
||||
cache = UploadCache()
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
cache.set(file=file, provider="gemini", file_id="file-123")
|
||||
cache.set(file=file, provider="anthropic", file_id="file-456")
|
||||
|
||||
cleared = cache.clear()
|
||||
|
||||
assert cleared == 2
|
||||
assert len(cache) == 0
|
||||
|
||||
def test_get_all_for_provider(self):
|
||||
"""Test getting all cached uploads for a provider."""
|
||||
cache = UploadCache()
|
||||
file1 = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test1.png"))
|
||||
file2 = ImageFile(source=FileBytes(data=MINIMAL_PNG + b"x", filename="test2.png"))
|
||||
file3 = ImageFile(source=FileBytes(data=MINIMAL_PNG + b"xx", filename="test3.png"))
|
||||
|
||||
cache.set(file=file1, provider="gemini", file_id="file-1")
|
||||
cache.set(file=file2, provider="gemini", file_id="file-2")
|
||||
cache.set(file=file3, provider="anthropic", file_id="file-3")
|
||||
|
||||
gemini_uploads = cache.get_all_for_provider("gemini")
|
||||
anthropic_uploads = cache.get_all_for_provider("anthropic")
|
||||
|
||||
assert len(gemini_uploads) == 2
|
||||
assert len(anthropic_uploads) == 1
|
||||
Reference in New Issue
Block a user