diff --git a/lib/crewai/src/crewai/utilities/files/cleanup.py b/lib/crewai/src/crewai/utilities/files/cleanup.py index b86737456..1444d1a80 100644 --- a/lib/crewai/src/crewai/utilities/files/cleanup.py +++ b/lib/crewai/src/crewai/utilities/files/cleanup.py @@ -177,10 +177,4 @@ def _get_providers_from_cache(cache: UploadCache) -> set[str]: Returns: Set of provider names. """ - providers: set[str] = set() - - with cache._lock: - for _, provider in cache._cache.keys(): - providers.add(provider) - - return providers + return cache.get_providers() diff --git a/lib/crewai/src/crewai/utilities/files/upload_cache.py b/lib/crewai/src/crewai/utilities/files/upload_cache.py index baa21a23e..ea83d2bf4 100644 --- a/lib/crewai/src/crewai/utilities/files/upload_cache.py +++ b/lib/crewai/src/crewai/utilities/files/upload_cache.py @@ -1,24 +1,36 @@ -"""Cache for tracking uploaded files to avoid redundant uploads.""" +"""Cache for tracking uploaded files using aiocache.""" -from dataclasses import dataclass, field -from datetime import datetime, timedelta, timezone +from __future__ import annotations + +import asyncio +import atexit +import builtins +from dataclasses import dataclass +from datetime import datetime, timezone import hashlib import logging -import threading +from typing import TYPE_CHECKING, Any -from crewai.utilities.files.content_types import ( - AudioFile, - File, - ImageFile, - PDFFile, - TextFile, - VideoFile, -) +from aiocache import Cache # type: ignore[import-untyped] +from aiocache.serializers import PickleSerializer # type: ignore[import-untyped] + + +if TYPE_CHECKING: + from crewai.utilities.files.content_types import ( + AudioFile, + File, + ImageFile, + PDFFile, + TextFile, + VideoFile, + ) + + FileInput = AudioFile | File | ImageFile | PDFFile | TextFile | VideoFile logger = logging.getLogger(__name__) -FileInput = AudioFile | File | ImageFile | PDFFile | TextFile | VideoFile +DEFAULT_TTL_SECONDS = 24 * 60 * 60 # 24 hours @dataclass @@ -42,44 +54,83 @@ class CachedUpload: expires_at: datetime | None = None def is_expired(self) -> bool: - """Check if this cached upload has expired. - - Returns: - True if expired, False otherwise. - """ + """Check if this cached upload has expired.""" if self.expires_at is None: return False return datetime.now(timezone.utc) >= self.expires_at -@dataclass -class UploadCache: - """Thread-safe cache for tracking uploaded files. +def _make_key(file_hash: str, provider: str) -> str: + """Create a cache key from file hash and provider.""" + return f"upload:{provider}:{file_hash}" - Uses file content hash and provider as composite key to avoid - uploading the same file multiple times. + +def _compute_file_hash(file: FileInput) -> str: + """Compute SHA-256 hash of file content.""" + content = file.read() + return hashlib.sha256(content).hexdigest() + + +class UploadCache: + """Async cache for tracking uploaded files using aiocache. + + Supports in-memory caching by default, with optional Redis backend + for distributed setups. Attributes: - default_ttl: Default time-to-live for cached entries. + ttl: Default time-to-live in seconds for cached entries. + namespace: Cache namespace for isolation. """ - default_ttl: timedelta = field(default_factory=lambda: timedelta(hours=24)) - _cache: dict[tuple[str, str], CachedUpload] = field(default_factory=dict) - _lock: threading.Lock = field(default_factory=threading.Lock) - - def _compute_hash(self, file: FileInput) -> str: - """Compute a hash of file content for cache key. + def __init__( + self, + ttl: int = DEFAULT_TTL_SECONDS, + namespace: str = "crewai_uploads", + cache_type: str = "memory", + **cache_kwargs: Any, + ) -> None: + """Initialize the upload cache. Args: - file: The file to hash. - - Returns: - SHA-256 hash of the file content. + ttl: Default TTL in seconds. + namespace: Cache namespace. + cache_type: Backend type ("memory" or "redis"). + **cache_kwargs: Additional args for cache backend. """ - content = file.source.read() - return hashlib.sha256(content).hexdigest() + self.ttl = ttl + self.namespace = namespace + self._provider_keys: dict[str, set[str]] = {} - def get(self, file: FileInput, provider: str) -> CachedUpload | None: + if cache_type == "redis": + self._cache = Cache( + Cache.REDIS, + serializer=PickleSerializer(), + namespace=namespace, + **cache_kwargs, + ) + else: + self._cache = Cache( + Cache.MEMORY, + serializer=PickleSerializer(), + namespace=namespace, + ) + + def _track_key(self, provider: str, key: str) -> None: + """Track a key for a provider (for cleanup).""" + if provider not in self._provider_keys: + self._provider_keys[provider] = set() + self._provider_keys[provider].add(key) + + def _untrack_key(self, provider: str, key: str) -> None: + """Remove key tracking for a provider.""" + if provider in self._provider_keys: + self._provider_keys[provider].discard(key) + + # ------------------------------------------------------------------------- + # Async methods (primary interface) + # ------------------------------------------------------------------------- + + async def aget(self, file: FileInput, provider: str) -> CachedUpload | None: """Get a cached upload for a file. Args: @@ -89,21 +140,10 @@ class UploadCache: Returns: Cached upload if found and not expired, None otherwise. """ - file_hash = self._compute_hash(file) - key = (file_hash, provider) + file_hash = _compute_file_hash(file) + return await self.aget_by_hash(file_hash, provider) - with self._lock: - cached = self._cache.get(key) - if cached is None: - return None - - if cached.is_expired(): - del self._cache[key] - return None - - return cached - - def get_by_hash(self, file_hash: str, provider: str) -> CachedUpload | None: + async def aget_by_hash(self, file_hash: str, provider: str) -> CachedUpload | None: """Get a cached upload by file hash. Args: @@ -113,20 +153,20 @@ class UploadCache: Returns: Cached upload if found and not expired, None otherwise. """ - key = (file_hash, provider) + key = _make_key(file_hash, provider) + result = await self._cache.get(key) - with self._lock: - cached = self._cache.get(key) - if cached is None: + if result is None: + return None + if isinstance(result, CachedUpload): + if result.is_expired(): + await self._cache.delete(key) + self._untrack_key(provider, key) return None + return result + return None - if cached.is_expired(): - del self._cache[key] - return None - - return cached - - def set( + async def aset( self, file: FileInput, provider: str, @@ -146,26 +186,17 @@ class UploadCache: Returns: The created cache entry. """ - file_hash = self._compute_hash(file) - key = (file_hash, provider) - - now = datetime.now(timezone.utc) - cached = CachedUpload( - file_id=file_id, - provider=provider, - file_uri=file_uri, + file_hash = _compute_file_hash(file) + return await self.aset_by_hash( + file_hash=file_hash, content_type=file.content_type, - uploaded_at=now, + provider=provider, + file_id=file_id, + file_uri=file_uri, expires_at=expires_at, ) - with self._lock: - self._cache[key] = cached - - logger.debug(f"Cached upload: {file_id} for provider {provider}") - return cached - - def set_by_hash( + async def aset_by_hash( self, file_hash: str, content_type: str, @@ -187,9 +218,9 @@ class UploadCache: Returns: The created cache entry. """ - key = (file_hash, provider) - + key = _make_key(file_hash, provider) now = datetime.now(timezone.utc) + cached = CachedUpload( file_id=file_id, provider=provider, @@ -199,13 +230,16 @@ class UploadCache: expires_at=expires_at, ) - with self._lock: - self._cache[key] = cached + ttl = self.ttl + if expires_at is not None: + ttl = max(0, int((expires_at - now).total_seconds())) + await self._cache.set(key, cached, ttl=ttl) + self._track_key(provider, key) logger.debug(f"Cached upload: {file_id} for provider {provider}") return cached - def remove(self, file: FileInput, provider: str) -> bool: + async def aremove(self, file: FileInput, provider: str) -> bool: """Remove a cached upload. Args: @@ -215,16 +249,16 @@ class UploadCache: Returns: True if entry was removed, False if not found. """ - file_hash = self._compute_hash(file) - key = (file_hash, provider) + file_hash = _compute_file_hash(file) + key = _make_key(file_hash, provider) - with self._lock: - if key in self._cache: - del self._cache[key] - return True - return False + result = await self._cache.delete(key) + removed = bool(result > 0 if isinstance(result, int) else result) + if removed: + self._untrack_key(provider, key) + return removed - def remove_by_file_id(self, file_id: str, provider: str) -> bool: + async def aremove_by_file_id(self, file_id: str, provider: str) -> bool: """Remove a cached upload by file ID. Args: @@ -234,14 +268,18 @@ class UploadCache: Returns: True if entry was removed, False if not found. """ - with self._lock: - for key, cached in list(self._cache.items()): - if cached.file_id == file_id and cached.provider == provider: - del self._cache[key] - return True + if provider not in self._provider_keys: return False - def clear_expired(self) -> int: + for key in list(self._provider_keys[provider]): + cached = await self._cache.get(key) + if isinstance(cached, CachedUpload) and cached.file_id == file_id: + await self._cache.delete(key) + self._untrack_key(provider, key) + return True + return False + + async def aclear_expired(self) -> int: """Remove all expired entries from the cache. Returns: @@ -249,33 +287,35 @@ class UploadCache: """ removed = 0 - with self._lock: - for key in list(self._cache.keys()): - if self._cache[key].is_expired(): - del self._cache[key] + for provider, keys in list(self._provider_keys.items()): + for key in list(keys): + cached = await self._cache.get(key) + if cached is None or ( + isinstance(cached, CachedUpload) and cached.is_expired() + ): + await self._cache.delete(key) + self._untrack_key(provider, key) removed += 1 if removed > 0: logger.debug(f"Cleared {removed} expired cache entries") - return removed - def clear(self) -> int: + async def aclear(self) -> int: """Clear all entries from the cache. Returns: Number of entries cleared. """ - with self._lock: - count = len(self._cache) - self._cache.clear() + count = sum(len(keys) for keys in self._provider_keys.values()) + await self._cache.clear(namespace=self.namespace) + self._provider_keys.clear() if count > 0: logger.debug(f"Cleared {count} cache entries") - return count - def get_all_for_provider(self, provider: str) -> list[CachedUpload]: + async def aget_all_for_provider(self, provider: str) -> list[CachedUpload]: """Get all cached uploads for a provider. Args: @@ -284,14 +324,171 @@ class UploadCache: Returns: List of cached uploads for the provider. """ - with self._lock: - return [ - cached - for (_, p), cached in self._cache.items() - if p == provider and not cached.is_expired() - ] + if provider not in self._provider_keys: + return [] + + results: list[CachedUpload] = [] + for key in list(self._provider_keys[provider]): + cached = await self._cache.get(key) + if isinstance(cached, CachedUpload) and not cached.is_expired(): + results.append(cached) + return results + + # ------------------------------------------------------------------------- + # Sync wrappers (convenience) + # ------------------------------------------------------------------------- + + def _run_sync(self, coro: Any) -> Any: + """Run an async coroutine from sync context.""" + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop is not None and loop.is_running(): + import concurrent.futures + + with concurrent.futures.ThreadPoolExecutor() as pool: + future = pool.submit(asyncio.run, coro) + return future.result() + return asyncio.run(coro) + + def get(self, file: FileInput, provider: str) -> CachedUpload | None: + """Sync wrapper for aget.""" + result: CachedUpload | None = self._run_sync(self.aget(file, provider)) + return result + + def get_by_hash(self, file_hash: str, provider: str) -> CachedUpload | None: + """Sync wrapper for aget_by_hash.""" + result: CachedUpload | None = self._run_sync( + self.aget_by_hash(file_hash, provider) + ) + return result + + def set( + self, + file: FileInput, + provider: str, + file_id: str, + file_uri: str | None = None, + expires_at: datetime | None = None, + ) -> CachedUpload: + """Sync wrapper for aset.""" + result: CachedUpload = self._run_sync( + self.aset(file, provider, file_id, file_uri, expires_at) + ) + return result + + def set_by_hash( + self, + file_hash: str, + content_type: str, + provider: str, + file_id: str, + file_uri: str | None = None, + expires_at: datetime | None = None, + ) -> CachedUpload: + """Sync wrapper for aset_by_hash.""" + result: CachedUpload = self._run_sync( + self.aset_by_hash( + file_hash, content_type, provider, file_id, file_uri, expires_at + ) + ) + return result + + def remove(self, file: FileInput, provider: str) -> bool: + """Sync wrapper for aremove.""" + result: bool = self._run_sync(self.aremove(file, provider)) + return result + + def remove_by_file_id(self, file_id: str, provider: str) -> bool: + """Sync wrapper for aremove_by_file_id.""" + result: bool = self._run_sync(self.aremove_by_file_id(file_id, provider)) + return result + + def clear_expired(self) -> int: + """Sync wrapper for aclear_expired.""" + result: int = self._run_sync(self.aclear_expired()) + return result + + def clear(self) -> int: + """Sync wrapper for aclear.""" + result: int = self._run_sync(self.aclear()) + return result + + def get_all_for_provider(self, provider: str) -> list[CachedUpload]: + """Sync wrapper for aget_all_for_provider.""" + result: list[CachedUpload] = self._run_sync( + self.aget_all_for_provider(provider) + ) + return result def __len__(self) -> int: """Return the number of cached entries.""" - with self._lock: - return len(self._cache) + return sum(len(keys) for keys in self._provider_keys.values()) + + def get_providers(self) -> builtins.set[str]: + """Get all provider names that have cached entries. + + Returns: + Set of provider names. + """ + return builtins.set(self._provider_keys.keys()) + + +# Module-level cache instance +_default_cache: UploadCache | None = None + + +def get_upload_cache( + ttl: int = DEFAULT_TTL_SECONDS, + namespace: str = "crewai_uploads", + cache_type: str = "memory", + **cache_kwargs: Any, +) -> UploadCache: + """Get or create the default upload cache. + + Args: + ttl: Default TTL in seconds. + namespace: Cache namespace. + cache_type: Backend type ("memory" or "redis"). + **cache_kwargs: Additional args for cache backend. + + Returns: + The upload cache instance. + """ + global _default_cache + if _default_cache is None: + _default_cache = UploadCache( + ttl=ttl, + namespace=namespace, + cache_type=cache_type, + **cache_kwargs, + ) + return _default_cache + + +def reset_upload_cache() -> None: + """Reset the default upload cache (useful for testing).""" + global _default_cache + if _default_cache is not None: + _default_cache.clear() + _default_cache = None + + +def _cleanup_on_exit() -> None: + """Clean up uploaded files on process exit.""" + global _default_cache + if _default_cache is None or len(_default_cache) == 0: + return + + # Import here to avoid circular imports + from crewai.utilities.files.cleanup import cleanup_uploaded_files + + try: + cleanup_uploaded_files(_default_cache, delete_from_provider=True) + except Exception as e: + logger.debug(f"Error during exit cleanup: {e}") + + +atexit.register(_cleanup_on_exit) diff --git a/lib/crewai/tests/utilities/files/test_upload_cache.py b/lib/crewai/tests/utilities/files/test_upload_cache.py new file mode 100644 index 000000000..e3b8ebe72 --- /dev/null +++ b/lib/crewai/tests/utilities/files/test_upload_cache.py @@ -0,0 +1,206 @@ +"""Tests for upload cache.""" + +from datetime import datetime, timedelta, timezone + +import pytest + +from crewai.utilities.files import FileBytes, ImageFile +from crewai.utilities.files.upload_cache import CachedUpload, UploadCache + + +# Minimal valid PNG +MINIMAL_PNG = ( + b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x08\x00\x00\x00\x08" + b"\x01\x00\x00\x00\x00\xf9Y\xab\xcd\x00\x00\x00\nIDATx\x9cc`\x00\x00" + b"\x00\x02\x00\x01\xe2!\xbc3\x00\x00\x00\x00IEND\xaeB`\x82" +) + + +class TestCachedUpload: + """Tests for CachedUpload dataclass.""" + + def test_cached_upload_creation(self): + """Test creating a cached upload.""" + now = datetime.now(timezone.utc) + cached = CachedUpload( + file_id="file-123", + provider="gemini", + file_uri="files/file-123", + content_type="image/png", + uploaded_at=now, + expires_at=now + timedelta(hours=48), + ) + + assert cached.file_id == "file-123" + assert cached.provider == "gemini" + assert cached.file_uri == "files/file-123" + assert cached.content_type == "image/png" + + def test_is_expired_false(self): + """Test is_expired returns False for non-expired upload.""" + future = datetime.now(timezone.utc) + timedelta(hours=24) + cached = CachedUpload( + file_id="file-123", + provider="gemini", + file_uri=None, + content_type="image/png", + uploaded_at=datetime.now(timezone.utc), + expires_at=future, + ) + + assert cached.is_expired() is False + + def test_is_expired_true(self): + """Test is_expired returns True for expired upload.""" + past = datetime.now(timezone.utc) - timedelta(hours=1) + cached = CachedUpload( + file_id="file-123", + provider="gemini", + file_uri=None, + content_type="image/png", + uploaded_at=datetime.now(timezone.utc) - timedelta(hours=2), + expires_at=past, + ) + + assert cached.is_expired() is True + + def test_is_expired_no_expiry(self): + """Test is_expired returns False when no expiry set.""" + cached = CachedUpload( + file_id="file-123", + provider="anthropic", + file_uri=None, + content_type="image/png", + uploaded_at=datetime.now(timezone.utc), + expires_at=None, + ) + + assert cached.is_expired() is False + + +class TestUploadCache: + """Tests for UploadCache class.""" + + def test_cache_creation(self): + """Test creating an empty cache.""" + cache = UploadCache() + + assert len(cache) == 0 + + def test_set_and_get(self): + """Test setting and getting cached uploads.""" + cache = UploadCache() + file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png")) + + cached = cache.set( + file=file, + provider="gemini", + file_id="file-123", + file_uri="files/file-123", + ) + + result = cache.get(file, "gemini") + + assert result is not None + assert result.file_id == "file-123" + assert result.provider == "gemini" + + def test_get_missing(self): + """Test getting non-existent entry returns None.""" + cache = UploadCache() + file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png")) + + result = cache.get(file, "gemini") + + assert result is None + + def test_get_different_provider(self): + """Test getting with different provider returns None.""" + cache = UploadCache() + file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png")) + + cache.set(file=file, provider="gemini", file_id="file-123") + + result = cache.get(file, "anthropic") # Different provider + + assert result is None + + def test_remove(self): + """Test removing cached entry.""" + cache = UploadCache() + file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png")) + + cache.set(file=file, provider="gemini", file_id="file-123") + removed = cache.remove(file, "gemini") + + assert removed is True + assert cache.get(file, "gemini") is None + + def test_remove_missing(self): + """Test removing non-existent entry returns False.""" + cache = UploadCache() + file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png")) + + removed = cache.remove(file, "gemini") + + assert removed is False + + def test_remove_by_file_id(self): + """Test removing by file ID.""" + cache = UploadCache() + file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png")) + + cache.set(file=file, provider="gemini", file_id="file-123") + removed = cache.remove_by_file_id("file-123", "gemini") + + assert removed is True + assert len(cache) == 0 + + def test_clear_expired(self): + """Test clearing expired entries.""" + cache = UploadCache() + file1 = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test1.png")) + file2 = ImageFile(source=FileBytes(data=MINIMAL_PNG + b"x", filename="test2.png")) + + # Add one expired and one valid entry + past = datetime.now(timezone.utc) - timedelta(hours=1) + future = datetime.now(timezone.utc) + timedelta(hours=24) + + cache.set(file=file1, provider="gemini", file_id="expired", expires_at=past) + cache.set(file=file2, provider="gemini", file_id="valid", expires_at=future) + + removed = cache.clear_expired() + + assert removed == 1 + assert len(cache) == 1 + assert cache.get(file2, "gemini") is not None + + def test_clear(self): + """Test clearing all entries.""" + cache = UploadCache() + file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png")) + + cache.set(file=file, provider="gemini", file_id="file-123") + cache.set(file=file, provider="anthropic", file_id="file-456") + + cleared = cache.clear() + + assert cleared == 2 + assert len(cache) == 0 + + def test_get_all_for_provider(self): + """Test getting all cached uploads for a provider.""" + cache = UploadCache() + file1 = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test1.png")) + file2 = ImageFile(source=FileBytes(data=MINIMAL_PNG + b"x", filename="test2.png")) + file3 = ImageFile(source=FileBytes(data=MINIMAL_PNG + b"xx", filename="test3.png")) + + cache.set(file=file1, provider="gemini", file_id="file-1") + cache.set(file=file2, provider="gemini", file_id="file-2") + cache.set(file=file3, provider="anthropic", file_id="file-3") + + gemini_uploads = cache.get_all_for_provider("gemini") + anthropic_uploads = cache.get_all_for_provider("anthropic") + + assert len(gemini_uploads) == 2 + assert len(anthropic_uploads) == 1