feat: upgrade upload cache to aiocache with atexit cleanup

This commit is contained in:
Greyson LaLonde
2026-01-21 19:35:56 -05:00
parent d8ebfe7ee0
commit 42ca4eacff
3 changed files with 520 additions and 123 deletions

View File

@@ -177,10 +177,4 @@ def _get_providers_from_cache(cache: UploadCache) -> set[str]:
Returns:
Set of provider names.
"""
providers: set[str] = set()
with cache._lock:
for _, provider in cache._cache.keys():
providers.add(provider)
return providers
return cache.get_providers()

View File

@@ -1,24 +1,36 @@
"""Cache for tracking uploaded files to avoid redundant uploads."""
"""Cache for tracking uploaded files using aiocache."""
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from __future__ import annotations
import asyncio
import atexit
import builtins
from dataclasses import dataclass
from datetime import datetime, timezone
import hashlib
import logging
import threading
from typing import TYPE_CHECKING, Any
from crewai.utilities.files.content_types import (
AudioFile,
File,
ImageFile,
PDFFile,
TextFile,
VideoFile,
)
from aiocache import Cache # type: ignore[import-untyped]
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
if TYPE_CHECKING:
from crewai.utilities.files.content_types import (
AudioFile,
File,
ImageFile,
PDFFile,
TextFile,
VideoFile,
)
FileInput = AudioFile | File | ImageFile | PDFFile | TextFile | VideoFile
logger = logging.getLogger(__name__)
FileInput = AudioFile | File | ImageFile | PDFFile | TextFile | VideoFile
DEFAULT_TTL_SECONDS = 24 * 60 * 60 # 24 hours
@dataclass
@@ -42,44 +54,83 @@ class CachedUpload:
expires_at: datetime | None = None
def is_expired(self) -> bool:
"""Check if this cached upload has expired.
Returns:
True if expired, False otherwise.
"""
"""Check if this cached upload has expired."""
if self.expires_at is None:
return False
return datetime.now(timezone.utc) >= self.expires_at
@dataclass
class UploadCache:
"""Thread-safe cache for tracking uploaded files.
def _make_key(file_hash: str, provider: str) -> str:
"""Create a cache key from file hash and provider."""
return f"upload:{provider}:{file_hash}"
Uses file content hash and provider as composite key to avoid
uploading the same file multiple times.
def _compute_file_hash(file: FileInput) -> str:
"""Compute SHA-256 hash of file content."""
content = file.read()
return hashlib.sha256(content).hexdigest()
class UploadCache:
"""Async cache for tracking uploaded files using aiocache.
Supports in-memory caching by default, with optional Redis backend
for distributed setups.
Attributes:
default_ttl: Default time-to-live for cached entries.
ttl: Default time-to-live in seconds for cached entries.
namespace: Cache namespace for isolation.
"""
default_ttl: timedelta = field(default_factory=lambda: timedelta(hours=24))
_cache: dict[tuple[str, str], CachedUpload] = field(default_factory=dict)
_lock: threading.Lock = field(default_factory=threading.Lock)
def _compute_hash(self, file: FileInput) -> str:
"""Compute a hash of file content for cache key.
def __init__(
self,
ttl: int = DEFAULT_TTL_SECONDS,
namespace: str = "crewai_uploads",
cache_type: str = "memory",
**cache_kwargs: Any,
) -> None:
"""Initialize the upload cache.
Args:
file: The file to hash.
Returns:
SHA-256 hash of the file content.
ttl: Default TTL in seconds.
namespace: Cache namespace.
cache_type: Backend type ("memory" or "redis").
**cache_kwargs: Additional args for cache backend.
"""
content = file.source.read()
return hashlib.sha256(content).hexdigest()
self.ttl = ttl
self.namespace = namespace
self._provider_keys: dict[str, set[str]] = {}
def get(self, file: FileInput, provider: str) -> CachedUpload | None:
if cache_type == "redis":
self._cache = Cache(
Cache.REDIS,
serializer=PickleSerializer(),
namespace=namespace,
**cache_kwargs,
)
else:
self._cache = Cache(
Cache.MEMORY,
serializer=PickleSerializer(),
namespace=namespace,
)
def _track_key(self, provider: str, key: str) -> None:
"""Track a key for a provider (for cleanup)."""
if provider not in self._provider_keys:
self._provider_keys[provider] = set()
self._provider_keys[provider].add(key)
def _untrack_key(self, provider: str, key: str) -> None:
"""Remove key tracking for a provider."""
if provider in self._provider_keys:
self._provider_keys[provider].discard(key)
# -------------------------------------------------------------------------
# Async methods (primary interface)
# -------------------------------------------------------------------------
async def aget(self, file: FileInput, provider: str) -> CachedUpload | None:
"""Get a cached upload for a file.
Args:
@@ -89,21 +140,10 @@ class UploadCache:
Returns:
Cached upload if found and not expired, None otherwise.
"""
file_hash = self._compute_hash(file)
key = (file_hash, provider)
file_hash = _compute_file_hash(file)
return await self.aget_by_hash(file_hash, provider)
with self._lock:
cached = self._cache.get(key)
if cached is None:
return None
if cached.is_expired():
del self._cache[key]
return None
return cached
def get_by_hash(self, file_hash: str, provider: str) -> CachedUpload | None:
async def aget_by_hash(self, file_hash: str, provider: str) -> CachedUpload | None:
"""Get a cached upload by file hash.
Args:
@@ -113,20 +153,20 @@ class UploadCache:
Returns:
Cached upload if found and not expired, None otherwise.
"""
key = (file_hash, provider)
key = _make_key(file_hash, provider)
result = await self._cache.get(key)
with self._lock:
cached = self._cache.get(key)
if cached is None:
if result is None:
return None
if isinstance(result, CachedUpload):
if result.is_expired():
await self._cache.delete(key)
self._untrack_key(provider, key)
return None
return result
return None
if cached.is_expired():
del self._cache[key]
return None
return cached
def set(
async def aset(
self,
file: FileInput,
provider: str,
@@ -146,26 +186,17 @@ class UploadCache:
Returns:
The created cache entry.
"""
file_hash = self._compute_hash(file)
key = (file_hash, provider)
now = datetime.now(timezone.utc)
cached = CachedUpload(
file_id=file_id,
provider=provider,
file_uri=file_uri,
file_hash = _compute_file_hash(file)
return await self.aset_by_hash(
file_hash=file_hash,
content_type=file.content_type,
uploaded_at=now,
provider=provider,
file_id=file_id,
file_uri=file_uri,
expires_at=expires_at,
)
with self._lock:
self._cache[key] = cached
logger.debug(f"Cached upload: {file_id} for provider {provider}")
return cached
def set_by_hash(
async def aset_by_hash(
self,
file_hash: str,
content_type: str,
@@ -187,9 +218,9 @@ class UploadCache:
Returns:
The created cache entry.
"""
key = (file_hash, provider)
key = _make_key(file_hash, provider)
now = datetime.now(timezone.utc)
cached = CachedUpload(
file_id=file_id,
provider=provider,
@@ -199,13 +230,16 @@ class UploadCache:
expires_at=expires_at,
)
with self._lock:
self._cache[key] = cached
ttl = self.ttl
if expires_at is not None:
ttl = max(0, int((expires_at - now).total_seconds()))
await self._cache.set(key, cached, ttl=ttl)
self._track_key(provider, key)
logger.debug(f"Cached upload: {file_id} for provider {provider}")
return cached
def remove(self, file: FileInput, provider: str) -> bool:
async def aremove(self, file: FileInput, provider: str) -> bool:
"""Remove a cached upload.
Args:
@@ -215,16 +249,16 @@ class UploadCache:
Returns:
True if entry was removed, False if not found.
"""
file_hash = self._compute_hash(file)
key = (file_hash, provider)
file_hash = _compute_file_hash(file)
key = _make_key(file_hash, provider)
with self._lock:
if key in self._cache:
del self._cache[key]
return True
return False
result = await self._cache.delete(key)
removed = bool(result > 0 if isinstance(result, int) else result)
if removed:
self._untrack_key(provider, key)
return removed
def remove_by_file_id(self, file_id: str, provider: str) -> bool:
async def aremove_by_file_id(self, file_id: str, provider: str) -> bool:
"""Remove a cached upload by file ID.
Args:
@@ -234,14 +268,18 @@ class UploadCache:
Returns:
True if entry was removed, False if not found.
"""
with self._lock:
for key, cached in list(self._cache.items()):
if cached.file_id == file_id and cached.provider == provider:
del self._cache[key]
return True
if provider not in self._provider_keys:
return False
def clear_expired(self) -> int:
for key in list(self._provider_keys[provider]):
cached = await self._cache.get(key)
if isinstance(cached, CachedUpload) and cached.file_id == file_id:
await self._cache.delete(key)
self._untrack_key(provider, key)
return True
return False
async def aclear_expired(self) -> int:
"""Remove all expired entries from the cache.
Returns:
@@ -249,33 +287,35 @@ class UploadCache:
"""
removed = 0
with self._lock:
for key in list(self._cache.keys()):
if self._cache[key].is_expired():
del self._cache[key]
for provider, keys in list(self._provider_keys.items()):
for key in list(keys):
cached = await self._cache.get(key)
if cached is None or (
isinstance(cached, CachedUpload) and cached.is_expired()
):
await self._cache.delete(key)
self._untrack_key(provider, key)
removed += 1
if removed > 0:
logger.debug(f"Cleared {removed} expired cache entries")
return removed
def clear(self) -> int:
async def aclear(self) -> int:
"""Clear all entries from the cache.
Returns:
Number of entries cleared.
"""
with self._lock:
count = len(self._cache)
self._cache.clear()
count = sum(len(keys) for keys in self._provider_keys.values())
await self._cache.clear(namespace=self.namespace)
self._provider_keys.clear()
if count > 0:
logger.debug(f"Cleared {count} cache entries")
return count
def get_all_for_provider(self, provider: str) -> list[CachedUpload]:
async def aget_all_for_provider(self, provider: str) -> list[CachedUpload]:
"""Get all cached uploads for a provider.
Args:
@@ -284,14 +324,171 @@ class UploadCache:
Returns:
List of cached uploads for the provider.
"""
with self._lock:
return [
cached
for (_, p), cached in self._cache.items()
if p == provider and not cached.is_expired()
]
if provider not in self._provider_keys:
return []
results: list[CachedUpload] = []
for key in list(self._provider_keys[provider]):
cached = await self._cache.get(key)
if isinstance(cached, CachedUpload) and not cached.is_expired():
results.append(cached)
return results
# -------------------------------------------------------------------------
# Sync wrappers (convenience)
# -------------------------------------------------------------------------
def _run_sync(self, coro: Any) -> Any:
"""Run an async coroutine from sync context."""
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop is not None and loop.is_running():
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as pool:
future = pool.submit(asyncio.run, coro)
return future.result()
return asyncio.run(coro)
def get(self, file: FileInput, provider: str) -> CachedUpload | None:
"""Sync wrapper for aget."""
result: CachedUpload | None = self._run_sync(self.aget(file, provider))
return result
def get_by_hash(self, file_hash: str, provider: str) -> CachedUpload | None:
"""Sync wrapper for aget_by_hash."""
result: CachedUpload | None = self._run_sync(
self.aget_by_hash(file_hash, provider)
)
return result
def set(
self,
file: FileInput,
provider: str,
file_id: str,
file_uri: str | None = None,
expires_at: datetime | None = None,
) -> CachedUpload:
"""Sync wrapper for aset."""
result: CachedUpload = self._run_sync(
self.aset(file, provider, file_id, file_uri, expires_at)
)
return result
def set_by_hash(
self,
file_hash: str,
content_type: str,
provider: str,
file_id: str,
file_uri: str | None = None,
expires_at: datetime | None = None,
) -> CachedUpload:
"""Sync wrapper for aset_by_hash."""
result: CachedUpload = self._run_sync(
self.aset_by_hash(
file_hash, content_type, provider, file_id, file_uri, expires_at
)
)
return result
def remove(self, file: FileInput, provider: str) -> bool:
"""Sync wrapper for aremove."""
result: bool = self._run_sync(self.aremove(file, provider))
return result
def remove_by_file_id(self, file_id: str, provider: str) -> bool:
"""Sync wrapper for aremove_by_file_id."""
result: bool = self._run_sync(self.aremove_by_file_id(file_id, provider))
return result
def clear_expired(self) -> int:
"""Sync wrapper for aclear_expired."""
result: int = self._run_sync(self.aclear_expired())
return result
def clear(self) -> int:
"""Sync wrapper for aclear."""
result: int = self._run_sync(self.aclear())
return result
def get_all_for_provider(self, provider: str) -> list[CachedUpload]:
"""Sync wrapper for aget_all_for_provider."""
result: list[CachedUpload] = self._run_sync(
self.aget_all_for_provider(provider)
)
return result
def __len__(self) -> int:
"""Return the number of cached entries."""
with self._lock:
return len(self._cache)
return sum(len(keys) for keys in self._provider_keys.values())
def get_providers(self) -> builtins.set[str]:
"""Get all provider names that have cached entries.
Returns:
Set of provider names.
"""
return builtins.set(self._provider_keys.keys())
# Module-level cache instance
_default_cache: UploadCache | None = None
def get_upload_cache(
ttl: int = DEFAULT_TTL_SECONDS,
namespace: str = "crewai_uploads",
cache_type: str = "memory",
**cache_kwargs: Any,
) -> UploadCache:
"""Get or create the default upload cache.
Args:
ttl: Default TTL in seconds.
namespace: Cache namespace.
cache_type: Backend type ("memory" or "redis").
**cache_kwargs: Additional args for cache backend.
Returns:
The upload cache instance.
"""
global _default_cache
if _default_cache is None:
_default_cache = UploadCache(
ttl=ttl,
namespace=namespace,
cache_type=cache_type,
**cache_kwargs,
)
return _default_cache
def reset_upload_cache() -> None:
"""Reset the default upload cache (useful for testing)."""
global _default_cache
if _default_cache is not None:
_default_cache.clear()
_default_cache = None
def _cleanup_on_exit() -> None:
"""Clean up uploaded files on process exit."""
global _default_cache
if _default_cache is None or len(_default_cache) == 0:
return
# Import here to avoid circular imports
from crewai.utilities.files.cleanup import cleanup_uploaded_files
try:
cleanup_uploaded_files(_default_cache, delete_from_provider=True)
except Exception as e:
logger.debug(f"Error during exit cleanup: {e}")
atexit.register(_cleanup_on_exit)

View File

@@ -0,0 +1,206 @@
"""Tests for upload cache."""
from datetime import datetime, timedelta, timezone
import pytest
from crewai.utilities.files import FileBytes, ImageFile
from crewai.utilities.files.upload_cache import CachedUpload, UploadCache
# Minimal valid PNG
MINIMAL_PNG = (
b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x08\x00\x00\x00\x08"
b"\x01\x00\x00\x00\x00\xf9Y\xab\xcd\x00\x00\x00\nIDATx\x9cc`\x00\x00"
b"\x00\x02\x00\x01\xe2!\xbc3\x00\x00\x00\x00IEND\xaeB`\x82"
)
class TestCachedUpload:
"""Tests for CachedUpload dataclass."""
def test_cached_upload_creation(self):
"""Test creating a cached upload."""
now = datetime.now(timezone.utc)
cached = CachedUpload(
file_id="file-123",
provider="gemini",
file_uri="files/file-123",
content_type="image/png",
uploaded_at=now,
expires_at=now + timedelta(hours=48),
)
assert cached.file_id == "file-123"
assert cached.provider == "gemini"
assert cached.file_uri == "files/file-123"
assert cached.content_type == "image/png"
def test_is_expired_false(self):
"""Test is_expired returns False for non-expired upload."""
future = datetime.now(timezone.utc) + timedelta(hours=24)
cached = CachedUpload(
file_id="file-123",
provider="gemini",
file_uri=None,
content_type="image/png",
uploaded_at=datetime.now(timezone.utc),
expires_at=future,
)
assert cached.is_expired() is False
def test_is_expired_true(self):
"""Test is_expired returns True for expired upload."""
past = datetime.now(timezone.utc) - timedelta(hours=1)
cached = CachedUpload(
file_id="file-123",
provider="gemini",
file_uri=None,
content_type="image/png",
uploaded_at=datetime.now(timezone.utc) - timedelta(hours=2),
expires_at=past,
)
assert cached.is_expired() is True
def test_is_expired_no_expiry(self):
"""Test is_expired returns False when no expiry set."""
cached = CachedUpload(
file_id="file-123",
provider="anthropic",
file_uri=None,
content_type="image/png",
uploaded_at=datetime.now(timezone.utc),
expires_at=None,
)
assert cached.is_expired() is False
class TestUploadCache:
"""Tests for UploadCache class."""
def test_cache_creation(self):
"""Test creating an empty cache."""
cache = UploadCache()
assert len(cache) == 0
def test_set_and_get(self):
"""Test setting and getting cached uploads."""
cache = UploadCache()
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
cached = cache.set(
file=file,
provider="gemini",
file_id="file-123",
file_uri="files/file-123",
)
result = cache.get(file, "gemini")
assert result is not None
assert result.file_id == "file-123"
assert result.provider == "gemini"
def test_get_missing(self):
"""Test getting non-existent entry returns None."""
cache = UploadCache()
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
result = cache.get(file, "gemini")
assert result is None
def test_get_different_provider(self):
"""Test getting with different provider returns None."""
cache = UploadCache()
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
cache.set(file=file, provider="gemini", file_id="file-123")
result = cache.get(file, "anthropic") # Different provider
assert result is None
def test_remove(self):
"""Test removing cached entry."""
cache = UploadCache()
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
cache.set(file=file, provider="gemini", file_id="file-123")
removed = cache.remove(file, "gemini")
assert removed is True
assert cache.get(file, "gemini") is None
def test_remove_missing(self):
"""Test removing non-existent entry returns False."""
cache = UploadCache()
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
removed = cache.remove(file, "gemini")
assert removed is False
def test_remove_by_file_id(self):
"""Test removing by file ID."""
cache = UploadCache()
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
cache.set(file=file, provider="gemini", file_id="file-123")
removed = cache.remove_by_file_id("file-123", "gemini")
assert removed is True
assert len(cache) == 0
def test_clear_expired(self):
"""Test clearing expired entries."""
cache = UploadCache()
file1 = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test1.png"))
file2 = ImageFile(source=FileBytes(data=MINIMAL_PNG + b"x", filename="test2.png"))
# Add one expired and one valid entry
past = datetime.now(timezone.utc) - timedelta(hours=1)
future = datetime.now(timezone.utc) + timedelta(hours=24)
cache.set(file=file1, provider="gemini", file_id="expired", expires_at=past)
cache.set(file=file2, provider="gemini", file_id="valid", expires_at=future)
removed = cache.clear_expired()
assert removed == 1
assert len(cache) == 1
assert cache.get(file2, "gemini") is not None
def test_clear(self):
"""Test clearing all entries."""
cache = UploadCache()
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
cache.set(file=file, provider="gemini", file_id="file-123")
cache.set(file=file, provider="anthropic", file_id="file-456")
cleared = cache.clear()
assert cleared == 2
assert len(cache) == 0
def test_get_all_for_provider(self):
"""Test getting all cached uploads for a provider."""
cache = UploadCache()
file1 = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test1.png"))
file2 = ImageFile(source=FileBytes(data=MINIMAL_PNG + b"x", filename="test2.png"))
file3 = ImageFile(source=FileBytes(data=MINIMAL_PNG + b"xx", filename="test3.png"))
cache.set(file=file1, provider="gemini", file_id="file-1")
cache.set(file=file2, provider="gemini", file_id="file-2")
cache.set(file=file3, provider="anthropic", file_id="file-3")
gemini_uploads = cache.get_all_for_provider("gemini")
anthropic_uploads = cache.get_all_for_provider("anthropic")
assert len(gemini_uploads) == 2
assert len(anthropic_uploads) == 1