mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-05 01:02:37 +00:00
Fixes #4746 - Security: Insecure Pickle Deserialization enables Arbitrary Code Execution - Replace pickle.load/dump with json.load/dump in PickleHandler (file_handler.py) - Add backward compatibility to auto-migrate legacy .pkl files to .json - Replace PickleSerializer with JSON-based _CachedUploadSerializer in upload_cache.py - Replace PickleSerializer with JsonSerializer in file_store.py and agent_card.py - Update and add comprehensive security tests for all changes Co-Authored-By: João <joao@crewai.com>
315 lines
10 KiB
Python
315 lines
10 KiB
Python
"""Tests for upload cache."""
|
|
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
from crewai_files import FileBytes, ImageFile
|
|
from crewai_files.cache.upload_cache import (
|
|
CachedUpload,
|
|
UploadCache,
|
|
_CachedUploadSerializer,
|
|
)
|
|
|
|
|
|
# Minimal valid PNG
|
|
MINIMAL_PNG = (
|
|
b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x08\x00\x00\x00\x08"
|
|
b"\x01\x00\x00\x00\x00\xf9Y\xab\xcd\x00\x00\x00\nIDATx\x9cc`\x00\x00"
|
|
b"\x00\x02\x00\x01\xe2!\xbc3\x00\x00\x00\x00IEND\xaeB`\x82"
|
|
)
|
|
|
|
|
|
class TestCachedUpload:
|
|
"""Tests for CachedUpload dataclass."""
|
|
|
|
def test_cached_upload_creation(self):
|
|
"""Test creating a cached upload."""
|
|
now = datetime.now(timezone.utc)
|
|
cached = CachedUpload(
|
|
file_id="file-123",
|
|
provider="gemini",
|
|
file_uri="files/file-123",
|
|
content_type="image/png",
|
|
uploaded_at=now,
|
|
expires_at=now + timedelta(hours=48),
|
|
)
|
|
|
|
assert cached.file_id == "file-123"
|
|
assert cached.provider == "gemini"
|
|
assert cached.file_uri == "files/file-123"
|
|
assert cached.content_type == "image/png"
|
|
|
|
def test_is_expired_false(self):
|
|
"""Test is_expired returns False for non-expired upload."""
|
|
future = datetime.now(timezone.utc) + timedelta(hours=24)
|
|
cached = CachedUpload(
|
|
file_id="file-123",
|
|
provider="gemini",
|
|
file_uri=None,
|
|
content_type="image/png",
|
|
uploaded_at=datetime.now(timezone.utc),
|
|
expires_at=future,
|
|
)
|
|
|
|
assert cached.is_expired() is False
|
|
|
|
def test_is_expired_true(self):
|
|
"""Test is_expired returns True for expired upload."""
|
|
past = datetime.now(timezone.utc) - timedelta(hours=1)
|
|
cached = CachedUpload(
|
|
file_id="file-123",
|
|
provider="gemini",
|
|
file_uri=None,
|
|
content_type="image/png",
|
|
uploaded_at=datetime.now(timezone.utc) - timedelta(hours=2),
|
|
expires_at=past,
|
|
)
|
|
|
|
assert cached.is_expired() is True
|
|
|
|
def test_is_expired_no_expiry(self):
|
|
"""Test is_expired returns False when no expiry set."""
|
|
cached = CachedUpload(
|
|
file_id="file-123",
|
|
provider="anthropic",
|
|
file_uri=None,
|
|
content_type="image/png",
|
|
uploaded_at=datetime.now(timezone.utc),
|
|
expires_at=None,
|
|
)
|
|
|
|
assert cached.is_expired() is False
|
|
|
|
|
|
class TestUploadCache:
|
|
"""Tests for UploadCache class."""
|
|
|
|
def test_cache_creation(self):
|
|
"""Test creating an empty cache."""
|
|
cache = UploadCache()
|
|
|
|
assert len(cache) == 0
|
|
|
|
def test_set_and_get(self):
|
|
"""Test setting and getting cached uploads."""
|
|
cache = UploadCache()
|
|
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
|
|
|
cache.set(
|
|
file=file,
|
|
provider="gemini",
|
|
file_id="file-123",
|
|
file_uri="files/file-123",
|
|
)
|
|
|
|
result = cache.get(file, "gemini")
|
|
|
|
assert result is not None
|
|
assert result.file_id == "file-123"
|
|
assert result.provider == "gemini"
|
|
|
|
def test_get_missing(self):
|
|
"""Test getting non-existent entry returns None."""
|
|
cache = UploadCache()
|
|
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
|
|
|
result = cache.get(file, "gemini")
|
|
|
|
assert result is None
|
|
|
|
def test_get_different_provider(self):
|
|
"""Test getting with different provider returns None."""
|
|
cache = UploadCache()
|
|
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
|
|
|
cache.set(file=file, provider="gemini", file_id="file-123")
|
|
|
|
result = cache.get(file, "anthropic") # Different provider
|
|
|
|
assert result is None
|
|
|
|
def test_remove(self):
|
|
"""Test removing cached entry."""
|
|
cache = UploadCache()
|
|
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
|
|
|
cache.set(file=file, provider="gemini", file_id="file-123")
|
|
removed = cache.remove(file, "gemini")
|
|
|
|
assert removed is True
|
|
assert cache.get(file, "gemini") is None
|
|
|
|
def test_remove_missing(self):
|
|
"""Test removing non-existent entry returns False."""
|
|
cache = UploadCache()
|
|
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
|
|
|
removed = cache.remove(file, "gemini")
|
|
|
|
assert removed is False
|
|
|
|
def test_remove_by_file_id(self):
|
|
"""Test removing by file ID."""
|
|
cache = UploadCache()
|
|
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
|
|
|
cache.set(file=file, provider="gemini", file_id="file-123")
|
|
removed = cache.remove_by_file_id("file-123", "gemini")
|
|
|
|
assert removed is True
|
|
assert len(cache) == 0
|
|
|
|
def test_clear_expired(self):
|
|
"""Test clearing expired entries."""
|
|
cache = UploadCache()
|
|
file1 = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test1.png"))
|
|
file2 = ImageFile(
|
|
source=FileBytes(data=MINIMAL_PNG + b"x", filename="test2.png")
|
|
)
|
|
|
|
# Add one expired and one valid entry
|
|
past = datetime.now(timezone.utc) - timedelta(hours=1)
|
|
future = datetime.now(timezone.utc) + timedelta(hours=24)
|
|
|
|
cache.set(file=file1, provider="gemini", file_id="expired", expires_at=past)
|
|
cache.set(file=file2, provider="gemini", file_id="valid", expires_at=future)
|
|
|
|
removed = cache.clear_expired()
|
|
|
|
assert removed == 1
|
|
assert len(cache) == 1
|
|
assert cache.get(file2, "gemini") is not None
|
|
|
|
def test_clear(self):
|
|
"""Test clearing all entries."""
|
|
cache = UploadCache()
|
|
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
|
|
|
cache.set(file=file, provider="gemini", file_id="file-123")
|
|
cache.set(file=file, provider="anthropic", file_id="file-456")
|
|
|
|
cleared = cache.clear()
|
|
|
|
assert cleared == 2
|
|
assert len(cache) == 0
|
|
|
|
def test_get_all_for_provider(self):
|
|
"""Test getting all cached uploads for a provider."""
|
|
cache = UploadCache()
|
|
file1 = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test1.png"))
|
|
file2 = ImageFile(
|
|
source=FileBytes(data=MINIMAL_PNG + b"x", filename="test2.png")
|
|
)
|
|
file3 = ImageFile(
|
|
source=FileBytes(data=MINIMAL_PNG + b"xx", filename="test3.png")
|
|
)
|
|
|
|
cache.set(file=file1, provider="gemini", file_id="file-1")
|
|
cache.set(file=file2, provider="gemini", file_id="file-2")
|
|
cache.set(file=file3, provider="anthropic", file_id="file-3")
|
|
|
|
gemini_uploads = cache.get_all_for_provider("gemini")
|
|
anthropic_uploads = cache.get_all_for_provider("anthropic")
|
|
|
|
assert len(gemini_uploads) == 2
|
|
assert len(anthropic_uploads) == 1
|
|
|
|
|
|
class TestCachedUploadSerializer:
|
|
"""Tests for the JSON-based CachedUpload serializer (security fix)."""
|
|
|
|
def test_serializer_uses_json_not_pickle(self):
|
|
"""Test that the serializer produces JSON output, not pickle bytes."""
|
|
serializer = _CachedUploadSerializer()
|
|
now = datetime.now(timezone.utc)
|
|
cached = CachedUpload(
|
|
file_id="file-123",
|
|
provider="gemini",
|
|
file_uri="files/file-123",
|
|
content_type="image/png",
|
|
uploaded_at=now,
|
|
expires_at=now + timedelta(hours=48),
|
|
)
|
|
|
|
dumped = serializer.dumps(cached)
|
|
|
|
# Should be a JSON string, not pickle bytes
|
|
assert isinstance(dumped, str)
|
|
import json
|
|
|
|
data = json.loads(dumped)
|
|
assert data["file_id"] == "file-123"
|
|
assert data["provider"] == "gemini"
|
|
assert data["__cached_upload__"] is True
|
|
|
|
def test_serializer_roundtrip(self):
|
|
"""Test that CachedUpload survives serialization/deserialization."""
|
|
serializer = _CachedUploadSerializer()
|
|
now = datetime.now(timezone.utc)
|
|
original = CachedUpload(
|
|
file_id="file-456",
|
|
provider="anthropic",
|
|
file_uri=None,
|
|
content_type="application/pdf",
|
|
uploaded_at=now,
|
|
expires_at=now + timedelta(hours=24),
|
|
)
|
|
|
|
dumped = serializer.dumps(original)
|
|
loaded = serializer.loads(dumped)
|
|
|
|
assert isinstance(loaded, CachedUpload)
|
|
assert loaded.file_id == original.file_id
|
|
assert loaded.provider == original.provider
|
|
assert loaded.file_uri == original.file_uri
|
|
assert loaded.content_type == original.content_type
|
|
|
|
def test_serializer_handles_none_expiry(self):
|
|
"""Test serializer handles CachedUpload with no expiry."""
|
|
serializer = _CachedUploadSerializer()
|
|
now = datetime.now(timezone.utc)
|
|
cached = CachedUpload(
|
|
file_id="file-789",
|
|
provider="gemini",
|
|
file_uri=None,
|
|
content_type="image/jpeg",
|
|
uploaded_at=now,
|
|
expires_at=None,
|
|
)
|
|
|
|
dumped = serializer.dumps(cached)
|
|
loaded = serializer.loads(dumped)
|
|
|
|
assert isinstance(loaded, CachedUpload)
|
|
assert loaded.expires_at is None
|
|
|
|
def test_serializer_rejects_invalid_data(self):
|
|
"""Test serializer returns None for invalid/corrupted data."""
|
|
serializer = _CachedUploadSerializer()
|
|
|
|
assert serializer.loads(None) is None
|
|
assert serializer.loads("not valid json {{{") is None
|
|
assert serializer.loads(b"binary garbage \x80\x04") is None
|
|
|
|
def test_cache_set_get_roundtrip_uses_json_serializer(self):
|
|
"""Test that the cache properly round-trips CachedUpload through JSON."""
|
|
cache = UploadCache()
|
|
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
|
|
|
now = datetime.now(timezone.utc)
|
|
future = now + timedelta(hours=24)
|
|
|
|
cache.set(
|
|
file=file,
|
|
provider="gemini",
|
|
file_id="file-sec-test",
|
|
file_uri="files/file-sec-test",
|
|
expires_at=future,
|
|
)
|
|
|
|
result = cache.get(file, "gemini")
|
|
|
|
assert result is not None
|
|
assert isinstance(result, CachedUpload)
|
|
assert result.file_id == "file-sec-test"
|
|
assert result.file_uri == "files/file-sec-test"
|