Files
crewAI/lib/crewai-files/tests/test_upload_cache.py
Devin AI 614354df4c fix: replace insecure pickle deserialization with JSON serialization
Fixes #4746 - Security: Insecure Pickle Deserialization enables Arbitrary Code Execution

- Replace pickle.load/dump with json.load/dump in PickleHandler (file_handler.py)
- Add backward compatibility to auto-migrate legacy .pkl files to .json
- Replace PickleSerializer with JSON-based _CachedUploadSerializer in upload_cache.py
- Replace PickleSerializer with JsonSerializer in file_store.py and agent_card.py
- Update and add comprehensive security tests for all changes

Co-Authored-By: João <joao@crewai.com>
2026-03-06 15:19:49 +00:00

315 lines
10 KiB
Python

"""Tests for upload cache."""
from datetime import datetime, timedelta, timezone
from crewai_files import FileBytes, ImageFile
from crewai_files.cache.upload_cache import (
CachedUpload,
UploadCache,
_CachedUploadSerializer,
)
# Minimal valid PNG
MINIMAL_PNG = (
b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x08\x00\x00\x00\x08"
b"\x01\x00\x00\x00\x00\xf9Y\xab\xcd\x00\x00\x00\nIDATx\x9cc`\x00\x00"
b"\x00\x02\x00\x01\xe2!\xbc3\x00\x00\x00\x00IEND\xaeB`\x82"
)
class TestCachedUpload:
"""Tests for CachedUpload dataclass."""
def test_cached_upload_creation(self):
"""Test creating a cached upload."""
now = datetime.now(timezone.utc)
cached = CachedUpload(
file_id="file-123",
provider="gemini",
file_uri="files/file-123",
content_type="image/png",
uploaded_at=now,
expires_at=now + timedelta(hours=48),
)
assert cached.file_id == "file-123"
assert cached.provider == "gemini"
assert cached.file_uri == "files/file-123"
assert cached.content_type == "image/png"
def test_is_expired_false(self):
"""Test is_expired returns False for non-expired upload."""
future = datetime.now(timezone.utc) + timedelta(hours=24)
cached = CachedUpload(
file_id="file-123",
provider="gemini",
file_uri=None,
content_type="image/png",
uploaded_at=datetime.now(timezone.utc),
expires_at=future,
)
assert cached.is_expired() is False
def test_is_expired_true(self):
"""Test is_expired returns True for expired upload."""
past = datetime.now(timezone.utc) - timedelta(hours=1)
cached = CachedUpload(
file_id="file-123",
provider="gemini",
file_uri=None,
content_type="image/png",
uploaded_at=datetime.now(timezone.utc) - timedelta(hours=2),
expires_at=past,
)
assert cached.is_expired() is True
def test_is_expired_no_expiry(self):
"""Test is_expired returns False when no expiry set."""
cached = CachedUpload(
file_id="file-123",
provider="anthropic",
file_uri=None,
content_type="image/png",
uploaded_at=datetime.now(timezone.utc),
expires_at=None,
)
assert cached.is_expired() is False
class TestUploadCache:
"""Tests for UploadCache class."""
def test_cache_creation(self):
"""Test creating an empty cache."""
cache = UploadCache()
assert len(cache) == 0
def test_set_and_get(self):
"""Test setting and getting cached uploads."""
cache = UploadCache()
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
cache.set(
file=file,
provider="gemini",
file_id="file-123",
file_uri="files/file-123",
)
result = cache.get(file, "gemini")
assert result is not None
assert result.file_id == "file-123"
assert result.provider == "gemini"
def test_get_missing(self):
"""Test getting non-existent entry returns None."""
cache = UploadCache()
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
result = cache.get(file, "gemini")
assert result is None
def test_get_different_provider(self):
"""Test getting with different provider returns None."""
cache = UploadCache()
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
cache.set(file=file, provider="gemini", file_id="file-123")
result = cache.get(file, "anthropic") # Different provider
assert result is None
def test_remove(self):
"""Test removing cached entry."""
cache = UploadCache()
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
cache.set(file=file, provider="gemini", file_id="file-123")
removed = cache.remove(file, "gemini")
assert removed is True
assert cache.get(file, "gemini") is None
def test_remove_missing(self):
"""Test removing non-existent entry returns False."""
cache = UploadCache()
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
removed = cache.remove(file, "gemini")
assert removed is False
def test_remove_by_file_id(self):
"""Test removing by file ID."""
cache = UploadCache()
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
cache.set(file=file, provider="gemini", file_id="file-123")
removed = cache.remove_by_file_id("file-123", "gemini")
assert removed is True
assert len(cache) == 0
def test_clear_expired(self):
"""Test clearing expired entries."""
cache = UploadCache()
file1 = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test1.png"))
file2 = ImageFile(
source=FileBytes(data=MINIMAL_PNG + b"x", filename="test2.png")
)
# Add one expired and one valid entry
past = datetime.now(timezone.utc) - timedelta(hours=1)
future = datetime.now(timezone.utc) + timedelta(hours=24)
cache.set(file=file1, provider="gemini", file_id="expired", expires_at=past)
cache.set(file=file2, provider="gemini", file_id="valid", expires_at=future)
removed = cache.clear_expired()
assert removed == 1
assert len(cache) == 1
assert cache.get(file2, "gemini") is not None
def test_clear(self):
"""Test clearing all entries."""
cache = UploadCache()
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
cache.set(file=file, provider="gemini", file_id="file-123")
cache.set(file=file, provider="anthropic", file_id="file-456")
cleared = cache.clear()
assert cleared == 2
assert len(cache) == 0
def test_get_all_for_provider(self):
"""Test getting all cached uploads for a provider."""
cache = UploadCache()
file1 = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test1.png"))
file2 = ImageFile(
source=FileBytes(data=MINIMAL_PNG + b"x", filename="test2.png")
)
file3 = ImageFile(
source=FileBytes(data=MINIMAL_PNG + b"xx", filename="test3.png")
)
cache.set(file=file1, provider="gemini", file_id="file-1")
cache.set(file=file2, provider="gemini", file_id="file-2")
cache.set(file=file3, provider="anthropic", file_id="file-3")
gemini_uploads = cache.get_all_for_provider("gemini")
anthropic_uploads = cache.get_all_for_provider("anthropic")
assert len(gemini_uploads) == 2
assert len(anthropic_uploads) == 1
class TestCachedUploadSerializer:
"""Tests for the JSON-based CachedUpload serializer (security fix)."""
def test_serializer_uses_json_not_pickle(self):
"""Test that the serializer produces JSON output, not pickle bytes."""
serializer = _CachedUploadSerializer()
now = datetime.now(timezone.utc)
cached = CachedUpload(
file_id="file-123",
provider="gemini",
file_uri="files/file-123",
content_type="image/png",
uploaded_at=now,
expires_at=now + timedelta(hours=48),
)
dumped = serializer.dumps(cached)
# Should be a JSON string, not pickle bytes
assert isinstance(dumped, str)
import json
data = json.loads(dumped)
assert data["file_id"] == "file-123"
assert data["provider"] == "gemini"
assert data["__cached_upload__"] is True
def test_serializer_roundtrip(self):
"""Test that CachedUpload survives serialization/deserialization."""
serializer = _CachedUploadSerializer()
now = datetime.now(timezone.utc)
original = CachedUpload(
file_id="file-456",
provider="anthropic",
file_uri=None,
content_type="application/pdf",
uploaded_at=now,
expires_at=now + timedelta(hours=24),
)
dumped = serializer.dumps(original)
loaded = serializer.loads(dumped)
assert isinstance(loaded, CachedUpload)
assert loaded.file_id == original.file_id
assert loaded.provider == original.provider
assert loaded.file_uri == original.file_uri
assert loaded.content_type == original.content_type
def test_serializer_handles_none_expiry(self):
"""Test serializer handles CachedUpload with no expiry."""
serializer = _CachedUploadSerializer()
now = datetime.now(timezone.utc)
cached = CachedUpload(
file_id="file-789",
provider="gemini",
file_uri=None,
content_type="image/jpeg",
uploaded_at=now,
expires_at=None,
)
dumped = serializer.dumps(cached)
loaded = serializer.loads(dumped)
assert isinstance(loaded, CachedUpload)
assert loaded.expires_at is None
def test_serializer_rejects_invalid_data(self):
"""Test serializer returns None for invalid/corrupted data."""
serializer = _CachedUploadSerializer()
assert serializer.loads(None) is None
assert serializer.loads("not valid json {{{") is None
assert serializer.loads(b"binary garbage \x80\x04") is None
def test_cache_set_get_roundtrip_uses_json_serializer(self):
"""Test that the cache properly round-trips CachedUpload through JSON."""
cache = UploadCache()
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
now = datetime.now(timezone.utc)
future = now + timedelta(hours=24)
cache.set(
file=file,
provider="gemini",
file_id="file-sec-test",
file_uri="files/file-sec-test",
expires_at=future,
)
result = cache.get(file, "gemini")
assert result is not None
assert isinstance(result, CachedUpload)
assert result.file_id == "file-sec-test"
assert result.file_uri == "files/file-sec-test"