mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-23 15:18:14 +00:00
feat: upgrade upload cache to aiocache with atexit cleanup
This commit is contained in:
@@ -177,10 +177,4 @@ def _get_providers_from_cache(cache: UploadCache) -> set[str]:
|
|||||||
Returns:
|
Returns:
|
||||||
Set of provider names.
|
Set of provider names.
|
||||||
"""
|
"""
|
||||||
providers: set[str] = set()
|
return cache.get_providers()
|
||||||
|
|
||||||
with cache._lock:
|
|
||||||
for _, provider in cache._cache.keys():
|
|
||||||
providers.add(provider)
|
|
||||||
|
|
||||||
return providers
|
|
||||||
|
|||||||
@@ -1,24 +1,36 @@
|
|||||||
"""Cache for tracking uploaded files to avoid redundant uploads."""
|
"""Cache for tracking uploaded files using aiocache."""
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from __future__ import annotations
|
||||||
from datetime import datetime, timedelta, timezone
|
|
||||||
|
import asyncio
|
||||||
|
import atexit
|
||||||
|
import builtins
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime, timezone
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
import threading
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from crewai.utilities.files.content_types import (
|
from aiocache import Cache # type: ignore[import-untyped]
|
||||||
AudioFile,
|
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
|
||||||
File,
|
|
||||||
ImageFile,
|
|
||||||
PDFFile,
|
if TYPE_CHECKING:
|
||||||
TextFile,
|
from crewai.utilities.files.content_types import (
|
||||||
VideoFile,
|
AudioFile,
|
||||||
)
|
File,
|
||||||
|
ImageFile,
|
||||||
|
PDFFile,
|
||||||
|
TextFile,
|
||||||
|
VideoFile,
|
||||||
|
)
|
||||||
|
|
||||||
|
FileInput = AudioFile | File | ImageFile | PDFFile | TextFile | VideoFile
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
FileInput = AudioFile | File | ImageFile | PDFFile | TextFile | VideoFile
|
DEFAULT_TTL_SECONDS = 24 * 60 * 60 # 24 hours
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -42,44 +54,83 @@ class CachedUpload:
|
|||||||
expires_at: datetime | None = None
|
expires_at: datetime | None = None
|
||||||
|
|
||||||
def is_expired(self) -> bool:
|
def is_expired(self) -> bool:
|
||||||
"""Check if this cached upload has expired.
|
"""Check if this cached upload has expired."""
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if expired, False otherwise.
|
|
||||||
"""
|
|
||||||
if self.expires_at is None:
|
if self.expires_at is None:
|
||||||
return False
|
return False
|
||||||
return datetime.now(timezone.utc) >= self.expires_at
|
return datetime.now(timezone.utc) >= self.expires_at
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
def _make_key(file_hash: str, provider: str) -> str:
|
||||||
class UploadCache:
|
"""Create a cache key from file hash and provider."""
|
||||||
"""Thread-safe cache for tracking uploaded files.
|
return f"upload:{provider}:{file_hash}"
|
||||||
|
|
||||||
Uses file content hash and provider as composite key to avoid
|
|
||||||
uploading the same file multiple times.
|
def _compute_file_hash(file: FileInput) -> str:
|
||||||
|
"""Compute SHA-256 hash of file content."""
|
||||||
|
content = file.read()
|
||||||
|
return hashlib.sha256(content).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
class UploadCache:
|
||||||
|
"""Async cache for tracking uploaded files using aiocache.
|
||||||
|
|
||||||
|
Supports in-memory caching by default, with optional Redis backend
|
||||||
|
for distributed setups.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
default_ttl: Default time-to-live for cached entries.
|
ttl: Default time-to-live in seconds for cached entries.
|
||||||
|
namespace: Cache namespace for isolation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
default_ttl: timedelta = field(default_factory=lambda: timedelta(hours=24))
|
def __init__(
|
||||||
_cache: dict[tuple[str, str], CachedUpload] = field(default_factory=dict)
|
self,
|
||||||
_lock: threading.Lock = field(default_factory=threading.Lock)
|
ttl: int = DEFAULT_TTL_SECONDS,
|
||||||
|
namespace: str = "crewai_uploads",
|
||||||
def _compute_hash(self, file: FileInput) -> str:
|
cache_type: str = "memory",
|
||||||
"""Compute a hash of file content for cache key.
|
**cache_kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the upload cache.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file: The file to hash.
|
ttl: Default TTL in seconds.
|
||||||
|
namespace: Cache namespace.
|
||||||
Returns:
|
cache_type: Backend type ("memory" or "redis").
|
||||||
SHA-256 hash of the file content.
|
**cache_kwargs: Additional args for cache backend.
|
||||||
"""
|
"""
|
||||||
content = file.source.read()
|
self.ttl = ttl
|
||||||
return hashlib.sha256(content).hexdigest()
|
self.namespace = namespace
|
||||||
|
self._provider_keys: dict[str, set[str]] = {}
|
||||||
|
|
||||||
def get(self, file: FileInput, provider: str) -> CachedUpload | None:
|
if cache_type == "redis":
|
||||||
|
self._cache = Cache(
|
||||||
|
Cache.REDIS,
|
||||||
|
serializer=PickleSerializer(),
|
||||||
|
namespace=namespace,
|
||||||
|
**cache_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._cache = Cache(
|
||||||
|
Cache.MEMORY,
|
||||||
|
serializer=PickleSerializer(),
|
||||||
|
namespace=namespace,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _track_key(self, provider: str, key: str) -> None:
|
||||||
|
"""Track a key for a provider (for cleanup)."""
|
||||||
|
if provider not in self._provider_keys:
|
||||||
|
self._provider_keys[provider] = set()
|
||||||
|
self._provider_keys[provider].add(key)
|
||||||
|
|
||||||
|
def _untrack_key(self, provider: str, key: str) -> None:
|
||||||
|
"""Remove key tracking for a provider."""
|
||||||
|
if provider in self._provider_keys:
|
||||||
|
self._provider_keys[provider].discard(key)
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# Async methods (primary interface)
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def aget(self, file: FileInput, provider: str) -> CachedUpload | None:
|
||||||
"""Get a cached upload for a file.
|
"""Get a cached upload for a file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -89,21 +140,10 @@ class UploadCache:
|
|||||||
Returns:
|
Returns:
|
||||||
Cached upload if found and not expired, None otherwise.
|
Cached upload if found and not expired, None otherwise.
|
||||||
"""
|
"""
|
||||||
file_hash = self._compute_hash(file)
|
file_hash = _compute_file_hash(file)
|
||||||
key = (file_hash, provider)
|
return await self.aget_by_hash(file_hash, provider)
|
||||||
|
|
||||||
with self._lock:
|
async def aget_by_hash(self, file_hash: str, provider: str) -> CachedUpload | None:
|
||||||
cached = self._cache.get(key)
|
|
||||||
if cached is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if cached.is_expired():
|
|
||||||
del self._cache[key]
|
|
||||||
return None
|
|
||||||
|
|
||||||
return cached
|
|
||||||
|
|
||||||
def get_by_hash(self, file_hash: str, provider: str) -> CachedUpload | None:
|
|
||||||
"""Get a cached upload by file hash.
|
"""Get a cached upload by file hash.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -113,20 +153,20 @@ class UploadCache:
|
|||||||
Returns:
|
Returns:
|
||||||
Cached upload if found and not expired, None otherwise.
|
Cached upload if found and not expired, None otherwise.
|
||||||
"""
|
"""
|
||||||
key = (file_hash, provider)
|
key = _make_key(file_hash, provider)
|
||||||
|
result = await self._cache.get(key)
|
||||||
|
|
||||||
with self._lock:
|
if result is None:
|
||||||
cached = self._cache.get(key)
|
return None
|
||||||
if cached is None:
|
if isinstance(result, CachedUpload):
|
||||||
|
if result.is_expired():
|
||||||
|
await self._cache.delete(key)
|
||||||
|
self._untrack_key(provider, key)
|
||||||
return None
|
return None
|
||||||
|
return result
|
||||||
|
return None
|
||||||
|
|
||||||
if cached.is_expired():
|
async def aset(
|
||||||
del self._cache[key]
|
|
||||||
return None
|
|
||||||
|
|
||||||
return cached
|
|
||||||
|
|
||||||
def set(
|
|
||||||
self,
|
self,
|
||||||
file: FileInput,
|
file: FileInput,
|
||||||
provider: str,
|
provider: str,
|
||||||
@@ -146,26 +186,17 @@ class UploadCache:
|
|||||||
Returns:
|
Returns:
|
||||||
The created cache entry.
|
The created cache entry.
|
||||||
"""
|
"""
|
||||||
file_hash = self._compute_hash(file)
|
file_hash = _compute_file_hash(file)
|
||||||
key = (file_hash, provider)
|
return await self.aset_by_hash(
|
||||||
|
file_hash=file_hash,
|
||||||
now = datetime.now(timezone.utc)
|
|
||||||
cached = CachedUpload(
|
|
||||||
file_id=file_id,
|
|
||||||
provider=provider,
|
|
||||||
file_uri=file_uri,
|
|
||||||
content_type=file.content_type,
|
content_type=file.content_type,
|
||||||
uploaded_at=now,
|
provider=provider,
|
||||||
|
file_id=file_id,
|
||||||
|
file_uri=file_uri,
|
||||||
expires_at=expires_at,
|
expires_at=expires_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
with self._lock:
|
async def aset_by_hash(
|
||||||
self._cache[key] = cached
|
|
||||||
|
|
||||||
logger.debug(f"Cached upload: {file_id} for provider {provider}")
|
|
||||||
return cached
|
|
||||||
|
|
||||||
def set_by_hash(
|
|
||||||
self,
|
self,
|
||||||
file_hash: str,
|
file_hash: str,
|
||||||
content_type: str,
|
content_type: str,
|
||||||
@@ -187,9 +218,9 @@ class UploadCache:
|
|||||||
Returns:
|
Returns:
|
||||||
The created cache entry.
|
The created cache entry.
|
||||||
"""
|
"""
|
||||||
key = (file_hash, provider)
|
key = _make_key(file_hash, provider)
|
||||||
|
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
cached = CachedUpload(
|
cached = CachedUpload(
|
||||||
file_id=file_id,
|
file_id=file_id,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
@@ -199,13 +230,16 @@ class UploadCache:
|
|||||||
expires_at=expires_at,
|
expires_at=expires_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
with self._lock:
|
ttl = self.ttl
|
||||||
self._cache[key] = cached
|
if expires_at is not None:
|
||||||
|
ttl = max(0, int((expires_at - now).total_seconds()))
|
||||||
|
|
||||||
|
await self._cache.set(key, cached, ttl=ttl)
|
||||||
|
self._track_key(provider, key)
|
||||||
logger.debug(f"Cached upload: {file_id} for provider {provider}")
|
logger.debug(f"Cached upload: {file_id} for provider {provider}")
|
||||||
return cached
|
return cached
|
||||||
|
|
||||||
def remove(self, file: FileInput, provider: str) -> bool:
|
async def aremove(self, file: FileInput, provider: str) -> bool:
|
||||||
"""Remove a cached upload.
|
"""Remove a cached upload.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -215,16 +249,16 @@ class UploadCache:
|
|||||||
Returns:
|
Returns:
|
||||||
True if entry was removed, False if not found.
|
True if entry was removed, False if not found.
|
||||||
"""
|
"""
|
||||||
file_hash = self._compute_hash(file)
|
file_hash = _compute_file_hash(file)
|
||||||
key = (file_hash, provider)
|
key = _make_key(file_hash, provider)
|
||||||
|
|
||||||
with self._lock:
|
result = await self._cache.delete(key)
|
||||||
if key in self._cache:
|
removed = bool(result > 0 if isinstance(result, int) else result)
|
||||||
del self._cache[key]
|
if removed:
|
||||||
return True
|
self._untrack_key(provider, key)
|
||||||
return False
|
return removed
|
||||||
|
|
||||||
def remove_by_file_id(self, file_id: str, provider: str) -> bool:
|
async def aremove_by_file_id(self, file_id: str, provider: str) -> bool:
|
||||||
"""Remove a cached upload by file ID.
|
"""Remove a cached upload by file ID.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -234,14 +268,18 @@ class UploadCache:
|
|||||||
Returns:
|
Returns:
|
||||||
True if entry was removed, False if not found.
|
True if entry was removed, False if not found.
|
||||||
"""
|
"""
|
||||||
with self._lock:
|
if provider not in self._provider_keys:
|
||||||
for key, cached in list(self._cache.items()):
|
|
||||||
if cached.file_id == file_id and cached.provider == provider:
|
|
||||||
del self._cache[key]
|
|
||||||
return True
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def clear_expired(self) -> int:
|
for key in list(self._provider_keys[provider]):
|
||||||
|
cached = await self._cache.get(key)
|
||||||
|
if isinstance(cached, CachedUpload) and cached.file_id == file_id:
|
||||||
|
await self._cache.delete(key)
|
||||||
|
self._untrack_key(provider, key)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def aclear_expired(self) -> int:
|
||||||
"""Remove all expired entries from the cache.
|
"""Remove all expired entries from the cache.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -249,33 +287,35 @@ class UploadCache:
|
|||||||
"""
|
"""
|
||||||
removed = 0
|
removed = 0
|
||||||
|
|
||||||
with self._lock:
|
for provider, keys in list(self._provider_keys.items()):
|
||||||
for key in list(self._cache.keys()):
|
for key in list(keys):
|
||||||
if self._cache[key].is_expired():
|
cached = await self._cache.get(key)
|
||||||
del self._cache[key]
|
if cached is None or (
|
||||||
|
isinstance(cached, CachedUpload) and cached.is_expired()
|
||||||
|
):
|
||||||
|
await self._cache.delete(key)
|
||||||
|
self._untrack_key(provider, key)
|
||||||
removed += 1
|
removed += 1
|
||||||
|
|
||||||
if removed > 0:
|
if removed > 0:
|
||||||
logger.debug(f"Cleared {removed} expired cache entries")
|
logger.debug(f"Cleared {removed} expired cache entries")
|
||||||
|
|
||||||
return removed
|
return removed
|
||||||
|
|
||||||
def clear(self) -> int:
|
async def aclear(self) -> int:
|
||||||
"""Clear all entries from the cache.
|
"""Clear all entries from the cache.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Number of entries cleared.
|
Number of entries cleared.
|
||||||
"""
|
"""
|
||||||
with self._lock:
|
count = sum(len(keys) for keys in self._provider_keys.values())
|
||||||
count = len(self._cache)
|
await self._cache.clear(namespace=self.namespace)
|
||||||
self._cache.clear()
|
self._provider_keys.clear()
|
||||||
|
|
||||||
if count > 0:
|
if count > 0:
|
||||||
logger.debug(f"Cleared {count} cache entries")
|
logger.debug(f"Cleared {count} cache entries")
|
||||||
|
|
||||||
return count
|
return count
|
||||||
|
|
||||||
def get_all_for_provider(self, provider: str) -> list[CachedUpload]:
|
async def aget_all_for_provider(self, provider: str) -> list[CachedUpload]:
|
||||||
"""Get all cached uploads for a provider.
|
"""Get all cached uploads for a provider.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -284,14 +324,171 @@ class UploadCache:
|
|||||||
Returns:
|
Returns:
|
||||||
List of cached uploads for the provider.
|
List of cached uploads for the provider.
|
||||||
"""
|
"""
|
||||||
with self._lock:
|
if provider not in self._provider_keys:
|
||||||
return [
|
return []
|
||||||
cached
|
|
||||||
for (_, p), cached in self._cache.items()
|
results: list[CachedUpload] = []
|
||||||
if p == provider and not cached.is_expired()
|
for key in list(self._provider_keys[provider]):
|
||||||
]
|
cached = await self._cache.get(key)
|
||||||
|
if isinstance(cached, CachedUpload) and not cached.is_expired():
|
||||||
|
results.append(cached)
|
||||||
|
return results
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# Sync wrappers (convenience)
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _run_sync(self, coro: Any) -> Any:
|
||||||
|
"""Run an async coroutine from sync context."""
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
loop = None
|
||||||
|
|
||||||
|
if loop is not None and loop.is_running():
|
||||||
|
import concurrent.futures
|
||||||
|
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||||
|
future = pool.submit(asyncio.run, coro)
|
||||||
|
return future.result()
|
||||||
|
return asyncio.run(coro)
|
||||||
|
|
||||||
|
def get(self, file: FileInput, provider: str) -> CachedUpload | None:
|
||||||
|
"""Sync wrapper for aget."""
|
||||||
|
result: CachedUpload | None = self._run_sync(self.aget(file, provider))
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_by_hash(self, file_hash: str, provider: str) -> CachedUpload | None:
|
||||||
|
"""Sync wrapper for aget_by_hash."""
|
||||||
|
result: CachedUpload | None = self._run_sync(
|
||||||
|
self.aget_by_hash(file_hash, provider)
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def set(
|
||||||
|
self,
|
||||||
|
file: FileInput,
|
||||||
|
provider: str,
|
||||||
|
file_id: str,
|
||||||
|
file_uri: str | None = None,
|
||||||
|
expires_at: datetime | None = None,
|
||||||
|
) -> CachedUpload:
|
||||||
|
"""Sync wrapper for aset."""
|
||||||
|
result: CachedUpload = self._run_sync(
|
||||||
|
self.aset(file, provider, file_id, file_uri, expires_at)
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def set_by_hash(
|
||||||
|
self,
|
||||||
|
file_hash: str,
|
||||||
|
content_type: str,
|
||||||
|
provider: str,
|
||||||
|
file_id: str,
|
||||||
|
file_uri: str | None = None,
|
||||||
|
expires_at: datetime | None = None,
|
||||||
|
) -> CachedUpload:
|
||||||
|
"""Sync wrapper for aset_by_hash."""
|
||||||
|
result: CachedUpload = self._run_sync(
|
||||||
|
self.aset_by_hash(
|
||||||
|
file_hash, content_type, provider, file_id, file_uri, expires_at
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def remove(self, file: FileInput, provider: str) -> bool:
|
||||||
|
"""Sync wrapper for aremove."""
|
||||||
|
result: bool = self._run_sync(self.aremove(file, provider))
|
||||||
|
return result
|
||||||
|
|
||||||
|
def remove_by_file_id(self, file_id: str, provider: str) -> bool:
|
||||||
|
"""Sync wrapper for aremove_by_file_id."""
|
||||||
|
result: bool = self._run_sync(self.aremove_by_file_id(file_id, provider))
|
||||||
|
return result
|
||||||
|
|
||||||
|
def clear_expired(self) -> int:
|
||||||
|
"""Sync wrapper for aclear_expired."""
|
||||||
|
result: int = self._run_sync(self.aclear_expired())
|
||||||
|
return result
|
||||||
|
|
||||||
|
def clear(self) -> int:
|
||||||
|
"""Sync wrapper for aclear."""
|
||||||
|
result: int = self._run_sync(self.aclear())
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_all_for_provider(self, provider: str) -> list[CachedUpload]:
|
||||||
|
"""Sync wrapper for aget_all_for_provider."""
|
||||||
|
result: list[CachedUpload] = self._run_sync(
|
||||||
|
self.aget_all_for_provider(provider)
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
"""Return the number of cached entries."""
|
"""Return the number of cached entries."""
|
||||||
with self._lock:
|
return sum(len(keys) for keys in self._provider_keys.values())
|
||||||
return len(self._cache)
|
|
||||||
|
def get_providers(self) -> builtins.set[str]:
|
||||||
|
"""Get all provider names that have cached entries.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Set of provider names.
|
||||||
|
"""
|
||||||
|
return builtins.set(self._provider_keys.keys())
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level cache instance
|
||||||
|
_default_cache: UploadCache | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_upload_cache(
|
||||||
|
ttl: int = DEFAULT_TTL_SECONDS,
|
||||||
|
namespace: str = "crewai_uploads",
|
||||||
|
cache_type: str = "memory",
|
||||||
|
**cache_kwargs: Any,
|
||||||
|
) -> UploadCache:
|
||||||
|
"""Get or create the default upload cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ttl: Default TTL in seconds.
|
||||||
|
namespace: Cache namespace.
|
||||||
|
cache_type: Backend type ("memory" or "redis").
|
||||||
|
**cache_kwargs: Additional args for cache backend.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The upload cache instance.
|
||||||
|
"""
|
||||||
|
global _default_cache
|
||||||
|
if _default_cache is None:
|
||||||
|
_default_cache = UploadCache(
|
||||||
|
ttl=ttl,
|
||||||
|
namespace=namespace,
|
||||||
|
cache_type=cache_type,
|
||||||
|
**cache_kwargs,
|
||||||
|
)
|
||||||
|
return _default_cache
|
||||||
|
|
||||||
|
|
||||||
|
def reset_upload_cache() -> None:
|
||||||
|
"""Reset the default upload cache (useful for testing)."""
|
||||||
|
global _default_cache
|
||||||
|
if _default_cache is not None:
|
||||||
|
_default_cache.clear()
|
||||||
|
_default_cache = None
|
||||||
|
|
||||||
|
|
||||||
|
def _cleanup_on_exit() -> None:
|
||||||
|
"""Clean up uploaded files on process exit."""
|
||||||
|
global _default_cache
|
||||||
|
if _default_cache is None or len(_default_cache) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Import here to avoid circular imports
|
||||||
|
from crewai.utilities.files.cleanup import cleanup_uploaded_files
|
||||||
|
|
||||||
|
try:
|
||||||
|
cleanup_uploaded_files(_default_cache, delete_from_provider=True)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error during exit cleanup: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
atexit.register(_cleanup_on_exit)
|
||||||
|
|||||||
206
lib/crewai/tests/utilities/files/test_upload_cache.py
Normal file
206
lib/crewai/tests/utilities/files/test_upload_cache.py
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
"""Tests for upload cache."""
|
||||||
|
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from crewai.utilities.files import FileBytes, ImageFile
|
||||||
|
from crewai.utilities.files.upload_cache import CachedUpload, UploadCache
|
||||||
|
|
||||||
|
|
||||||
|
# Minimal valid PNG
|
||||||
|
MINIMAL_PNG = (
|
||||||
|
b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x08\x00\x00\x00\x08"
|
||||||
|
b"\x01\x00\x00\x00\x00\xf9Y\xab\xcd\x00\x00\x00\nIDATx\x9cc`\x00\x00"
|
||||||
|
b"\x00\x02\x00\x01\xe2!\xbc3\x00\x00\x00\x00IEND\xaeB`\x82"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCachedUpload:
|
||||||
|
"""Tests for CachedUpload dataclass."""
|
||||||
|
|
||||||
|
def test_cached_upload_creation(self):
|
||||||
|
"""Test creating a cached upload."""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
cached = CachedUpload(
|
||||||
|
file_id="file-123",
|
||||||
|
provider="gemini",
|
||||||
|
file_uri="files/file-123",
|
||||||
|
content_type="image/png",
|
||||||
|
uploaded_at=now,
|
||||||
|
expires_at=now + timedelta(hours=48),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert cached.file_id == "file-123"
|
||||||
|
assert cached.provider == "gemini"
|
||||||
|
assert cached.file_uri == "files/file-123"
|
||||||
|
assert cached.content_type == "image/png"
|
||||||
|
|
||||||
|
def test_is_expired_false(self):
|
||||||
|
"""Test is_expired returns False for non-expired upload."""
|
||||||
|
future = datetime.now(timezone.utc) + timedelta(hours=24)
|
||||||
|
cached = CachedUpload(
|
||||||
|
file_id="file-123",
|
||||||
|
provider="gemini",
|
||||||
|
file_uri=None,
|
||||||
|
content_type="image/png",
|
||||||
|
uploaded_at=datetime.now(timezone.utc),
|
||||||
|
expires_at=future,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert cached.is_expired() is False
|
||||||
|
|
||||||
|
def test_is_expired_true(self):
|
||||||
|
"""Test is_expired returns True for expired upload."""
|
||||||
|
past = datetime.now(timezone.utc) - timedelta(hours=1)
|
||||||
|
cached = CachedUpload(
|
||||||
|
file_id="file-123",
|
||||||
|
provider="gemini",
|
||||||
|
file_uri=None,
|
||||||
|
content_type="image/png",
|
||||||
|
uploaded_at=datetime.now(timezone.utc) - timedelta(hours=2),
|
||||||
|
expires_at=past,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert cached.is_expired() is True
|
||||||
|
|
||||||
|
def test_is_expired_no_expiry(self):
|
||||||
|
"""Test is_expired returns False when no expiry set."""
|
||||||
|
cached = CachedUpload(
|
||||||
|
file_id="file-123",
|
||||||
|
provider="anthropic",
|
||||||
|
file_uri=None,
|
||||||
|
content_type="image/png",
|
||||||
|
uploaded_at=datetime.now(timezone.utc),
|
||||||
|
expires_at=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert cached.is_expired() is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestUploadCache:
|
||||||
|
"""Tests for UploadCache class."""
|
||||||
|
|
||||||
|
def test_cache_creation(self):
|
||||||
|
"""Test creating an empty cache."""
|
||||||
|
cache = UploadCache()
|
||||||
|
|
||||||
|
assert len(cache) == 0
|
||||||
|
|
||||||
|
def test_set_and_get(self):
|
||||||
|
"""Test setting and getting cached uploads."""
|
||||||
|
cache = UploadCache()
|
||||||
|
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||||
|
|
||||||
|
cached = cache.set(
|
||||||
|
file=file,
|
||||||
|
provider="gemini",
|
||||||
|
file_id="file-123",
|
||||||
|
file_uri="files/file-123",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = cache.get(file, "gemini")
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.file_id == "file-123"
|
||||||
|
assert result.provider == "gemini"
|
||||||
|
|
||||||
|
def test_get_missing(self):
|
||||||
|
"""Test getting non-existent entry returns None."""
|
||||||
|
cache = UploadCache()
|
||||||
|
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||||
|
|
||||||
|
result = cache.get(file, "gemini")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_get_different_provider(self):
|
||||||
|
"""Test getting with different provider returns None."""
|
||||||
|
cache = UploadCache()
|
||||||
|
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||||
|
|
||||||
|
cache.set(file=file, provider="gemini", file_id="file-123")
|
||||||
|
|
||||||
|
result = cache.get(file, "anthropic") # Different provider
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_remove(self):
|
||||||
|
"""Test removing cached entry."""
|
||||||
|
cache = UploadCache()
|
||||||
|
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||||
|
|
||||||
|
cache.set(file=file, provider="gemini", file_id="file-123")
|
||||||
|
removed = cache.remove(file, "gemini")
|
||||||
|
|
||||||
|
assert removed is True
|
||||||
|
assert cache.get(file, "gemini") is None
|
||||||
|
|
||||||
|
def test_remove_missing(self):
|
||||||
|
"""Test removing non-existent entry returns False."""
|
||||||
|
cache = UploadCache()
|
||||||
|
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||||
|
|
||||||
|
removed = cache.remove(file, "gemini")
|
||||||
|
|
||||||
|
assert removed is False
|
||||||
|
|
||||||
|
def test_remove_by_file_id(self):
|
||||||
|
"""Test removing by file ID."""
|
||||||
|
cache = UploadCache()
|
||||||
|
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||||
|
|
||||||
|
cache.set(file=file, provider="gemini", file_id="file-123")
|
||||||
|
removed = cache.remove_by_file_id("file-123", "gemini")
|
||||||
|
|
||||||
|
assert removed is True
|
||||||
|
assert len(cache) == 0
|
||||||
|
|
||||||
|
def test_clear_expired(self):
|
||||||
|
"""Test clearing expired entries."""
|
||||||
|
cache = UploadCache()
|
||||||
|
file1 = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test1.png"))
|
||||||
|
file2 = ImageFile(source=FileBytes(data=MINIMAL_PNG + b"x", filename="test2.png"))
|
||||||
|
|
||||||
|
# Add one expired and one valid entry
|
||||||
|
past = datetime.now(timezone.utc) - timedelta(hours=1)
|
||||||
|
future = datetime.now(timezone.utc) + timedelta(hours=24)
|
||||||
|
|
||||||
|
cache.set(file=file1, provider="gemini", file_id="expired", expires_at=past)
|
||||||
|
cache.set(file=file2, provider="gemini", file_id="valid", expires_at=future)
|
||||||
|
|
||||||
|
removed = cache.clear_expired()
|
||||||
|
|
||||||
|
assert removed == 1
|
||||||
|
assert len(cache) == 1
|
||||||
|
assert cache.get(file2, "gemini") is not None
|
||||||
|
|
||||||
|
def test_clear(self):
|
||||||
|
"""Test clearing all entries."""
|
||||||
|
cache = UploadCache()
|
||||||
|
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||||
|
|
||||||
|
cache.set(file=file, provider="gemini", file_id="file-123")
|
||||||
|
cache.set(file=file, provider="anthropic", file_id="file-456")
|
||||||
|
|
||||||
|
cleared = cache.clear()
|
||||||
|
|
||||||
|
assert cleared == 2
|
||||||
|
assert len(cache) == 0
|
||||||
|
|
||||||
|
def test_get_all_for_provider(self):
|
||||||
|
"""Test getting all cached uploads for a provider."""
|
||||||
|
cache = UploadCache()
|
||||||
|
file1 = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test1.png"))
|
||||||
|
file2 = ImageFile(source=FileBytes(data=MINIMAL_PNG + b"x", filename="test2.png"))
|
||||||
|
file3 = ImageFile(source=FileBytes(data=MINIMAL_PNG + b"xx", filename="test3.png"))
|
||||||
|
|
||||||
|
cache.set(file=file1, provider="gemini", file_id="file-1")
|
||||||
|
cache.set(file=file2, provider="gemini", file_id="file-2")
|
||||||
|
cache.set(file=file3, provider="anthropic", file_id="file-3")
|
||||||
|
|
||||||
|
gemini_uploads = cache.get_all_for_provider("gemini")
|
||||||
|
anthropic_uploads = cache.get_all_for_provider("anthropic")
|
||||||
|
|
||||||
|
assert len(gemini_uploads) == 2
|
||||||
|
assert len(anthropic_uploads) == 1
|
||||||
Reference in New Issue
Block a user