"""Cache for tracking uploaded files using aiocache.""" from __future__ import annotations import asyncio import atexit import builtins from collections.abc import Iterator from dataclasses import dataclass from datetime import datetime, timezone import hashlib import logging from typing import TYPE_CHECKING, Any from aiocache import Cache # type: ignore[import-untyped] from aiocache.serializers import PickleSerializer # type: ignore[import-untyped] from crewai_files.core.constants import DEFAULT_MAX_CACHE_ENTRIES, DEFAULT_TTL_SECONDS from crewai_files.uploaders.factory import ProviderType if TYPE_CHECKING: from crewai_files.core.types import FileInput logger = logging.getLogger(__name__) @dataclass class CachedUpload: """Represents a cached file upload. Attributes: file_id: Provider-specific file identifier. provider: Name of the provider. file_uri: Optional URI for accessing the file. content_type: MIME type of the uploaded file. uploaded_at: When the file was uploaded. expires_at: When the upload expires (if applicable). """ file_id: str provider: ProviderType file_uri: str | None content_type: str uploaded_at: datetime expires_at: datetime | None = None def is_expired(self) -> bool: """Check if this cached upload has expired.""" if self.expires_at is None: return False return datetime.now(timezone.utc) >= self.expires_at def _make_key(file_hash: str, provider: str) -> str: """Create a cache key from file hash and provider.""" return f"upload:{provider}:{file_hash}" def _compute_file_hash_streaming(chunks: Iterator[bytes]) -> str: """Compute SHA-256 hash from streaming chunks. Args: chunks: Iterator of byte chunks. Returns: Hexadecimal hash string. """ hasher = hashlib.sha256() for chunk in chunks: hasher.update(chunk) return hasher.hexdigest() def _compute_file_hash(file: FileInput) -> str: """Compute SHA-256 hash of file content. Uses streaming for FilePath sources to avoid loading large files into memory. """ from crewai_files.core.sources import FilePath source = file._file_source if isinstance(source, FilePath): return _compute_file_hash_streaming(source.read_chunks(chunk_size=1024 * 1024)) content = file.read() return hashlib.sha256(content).hexdigest() class UploadCache: """Async cache for tracking uploaded files using aiocache. Supports in-memory caching by default, with optional Redis backend for distributed setups. Attributes: ttl: Default time-to-live in seconds for cached entries. namespace: Cache namespace for isolation. """ def __init__( self, ttl: int = DEFAULT_TTL_SECONDS, namespace: str = "crewai_uploads", cache_type: str = "memory", max_entries: int | None = DEFAULT_MAX_CACHE_ENTRIES, **cache_kwargs: Any, ) -> None: """Initialize the upload cache. Args: ttl: Default TTL in seconds. namespace: Cache namespace. cache_type: Backend type ("memory" or "redis"). max_entries: Maximum cache entries (None for unlimited). **cache_kwargs: Additional args for cache backend. """ self.ttl = ttl self.namespace = namespace self.max_entries = max_entries self._provider_keys: dict[ProviderType, set[str]] = {} self._key_access_order: list[str] = [] if cache_type == "redis": self._cache = Cache( Cache.REDIS, serializer=PickleSerializer(), namespace=namespace, **cache_kwargs, ) else: self._cache = Cache( serializer=PickleSerializer(), namespace=namespace, ) def _track_key(self, provider: ProviderType, key: str) -> None: """Track a key for a provider (for cleanup) and access order.""" if provider not in self._provider_keys: self._provider_keys[provider] = set() self._provider_keys[provider].add(key) if key in self._key_access_order: self._key_access_order.remove(key) self._key_access_order.append(key) def _untrack_key(self, provider: ProviderType, key: str) -> None: """Remove key tracking for a provider.""" if provider in self._provider_keys: self._provider_keys[provider].discard(key) if key in self._key_access_order: self._key_access_order.remove(key) async def _evict_if_needed(self) -> int: """Evict oldest entries if limit exceeded. Returns: Number of entries evicted. """ if self.max_entries is None: return 0 current_count = len(self) if current_count < self.max_entries: return 0 to_evict = max(1, self.max_entries // 10) return await self._evict_oldest(to_evict) async def _evict_oldest(self, count: int) -> int: """Evict the oldest entries from the cache. Args: count: Number of entries to evict. Returns: Number of entries actually evicted. """ evicted = 0 keys_to_evict = self._key_access_order[:count] for key in keys_to_evict: await self._cache.delete(key) self._key_access_order.remove(key) for provider_keys in self._provider_keys.values(): provider_keys.discard(key) evicted += 1 if evicted > 0: logger.debug(f"Evicted {evicted} oldest cache entries") return evicted async def aget( self, file: FileInput, provider: ProviderType ) -> CachedUpload | None: """Get a cached upload for a file. Args: file: The file to look up. provider: The provider name. Returns: Cached upload if found and not expired, None otherwise. """ file_hash = _compute_file_hash(file) return await self.aget_by_hash(file_hash, provider) async def aget_by_hash( self, file_hash: str, provider: ProviderType ) -> CachedUpload | None: """Get a cached upload by file hash. Args: file_hash: Hash of the file content. provider: The provider name. Returns: Cached upload if found and not expired, None otherwise. """ key = _make_key(file_hash, provider) result = await self._cache.get(key) if result is None: return None if isinstance(result, CachedUpload): if result.is_expired(): await self._cache.delete(key) self._untrack_key(provider, key) return None return result return None async def aset( self, file: FileInput, provider: ProviderType, file_id: str, file_uri: str | None = None, expires_at: datetime | None = None, ) -> CachedUpload: """Cache an uploaded file. Args: file: The file that was uploaded. provider: The provider name. file_id: Provider-specific file identifier. file_uri: Optional URI for accessing the file. expires_at: When the upload expires. Returns: The created cache entry. """ file_hash = _compute_file_hash(file) return await self.aset_by_hash( file_hash=file_hash, content_type=file.content_type, provider=provider, file_id=file_id, file_uri=file_uri, expires_at=expires_at, ) async def aset_by_hash( self, file_hash: str, content_type: str, provider: ProviderType, file_id: str, file_uri: str | None = None, expires_at: datetime | None = None, ) -> CachedUpload: """Cache an uploaded file by hash. Args: file_hash: Hash of the file content. content_type: MIME type of the file. provider: The provider name. file_id: Provider-specific file identifier. file_uri: Optional URI for accessing the file. expires_at: When the upload expires. Returns: The created cache entry. """ await self._evict_if_needed() key = _make_key(file_hash, provider) now = datetime.now(timezone.utc) cached = CachedUpload( file_id=file_id, provider=provider, file_uri=file_uri, content_type=content_type, uploaded_at=now, expires_at=expires_at, ) ttl = self.ttl if expires_at is not None: ttl = max(0, int((expires_at - now).total_seconds())) await self._cache.set(key, cached, ttl=ttl) self._track_key(provider, key) logger.debug(f"Cached upload: {file_id} for provider {provider}") return cached async def aremove(self, file: FileInput, provider: ProviderType) -> bool: """Remove a cached upload. Args: file: The file to remove. provider: The provider name. Returns: True if entry was removed, False if not found. """ file_hash = _compute_file_hash(file) key = _make_key(file_hash, provider) result = await self._cache.delete(key) removed = bool(result > 0 if isinstance(result, int) else result) if removed: self._untrack_key(provider, key) return removed async def aremove_by_file_id(self, file_id: str, provider: ProviderType) -> bool: """Remove a cached upload by file ID. Args: file_id: The file ID to remove. provider: The provider name. Returns: True if entry was removed, False if not found. """ if provider not in self._provider_keys: return False for key in list(self._provider_keys[provider]): cached = await self._cache.get(key) if isinstance(cached, CachedUpload) and cached.file_id == file_id: await self._cache.delete(key) self._untrack_key(provider, key) return True return False async def aclear_expired(self) -> int: """Remove all expired entries from the cache. Returns: Number of entries removed. """ removed = 0 for provider, keys in list(self._provider_keys.items()): for key in list(keys): cached = await self._cache.get(key) if cached is None or ( isinstance(cached, CachedUpload) and cached.is_expired() ): await self._cache.delete(key) self._untrack_key(provider, key) removed += 1 if removed > 0: logger.debug(f"Cleared {removed} expired cache entries") return removed async def aclear(self) -> int: """Clear all entries from the cache. Returns: Number of entries cleared. """ count = sum(len(keys) for keys in self._provider_keys.values()) await self._cache.clear(namespace=self.namespace) self._provider_keys.clear() if count > 0: logger.debug(f"Cleared {count} cache entries") return count async def aget_all_for_provider(self, provider: ProviderType) -> list[CachedUpload]: """Get all cached uploads for a provider. Args: provider: The provider name. Returns: List of cached uploads for the provider. """ if provider not in self._provider_keys: return [] results: list[CachedUpload] = [] for key in list(self._provider_keys[provider]): cached = await self._cache.get(key) if isinstance(cached, CachedUpload) and not cached.is_expired(): results.append(cached) return results @staticmethod def _run_sync(coro: Any) -> Any: """Run an async coroutine from sync context without blocking event loop.""" try: loop = asyncio.get_running_loop() except RuntimeError: loop = None if loop is not None and loop.is_running(): future = asyncio.run_coroutine_threadsafe(coro, loop) return future.result(timeout=30) return asyncio.run(coro) def get(self, file: FileInput, provider: ProviderType) -> CachedUpload | None: """Sync wrapper for aget.""" result: CachedUpload | None = self._run_sync(self.aget(file, provider)) return result def get_by_hash( self, file_hash: str, provider: ProviderType ) -> CachedUpload | None: """Sync wrapper for aget_by_hash.""" result: CachedUpload | None = self._run_sync( self.aget_by_hash(file_hash, provider) ) return result def set( self, file: FileInput, provider: ProviderType, file_id: str, file_uri: str | None = None, expires_at: datetime | None = None, ) -> CachedUpload: """Sync wrapper for aset.""" result: CachedUpload = self._run_sync( self.aset(file, provider, file_id, file_uri, expires_at) ) return result def set_by_hash( self, file_hash: str, content_type: str, provider: ProviderType, file_id: str, file_uri: str | None = None, expires_at: datetime | None = None, ) -> CachedUpload: """Sync wrapper for aset_by_hash.""" result: CachedUpload = self._run_sync( self.aset_by_hash( file_hash, content_type, provider, file_id, file_uri, expires_at ) ) return result def remove(self, file: FileInput, provider: ProviderType) -> bool: """Sync wrapper for aremove.""" result: bool = self._run_sync(self.aremove(file, provider)) return result def remove_by_file_id(self, file_id: str, provider: ProviderType) -> bool: """Sync wrapper for aremove_by_file_id.""" result: bool = self._run_sync(self.aremove_by_file_id(file_id, provider)) return result def clear_expired(self) -> int: """Sync wrapper for aclear_expired.""" result: int = self._run_sync(self.aclear_expired()) return result def clear(self) -> int: """Sync wrapper for aclear.""" result: int = self._run_sync(self.aclear()) return result def get_all_for_provider(self, provider: ProviderType) -> list[CachedUpload]: """Sync wrapper for aget_all_for_provider.""" result: list[CachedUpload] = self._run_sync( self.aget_all_for_provider(provider) ) return result def __len__(self) -> int: """Return the number of cached entries.""" return sum(len(keys) for keys in self._provider_keys.values()) def get_providers(self) -> builtins.set[ProviderType]: """Get all provider names that have cached entries. Returns: Set of provider names. """ return builtins.set(self._provider_keys.keys()) _default_cache: UploadCache | None = None def get_upload_cache( ttl: int = DEFAULT_TTL_SECONDS, namespace: str = "crewai_uploads", cache_type: str = "memory", **cache_kwargs: Any, ) -> UploadCache: """Get or create the default upload cache. Args: ttl: Default TTL in seconds. namespace: Cache namespace. cache_type: Backend type ("memory" or "redis"). **cache_kwargs: Additional args for cache backend. Returns: The upload cache instance. """ global _default_cache if _default_cache is None: _default_cache = UploadCache( ttl=ttl, namespace=namespace, cache_type=cache_type, **cache_kwargs, ) return _default_cache def reset_upload_cache() -> None: """Reset the default upload cache (useful for testing).""" global _default_cache if _default_cache is not None: _default_cache.clear() _default_cache = None def _cleanup_on_exit() -> None: """Clean up uploaded files on process exit.""" global _default_cache if _default_cache is None or len(_default_cache) == 0: return from crewai_files.cache.cleanup import cleanup_uploaded_files try: cleanup_uploaded_files(_default_cache) except Exception as e: logger.debug(f"Error during exit cleanup: {e}") atexit.register(_cleanup_on_exit)