feat: upgrade upload cache to aiocache with atexit cleanup

This commit is contained in:
Greyson LaLonde
2026-01-21 19:35:56 -05:00
parent d8ebfe7ee0
commit 42ca4eacff
3 changed files with 520 additions and 123 deletions

View File

@@ -177,10 +177,4 @@ def _get_providers_from_cache(cache: UploadCache) -> set[str]:
Returns: Returns:
Set of provider names. Set of provider names.
""" """
providers: set[str] = set() return cache.get_providers()
with cache._lock:
for _, provider in cache._cache.keys():
providers.add(provider)
return providers

View File

@@ -1,24 +1,36 @@
"""Cache for tracking uploaded files to avoid redundant uploads.""" """Cache for tracking uploaded files using aiocache."""
from dataclasses import dataclass, field from __future__ import annotations
from datetime import datetime, timedelta, timezone
import asyncio
import atexit
import builtins
from dataclasses import dataclass
from datetime import datetime, timezone
import hashlib import hashlib
import logging import logging
import threading from typing import TYPE_CHECKING, Any
from crewai.utilities.files.content_types import ( from aiocache import Cache # type: ignore[import-untyped]
AudioFile, from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
File,
ImageFile,
PDFFile, if TYPE_CHECKING:
TextFile, from crewai.utilities.files.content_types import (
VideoFile, AudioFile,
) File,
ImageFile,
PDFFile,
TextFile,
VideoFile,
)
FileInput = AudioFile | File | ImageFile | PDFFile | TextFile | VideoFile
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
FileInput = AudioFile | File | ImageFile | PDFFile | TextFile | VideoFile DEFAULT_TTL_SECONDS = 24 * 60 * 60 # 24 hours
@dataclass @dataclass
@@ -42,44 +54,83 @@ class CachedUpload:
expires_at: datetime | None = None expires_at: datetime | None = None
def is_expired(self) -> bool: def is_expired(self) -> bool:
"""Check if this cached upload has expired. """Check if this cached upload has expired."""
Returns:
True if expired, False otherwise.
"""
if self.expires_at is None: if self.expires_at is None:
return False return False
return datetime.now(timezone.utc) >= self.expires_at return datetime.now(timezone.utc) >= self.expires_at
@dataclass def _make_key(file_hash: str, provider: str) -> str:
class UploadCache: """Create a cache key from file hash and provider."""
"""Thread-safe cache for tracking uploaded files. return f"upload:{provider}:{file_hash}"
Uses file content hash and provider as composite key to avoid
uploading the same file multiple times. def _compute_file_hash(file: FileInput) -> str:
"""Compute SHA-256 hash of file content."""
content = file.read()
return hashlib.sha256(content).hexdigest()
class UploadCache:
"""Async cache for tracking uploaded files using aiocache.
Supports in-memory caching by default, with optional Redis backend
for distributed setups.
Attributes: Attributes:
default_ttl: Default time-to-live for cached entries. ttl: Default time-to-live in seconds for cached entries.
namespace: Cache namespace for isolation.
""" """
default_ttl: timedelta = field(default_factory=lambda: timedelta(hours=24)) def __init__(
_cache: dict[tuple[str, str], CachedUpload] = field(default_factory=dict) self,
_lock: threading.Lock = field(default_factory=threading.Lock) ttl: int = DEFAULT_TTL_SECONDS,
namespace: str = "crewai_uploads",
def _compute_hash(self, file: FileInput) -> str: cache_type: str = "memory",
"""Compute a hash of file content for cache key. **cache_kwargs: Any,
) -> None:
"""Initialize the upload cache.
Args: Args:
file: The file to hash. ttl: Default TTL in seconds.
namespace: Cache namespace.
Returns: cache_type: Backend type ("memory" or "redis").
SHA-256 hash of the file content. **cache_kwargs: Additional args for cache backend.
""" """
content = file.source.read() self.ttl = ttl
return hashlib.sha256(content).hexdigest() self.namespace = namespace
self._provider_keys: dict[str, set[str]] = {}
def get(self, file: FileInput, provider: str) -> CachedUpload | None: if cache_type == "redis":
self._cache = Cache(
Cache.REDIS,
serializer=PickleSerializer(),
namespace=namespace,
**cache_kwargs,
)
else:
self._cache = Cache(
Cache.MEMORY,
serializer=PickleSerializer(),
namespace=namespace,
)
def _track_key(self, provider: str, key: str) -> None:
"""Track a key for a provider (for cleanup)."""
if provider not in self._provider_keys:
self._provider_keys[provider] = set()
self._provider_keys[provider].add(key)
def _untrack_key(self, provider: str, key: str) -> None:
"""Remove key tracking for a provider."""
if provider in self._provider_keys:
self._provider_keys[provider].discard(key)
# -------------------------------------------------------------------------
# Async methods (primary interface)
# -------------------------------------------------------------------------
async def aget(self, file: FileInput, provider: str) -> CachedUpload | None:
"""Get a cached upload for a file. """Get a cached upload for a file.
Args: Args:
@@ -89,21 +140,10 @@ class UploadCache:
Returns: Returns:
Cached upload if found and not expired, None otherwise. Cached upload if found and not expired, None otherwise.
""" """
file_hash = self._compute_hash(file) file_hash = _compute_file_hash(file)
key = (file_hash, provider) return await self.aget_by_hash(file_hash, provider)
with self._lock: async def aget_by_hash(self, file_hash: str, provider: str) -> CachedUpload | None:
cached = self._cache.get(key)
if cached is None:
return None
if cached.is_expired():
del self._cache[key]
return None
return cached
def get_by_hash(self, file_hash: str, provider: str) -> CachedUpload | None:
"""Get a cached upload by file hash. """Get a cached upload by file hash.
Args: Args:
@@ -113,20 +153,20 @@ class UploadCache:
Returns: Returns:
Cached upload if found and not expired, None otherwise. Cached upload if found and not expired, None otherwise.
""" """
key = (file_hash, provider) key = _make_key(file_hash, provider)
result = await self._cache.get(key)
with self._lock: if result is None:
cached = self._cache.get(key) return None
if cached is None: if isinstance(result, CachedUpload):
if result.is_expired():
await self._cache.delete(key)
self._untrack_key(provider, key)
return None return None
return result
return None
if cached.is_expired(): async def aset(
del self._cache[key]
return None
return cached
def set(
self, self,
file: FileInput, file: FileInput,
provider: str, provider: str,
@@ -146,26 +186,17 @@ class UploadCache:
Returns: Returns:
The created cache entry. The created cache entry.
""" """
file_hash = self._compute_hash(file) file_hash = _compute_file_hash(file)
key = (file_hash, provider) return await self.aset_by_hash(
file_hash=file_hash,
now = datetime.now(timezone.utc)
cached = CachedUpload(
file_id=file_id,
provider=provider,
file_uri=file_uri,
content_type=file.content_type, content_type=file.content_type,
uploaded_at=now, provider=provider,
file_id=file_id,
file_uri=file_uri,
expires_at=expires_at, expires_at=expires_at,
) )
with self._lock: async def aset_by_hash(
self._cache[key] = cached
logger.debug(f"Cached upload: {file_id} for provider {provider}")
return cached
def set_by_hash(
self, self,
file_hash: str, file_hash: str,
content_type: str, content_type: str,
@@ -187,9 +218,9 @@ class UploadCache:
Returns: Returns:
The created cache entry. The created cache entry.
""" """
key = (file_hash, provider) key = _make_key(file_hash, provider)
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
cached = CachedUpload( cached = CachedUpload(
file_id=file_id, file_id=file_id,
provider=provider, provider=provider,
@@ -199,13 +230,16 @@ class UploadCache:
expires_at=expires_at, expires_at=expires_at,
) )
with self._lock: ttl = self.ttl
self._cache[key] = cached if expires_at is not None:
ttl = max(0, int((expires_at - now).total_seconds()))
await self._cache.set(key, cached, ttl=ttl)
self._track_key(provider, key)
logger.debug(f"Cached upload: {file_id} for provider {provider}") logger.debug(f"Cached upload: {file_id} for provider {provider}")
return cached return cached
def remove(self, file: FileInput, provider: str) -> bool: async def aremove(self, file: FileInput, provider: str) -> bool:
"""Remove a cached upload. """Remove a cached upload.
Args: Args:
@@ -215,16 +249,16 @@ class UploadCache:
Returns: Returns:
True if entry was removed, False if not found. True if entry was removed, False if not found.
""" """
file_hash = self._compute_hash(file) file_hash = _compute_file_hash(file)
key = (file_hash, provider) key = _make_key(file_hash, provider)
with self._lock: result = await self._cache.delete(key)
if key in self._cache: removed = bool(result > 0 if isinstance(result, int) else result)
del self._cache[key] if removed:
return True self._untrack_key(provider, key)
return False return removed
def remove_by_file_id(self, file_id: str, provider: str) -> bool: async def aremove_by_file_id(self, file_id: str, provider: str) -> bool:
"""Remove a cached upload by file ID. """Remove a cached upload by file ID.
Args: Args:
@@ -234,14 +268,18 @@ class UploadCache:
Returns: Returns:
True if entry was removed, False if not found. True if entry was removed, False if not found.
""" """
with self._lock: if provider not in self._provider_keys:
for key, cached in list(self._cache.items()):
if cached.file_id == file_id and cached.provider == provider:
del self._cache[key]
return True
return False return False
def clear_expired(self) -> int: for key in list(self._provider_keys[provider]):
cached = await self._cache.get(key)
if isinstance(cached, CachedUpload) and cached.file_id == file_id:
await self._cache.delete(key)
self._untrack_key(provider, key)
return True
return False
async def aclear_expired(self) -> int:
"""Remove all expired entries from the cache. """Remove all expired entries from the cache.
Returns: Returns:
@@ -249,33 +287,35 @@ class UploadCache:
""" """
removed = 0 removed = 0
with self._lock: for provider, keys in list(self._provider_keys.items()):
for key in list(self._cache.keys()): for key in list(keys):
if self._cache[key].is_expired(): cached = await self._cache.get(key)
del self._cache[key] if cached is None or (
isinstance(cached, CachedUpload) and cached.is_expired()
):
await self._cache.delete(key)
self._untrack_key(provider, key)
removed += 1 removed += 1
if removed > 0: if removed > 0:
logger.debug(f"Cleared {removed} expired cache entries") logger.debug(f"Cleared {removed} expired cache entries")
return removed return removed
def clear(self) -> int: async def aclear(self) -> int:
"""Clear all entries from the cache. """Clear all entries from the cache.
Returns: Returns:
Number of entries cleared. Number of entries cleared.
""" """
with self._lock: count = sum(len(keys) for keys in self._provider_keys.values())
count = len(self._cache) await self._cache.clear(namespace=self.namespace)
self._cache.clear() self._provider_keys.clear()
if count > 0: if count > 0:
logger.debug(f"Cleared {count} cache entries") logger.debug(f"Cleared {count} cache entries")
return count return count
def get_all_for_provider(self, provider: str) -> list[CachedUpload]: async def aget_all_for_provider(self, provider: str) -> list[CachedUpload]:
"""Get all cached uploads for a provider. """Get all cached uploads for a provider.
Args: Args:
@@ -284,14 +324,171 @@ class UploadCache:
Returns: Returns:
List of cached uploads for the provider. List of cached uploads for the provider.
""" """
with self._lock: if provider not in self._provider_keys:
return [ return []
cached
for (_, p), cached in self._cache.items() results: list[CachedUpload] = []
if p == provider and not cached.is_expired() for key in list(self._provider_keys[provider]):
] cached = await self._cache.get(key)
if isinstance(cached, CachedUpload) and not cached.is_expired():
results.append(cached)
return results
# -------------------------------------------------------------------------
# Sync wrappers (convenience)
# -------------------------------------------------------------------------
def _run_sync(self, coro: Any) -> Any:
"""Run an async coroutine from sync context."""
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop is not None and loop.is_running():
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as pool:
future = pool.submit(asyncio.run, coro)
return future.result()
return asyncio.run(coro)
def get(self, file: FileInput, provider: str) -> CachedUpload | None:
"""Sync wrapper for aget."""
result: CachedUpload | None = self._run_sync(self.aget(file, provider))
return result
def get_by_hash(self, file_hash: str, provider: str) -> CachedUpload | None:
"""Sync wrapper for aget_by_hash."""
result: CachedUpload | None = self._run_sync(
self.aget_by_hash(file_hash, provider)
)
return result
def set(
self,
file: FileInput,
provider: str,
file_id: str,
file_uri: str | None = None,
expires_at: datetime | None = None,
) -> CachedUpload:
"""Sync wrapper for aset."""
result: CachedUpload = self._run_sync(
self.aset(file, provider, file_id, file_uri, expires_at)
)
return result
def set_by_hash(
self,
file_hash: str,
content_type: str,
provider: str,
file_id: str,
file_uri: str | None = None,
expires_at: datetime | None = None,
) -> CachedUpload:
"""Sync wrapper for aset_by_hash."""
result: CachedUpload = self._run_sync(
self.aset_by_hash(
file_hash, content_type, provider, file_id, file_uri, expires_at
)
)
return result
def remove(self, file: FileInput, provider: str) -> bool:
"""Sync wrapper for aremove."""
result: bool = self._run_sync(self.aremove(file, provider))
return result
def remove_by_file_id(self, file_id: str, provider: str) -> bool:
"""Sync wrapper for aremove_by_file_id."""
result: bool = self._run_sync(self.aremove_by_file_id(file_id, provider))
return result
def clear_expired(self) -> int:
"""Sync wrapper for aclear_expired."""
result: int = self._run_sync(self.aclear_expired())
return result
def clear(self) -> int:
"""Sync wrapper for aclear."""
result: int = self._run_sync(self.aclear())
return result
def get_all_for_provider(self, provider: str) -> list[CachedUpload]:
"""Sync wrapper for aget_all_for_provider."""
result: list[CachedUpload] = self._run_sync(
self.aget_all_for_provider(provider)
)
return result
def __len__(self) -> int: def __len__(self) -> int:
"""Return the number of cached entries.""" """Return the number of cached entries."""
with self._lock: return sum(len(keys) for keys in self._provider_keys.values())
return len(self._cache)
def get_providers(self) -> builtins.set[str]:
"""Get all provider names that have cached entries.
Returns:
Set of provider names.
"""
return builtins.set(self._provider_keys.keys())
# Module-level cache instance
_default_cache: UploadCache | None = None
def get_upload_cache(
ttl: int = DEFAULT_TTL_SECONDS,
namespace: str = "crewai_uploads",
cache_type: str = "memory",
**cache_kwargs: Any,
) -> UploadCache:
"""Get or create the default upload cache.
Args:
ttl: Default TTL in seconds.
namespace: Cache namespace.
cache_type: Backend type ("memory" or "redis").
**cache_kwargs: Additional args for cache backend.
Returns:
The upload cache instance.
"""
global _default_cache
if _default_cache is None:
_default_cache = UploadCache(
ttl=ttl,
namespace=namespace,
cache_type=cache_type,
**cache_kwargs,
)
return _default_cache
def reset_upload_cache() -> None:
"""Reset the default upload cache (useful for testing)."""
global _default_cache
if _default_cache is not None:
_default_cache.clear()
_default_cache = None
def _cleanup_on_exit() -> None:
"""Clean up uploaded files on process exit."""
global _default_cache
if _default_cache is None or len(_default_cache) == 0:
return
# Import here to avoid circular imports
from crewai.utilities.files.cleanup import cleanup_uploaded_files
try:
cleanup_uploaded_files(_default_cache, delete_from_provider=True)
except Exception as e:
logger.debug(f"Error during exit cleanup: {e}")
atexit.register(_cleanup_on_exit)

View File

@@ -0,0 +1,206 @@
"""Tests for upload cache."""
from datetime import datetime, timedelta, timezone
import pytest
from crewai.utilities.files import FileBytes, ImageFile
from crewai.utilities.files.upload_cache import CachedUpload, UploadCache
# Minimal valid PNG
MINIMAL_PNG = (
b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x08\x00\x00\x00\x08"
b"\x01\x00\x00\x00\x00\xf9Y\xab\xcd\x00\x00\x00\nIDATx\x9cc`\x00\x00"
b"\x00\x02\x00\x01\xe2!\xbc3\x00\x00\x00\x00IEND\xaeB`\x82"
)
class TestCachedUpload:
"""Tests for CachedUpload dataclass."""
def test_cached_upload_creation(self):
"""Test creating a cached upload."""
now = datetime.now(timezone.utc)
cached = CachedUpload(
file_id="file-123",
provider="gemini",
file_uri="files/file-123",
content_type="image/png",
uploaded_at=now,
expires_at=now + timedelta(hours=48),
)
assert cached.file_id == "file-123"
assert cached.provider == "gemini"
assert cached.file_uri == "files/file-123"
assert cached.content_type == "image/png"
def test_is_expired_false(self):
"""Test is_expired returns False for non-expired upload."""
future = datetime.now(timezone.utc) + timedelta(hours=24)
cached = CachedUpload(
file_id="file-123",
provider="gemini",
file_uri=None,
content_type="image/png",
uploaded_at=datetime.now(timezone.utc),
expires_at=future,
)
assert cached.is_expired() is False
def test_is_expired_true(self):
"""Test is_expired returns True for expired upload."""
past = datetime.now(timezone.utc) - timedelta(hours=1)
cached = CachedUpload(
file_id="file-123",
provider="gemini",
file_uri=None,
content_type="image/png",
uploaded_at=datetime.now(timezone.utc) - timedelta(hours=2),
expires_at=past,
)
assert cached.is_expired() is True
def test_is_expired_no_expiry(self):
"""Test is_expired returns False when no expiry set."""
cached = CachedUpload(
file_id="file-123",
provider="anthropic",
file_uri=None,
content_type="image/png",
uploaded_at=datetime.now(timezone.utc),
expires_at=None,
)
assert cached.is_expired() is False
class TestUploadCache:
"""Tests for UploadCache class."""
def test_cache_creation(self):
"""Test creating an empty cache."""
cache = UploadCache()
assert len(cache) == 0
def test_set_and_get(self):
"""Test setting and getting cached uploads."""
cache = UploadCache()
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
cached = cache.set(
file=file,
provider="gemini",
file_id="file-123",
file_uri="files/file-123",
)
result = cache.get(file, "gemini")
assert result is not None
assert result.file_id == "file-123"
assert result.provider == "gemini"
def test_get_missing(self):
"""Test getting non-existent entry returns None."""
cache = UploadCache()
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
result = cache.get(file, "gemini")
assert result is None
def test_get_different_provider(self):
"""Test getting with different provider returns None."""
cache = UploadCache()
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
cache.set(file=file, provider="gemini", file_id="file-123")
result = cache.get(file, "anthropic") # Different provider
assert result is None
def test_remove(self):
"""Test removing cached entry."""
cache = UploadCache()
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
cache.set(file=file, provider="gemini", file_id="file-123")
removed = cache.remove(file, "gemini")
assert removed is True
assert cache.get(file, "gemini") is None
def test_remove_missing(self):
"""Test removing non-existent entry returns False."""
cache = UploadCache()
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
removed = cache.remove(file, "gemini")
assert removed is False
def test_remove_by_file_id(self):
"""Test removing by file ID."""
cache = UploadCache()
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
cache.set(file=file, provider="gemini", file_id="file-123")
removed = cache.remove_by_file_id("file-123", "gemini")
assert removed is True
assert len(cache) == 0
def test_clear_expired(self):
"""Test clearing expired entries."""
cache = UploadCache()
file1 = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test1.png"))
file2 = ImageFile(source=FileBytes(data=MINIMAL_PNG + b"x", filename="test2.png"))
# Add one expired and one valid entry
past = datetime.now(timezone.utc) - timedelta(hours=1)
future = datetime.now(timezone.utc) + timedelta(hours=24)
cache.set(file=file1, provider="gemini", file_id="expired", expires_at=past)
cache.set(file=file2, provider="gemini", file_id="valid", expires_at=future)
removed = cache.clear_expired()
assert removed == 1
assert len(cache) == 1
assert cache.get(file2, "gemini") is not None
def test_clear(self):
"""Test clearing all entries."""
cache = UploadCache()
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
cache.set(file=file, provider="gemini", file_id="file-123")
cache.set(file=file, provider="anthropic", file_id="file-456")
cleared = cache.clear()
assert cleared == 2
assert len(cache) == 0
def test_get_all_for_provider(self):
"""Test getting all cached uploads for a provider."""
cache = UploadCache()
file1 = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test1.png"))
file2 = ImageFile(source=FileBytes(data=MINIMAL_PNG + b"x", filename="test2.png"))
file3 = ImageFile(source=FileBytes(data=MINIMAL_PNG + b"xx", filename="test3.png"))
cache.set(file=file1, provider="gemini", file_id="file-1")
cache.set(file=file2, provider="gemini", file_id="file-2")
cache.set(file=file3, provider="anthropic", file_id="file-3")
gemini_uploads = cache.get_all_for_provider("gemini")
anthropic_uploads = cache.get_all_for_provider("anthropic")
assert len(gemini_uploads) == 2
assert len(anthropic_uploads) == 1