"""Tests for upload cache.""" from datetime import datetime, timedelta, timezone from crewai_files import FileBytes, ImageFile from crewai_files.cache.upload_cache import ( CachedUpload, UploadCache, _CachedUploadSerializer, ) # Minimal valid PNG MINIMAL_PNG = ( b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x08\x00\x00\x00\x08" b"\x01\x00\x00\x00\x00\xf9Y\xab\xcd\x00\x00\x00\nIDATx\x9cc`\x00\x00" b"\x00\x02\x00\x01\xe2!\xbc3\x00\x00\x00\x00IEND\xaeB`\x82" ) class TestCachedUpload: """Tests for CachedUpload dataclass.""" def test_cached_upload_creation(self): """Test creating a cached upload.""" now = datetime.now(timezone.utc) cached = CachedUpload( file_id="file-123", provider="gemini", file_uri="files/file-123", content_type="image/png", uploaded_at=now, expires_at=now + timedelta(hours=48), ) assert cached.file_id == "file-123" assert cached.provider == "gemini" assert cached.file_uri == "files/file-123" assert cached.content_type == "image/png" def test_is_expired_false(self): """Test is_expired returns False for non-expired upload.""" future = datetime.now(timezone.utc) + timedelta(hours=24) cached = CachedUpload( file_id="file-123", provider="gemini", file_uri=None, content_type="image/png", uploaded_at=datetime.now(timezone.utc), expires_at=future, ) assert cached.is_expired() is False def test_is_expired_true(self): """Test is_expired returns True for expired upload.""" past = datetime.now(timezone.utc) - timedelta(hours=1) cached = CachedUpload( file_id="file-123", provider="gemini", file_uri=None, content_type="image/png", uploaded_at=datetime.now(timezone.utc) - timedelta(hours=2), expires_at=past, ) assert cached.is_expired() is True def test_is_expired_no_expiry(self): """Test is_expired returns False when no expiry set.""" cached = CachedUpload( file_id="file-123", provider="anthropic", file_uri=None, content_type="image/png", uploaded_at=datetime.now(timezone.utc), expires_at=None, ) assert cached.is_expired() is False class TestUploadCache: """Tests for UploadCache class.""" def test_cache_creation(self): """Test creating an empty cache.""" cache = UploadCache() assert len(cache) == 0 def test_set_and_get(self): """Test setting and getting cached uploads.""" cache = UploadCache() file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png")) cache.set( file=file, provider="gemini", file_id="file-123", file_uri="files/file-123", ) result = cache.get(file, "gemini") assert result is not None assert result.file_id == "file-123" assert result.provider == "gemini" def test_get_missing(self): """Test getting non-existent entry returns None.""" cache = UploadCache() file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png")) result = cache.get(file, "gemini") assert result is None def test_get_different_provider(self): """Test getting with different provider returns None.""" cache = UploadCache() file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png")) cache.set(file=file, provider="gemini", file_id="file-123") result = cache.get(file, "anthropic") # Different provider assert result is None def test_remove(self): """Test removing cached entry.""" cache = UploadCache() file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png")) cache.set(file=file, provider="gemini", file_id="file-123") removed = cache.remove(file, "gemini") assert removed is True assert cache.get(file, "gemini") is None def test_remove_missing(self): """Test removing non-existent entry returns False.""" cache = UploadCache() file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png")) removed = cache.remove(file, "gemini") assert removed is False def test_remove_by_file_id(self): """Test removing by file ID.""" cache = UploadCache() file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png")) cache.set(file=file, provider="gemini", file_id="file-123") removed = cache.remove_by_file_id("file-123", "gemini") assert removed is True assert len(cache) == 0 def test_clear_expired(self): """Test clearing expired entries.""" cache = UploadCache() file1 = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test1.png")) file2 = ImageFile( source=FileBytes(data=MINIMAL_PNG + b"x", filename="test2.png") ) # Add one expired and one valid entry past = datetime.now(timezone.utc) - timedelta(hours=1) future = datetime.now(timezone.utc) + timedelta(hours=24) cache.set(file=file1, provider="gemini", file_id="expired", expires_at=past) cache.set(file=file2, provider="gemini", file_id="valid", expires_at=future) removed = cache.clear_expired() assert removed == 1 assert len(cache) == 1 assert cache.get(file2, "gemini") is not None def test_clear(self): """Test clearing all entries.""" cache = UploadCache() file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png")) cache.set(file=file, provider="gemini", file_id="file-123") cache.set(file=file, provider="anthropic", file_id="file-456") cleared = cache.clear() assert cleared == 2 assert len(cache) == 0 def test_get_all_for_provider(self): """Test getting all cached uploads for a provider.""" cache = UploadCache() file1 = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test1.png")) file2 = ImageFile( source=FileBytes(data=MINIMAL_PNG + b"x", filename="test2.png") ) file3 = ImageFile( source=FileBytes(data=MINIMAL_PNG + b"xx", filename="test3.png") ) cache.set(file=file1, provider="gemini", file_id="file-1") cache.set(file=file2, provider="gemini", file_id="file-2") cache.set(file=file3, provider="anthropic", file_id="file-3") gemini_uploads = cache.get_all_for_provider("gemini") anthropic_uploads = cache.get_all_for_provider("anthropic") assert len(gemini_uploads) == 2 assert len(anthropic_uploads) == 1 class TestCachedUploadSerializer: """Tests for the JSON-based CachedUpload serializer (security fix).""" def test_serializer_uses_json_not_pickle(self): """Test that the serializer produces JSON output, not pickle bytes.""" serializer = _CachedUploadSerializer() now = datetime.now(timezone.utc) cached = CachedUpload( file_id="file-123", provider="gemini", file_uri="files/file-123", content_type="image/png", uploaded_at=now, expires_at=now + timedelta(hours=48), ) dumped = serializer.dumps(cached) # Should be a JSON string, not pickle bytes assert isinstance(dumped, str) import json data = json.loads(dumped) assert data["file_id"] == "file-123" assert data["provider"] == "gemini" assert data["__cached_upload__"] is True def test_serializer_roundtrip(self): """Test that CachedUpload survives serialization/deserialization.""" serializer = _CachedUploadSerializer() now = datetime.now(timezone.utc) original = CachedUpload( file_id="file-456", provider="anthropic", file_uri=None, content_type="application/pdf", uploaded_at=now, expires_at=now + timedelta(hours=24), ) dumped = serializer.dumps(original) loaded = serializer.loads(dumped) assert isinstance(loaded, CachedUpload) assert loaded.file_id == original.file_id assert loaded.provider == original.provider assert loaded.file_uri == original.file_uri assert loaded.content_type == original.content_type def test_serializer_handles_none_expiry(self): """Test serializer handles CachedUpload with no expiry.""" serializer = _CachedUploadSerializer() now = datetime.now(timezone.utc) cached = CachedUpload( file_id="file-789", provider="gemini", file_uri=None, content_type="image/jpeg", uploaded_at=now, expires_at=None, ) dumped = serializer.dumps(cached) loaded = serializer.loads(dumped) assert isinstance(loaded, CachedUpload) assert loaded.expires_at is None def test_serializer_rejects_invalid_data(self): """Test serializer returns None for invalid/corrupted data.""" serializer = _CachedUploadSerializer() assert serializer.loads(None) is None assert serializer.loads("not valid json {{{") is None assert serializer.loads(b"binary garbage \x80\x04") is None def test_cache_set_get_roundtrip_uses_json_serializer(self): """Test that the cache properly round-trips CachedUpload through JSON.""" cache = UploadCache() file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png")) now = datetime.now(timezone.utc) future = now + timedelta(hours=24) cache.set( file=file, provider="gemini", file_id="file-sec-test", file_uri="files/file-sec-test", expires_at=future, ) result = cache.get(file, "gemini") assert result is not None assert isinstance(result, CachedUpload) assert result.file_id == "file-sec-test" assert result.file_uri == "files/file-sec-test"