mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-12 05:52:39 +00:00
Compare commits
2 Commits
main
...
devin/1772
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c1ae3da1cd | ||
|
|
614354df4c |
@@ -13,7 +13,7 @@ import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from aiocache import Cache # type: ignore[import-untyped]
|
||||
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
|
||||
from aiocache.serializers import JsonSerializer # type: ignore[import-untyped]
|
||||
|
||||
from crewai_files.core.constants import DEFAULT_MAX_CACHE_ENTRIES, DEFAULT_TTL_SECONDS
|
||||
from crewai_files.uploaders.factory import ProviderType
|
||||
@@ -25,6 +25,62 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _CachedUploadSerializer(JsonSerializer): # type: ignore[misc]
|
||||
"""JSON-based serializer that safely handles CachedUpload dataclass.
|
||||
|
||||
Uses JSON instead of pickle to prevent insecure deserialization
|
||||
vulnerabilities (CWE-502).
|
||||
"""
|
||||
|
||||
def dumps(self, value: Any) -> str: # type: ignore[override]
|
||||
"""Serialize value to JSON string, converting CachedUpload to dict."""
|
||||
import json
|
||||
|
||||
if isinstance(value, CachedUpload):
|
||||
data = {
|
||||
"__cached_upload__": True,
|
||||
"file_id": value.file_id,
|
||||
"provider": value.provider,
|
||||
"file_uri": value.file_uri,
|
||||
"content_type": value.content_type,
|
||||
"uploaded_at": value.uploaded_at.isoformat(),
|
||||
"expires_at": value.expires_at.isoformat() if value.expires_at else None,
|
||||
}
|
||||
return json.dumps(data)
|
||||
return json.dumps(value)
|
||||
|
||||
def loads(self, value: str | bytes | None) -> Any: # type: ignore[override]
|
||||
"""Deserialize JSON string, reconstructing CachedUpload if applicable."""
|
||||
import json
|
||||
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, bytes):
|
||||
try:
|
||||
value = value.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
return None
|
||||
try:
|
||||
data = json.loads(value)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return None
|
||||
|
||||
if isinstance(data, dict) and data.get("__cached_upload__"):
|
||||
return CachedUpload(
|
||||
file_id=data["file_id"],
|
||||
provider=data["provider"],
|
||||
file_uri=data.get("file_uri"),
|
||||
content_type=data["content_type"],
|
||||
uploaded_at=datetime.fromisoformat(data["uploaded_at"]),
|
||||
expires_at=(
|
||||
datetime.fromisoformat(data["expires_at"])
|
||||
if data.get("expires_at")
|
||||
else None
|
||||
),
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
@dataclass
|
||||
class CachedUpload:
|
||||
"""Represents a cached file upload.
|
||||
@@ -123,13 +179,13 @@ class UploadCache:
|
||||
if cache_type == "redis":
|
||||
self._cache = Cache(
|
||||
Cache.REDIS,
|
||||
serializer=PickleSerializer(),
|
||||
serializer=_CachedUploadSerializer(),
|
||||
namespace=namespace,
|
||||
**cache_kwargs,
|
||||
)
|
||||
else:
|
||||
self._cache = Cache(
|
||||
serializer=PickleSerializer(),
|
||||
serializer=_CachedUploadSerializer(),
|
||||
namespace=namespace,
|
||||
)
|
||||
|
||||
|
||||
@@ -3,7 +3,11 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from crewai_files import FileBytes, ImageFile
|
||||
from crewai_files.cache.upload_cache import CachedUpload, UploadCache
|
||||
from crewai_files.cache.upload_cache import (
|
||||
CachedUpload,
|
||||
UploadCache,
|
||||
_CachedUploadSerializer,
|
||||
)
|
||||
|
||||
|
||||
# Minimal valid PNG
|
||||
@@ -208,3 +212,103 @@ class TestUploadCache:
|
||||
|
||||
assert len(gemini_uploads) == 2
|
||||
assert len(anthropic_uploads) == 1
|
||||
|
||||
|
||||
class TestCachedUploadSerializer:
|
||||
"""Tests for the JSON-based CachedUpload serializer (security fix)."""
|
||||
|
||||
def test_serializer_uses_json_not_pickle(self):
|
||||
"""Test that the serializer produces JSON output, not pickle bytes."""
|
||||
serializer = _CachedUploadSerializer()
|
||||
now = datetime.now(timezone.utc)
|
||||
cached = CachedUpload(
|
||||
file_id="file-123",
|
||||
provider="gemini",
|
||||
file_uri="files/file-123",
|
||||
content_type="image/png",
|
||||
uploaded_at=now,
|
||||
expires_at=now + timedelta(hours=48),
|
||||
)
|
||||
|
||||
dumped = serializer.dumps(cached)
|
||||
|
||||
# Should be a JSON string, not pickle bytes
|
||||
assert isinstance(dumped, str)
|
||||
import json
|
||||
|
||||
data = json.loads(dumped)
|
||||
assert data["file_id"] == "file-123"
|
||||
assert data["provider"] == "gemini"
|
||||
assert data["__cached_upload__"] is True
|
||||
|
||||
def test_serializer_roundtrip(self):
|
||||
"""Test that CachedUpload survives serialization/deserialization."""
|
||||
serializer = _CachedUploadSerializer()
|
||||
now = datetime.now(timezone.utc)
|
||||
original = CachedUpload(
|
||||
file_id="file-456",
|
||||
provider="anthropic",
|
||||
file_uri=None,
|
||||
content_type="application/pdf",
|
||||
uploaded_at=now,
|
||||
expires_at=now + timedelta(hours=24),
|
||||
)
|
||||
|
||||
dumped = serializer.dumps(original)
|
||||
loaded = serializer.loads(dumped)
|
||||
|
||||
assert isinstance(loaded, CachedUpload)
|
||||
assert loaded.file_id == original.file_id
|
||||
assert loaded.provider == original.provider
|
||||
assert loaded.file_uri == original.file_uri
|
||||
assert loaded.content_type == original.content_type
|
||||
|
||||
def test_serializer_handles_none_expiry(self):
|
||||
"""Test serializer handles CachedUpload with no expiry."""
|
||||
serializer = _CachedUploadSerializer()
|
||||
now = datetime.now(timezone.utc)
|
||||
cached = CachedUpload(
|
||||
file_id="file-789",
|
||||
provider="gemini",
|
||||
file_uri=None,
|
||||
content_type="image/jpeg",
|
||||
uploaded_at=now,
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
dumped = serializer.dumps(cached)
|
||||
loaded = serializer.loads(dumped)
|
||||
|
||||
assert isinstance(loaded, CachedUpload)
|
||||
assert loaded.expires_at is None
|
||||
|
||||
def test_serializer_rejects_invalid_data(self):
|
||||
"""Test serializer returns None for invalid/corrupted data."""
|
||||
serializer = _CachedUploadSerializer()
|
||||
|
||||
assert serializer.loads(None) is None
|
||||
assert serializer.loads("not valid json {{{") is None
|
||||
assert serializer.loads(b"binary garbage \x80\x04") is None
|
||||
|
||||
def test_cache_set_get_roundtrip_uses_json_serializer(self):
|
||||
"""Test that the cache properly round-trips CachedUpload through JSON."""
|
||||
cache = UploadCache()
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
future = now + timedelta(hours=24)
|
||||
|
||||
cache.set(
|
||||
file=file,
|
||||
provider="gemini",
|
||||
file_id="file-sec-test",
|
||||
file_uri="files/file-sec-test",
|
||||
expires_at=future,
|
||||
)
|
||||
|
||||
result = cache.get(file, "gemini")
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, CachedUpload)
|
||||
assert result.file_id == "file-sec-test"
|
||||
assert result.file_uri == "files/file-sec-test"
|
||||
|
||||
@@ -220,6 +220,9 @@ def _fetch_agent_card_cached(
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
# PickleSerializer is safe here: this is an in-memory cache only.
|
||||
# Data never leaves the process, so there is no untrusted deserialization risk.
|
||||
# JsonSerializer would break AgentCard (Pydantic model) serialization.
|
||||
@cached(ttl=300, serializer=PickleSerializer()) # type: ignore[untyped-decorator]
|
||||
async def _afetch_agent_card_cached(
|
||||
endpoint: str,
|
||||
|
||||
@@ -8,8 +8,8 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
import contextvars
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import contextvars
|
||||
import inspect
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
@@ -1599,16 +1599,19 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
# Initialize or retrieve agent's training data
|
||||
agent_training_data = training_data.get(agent_id, {})
|
||||
|
||||
# Use string key for JSON compatibility (JSON converts int keys to strings)
|
||||
train_key = str(train_iteration)
|
||||
|
||||
if human_feedback is not None:
|
||||
# Save initial output and human feedback
|
||||
agent_training_data[train_iteration] = {
|
||||
agent_training_data[train_key] = {
|
||||
"initial_output": result.output,
|
||||
"human_feedback": human_feedback,
|
||||
}
|
||||
else:
|
||||
# Save improved output
|
||||
if train_iteration in agent_training_data:
|
||||
agent_training_data[train_iteration]["improved_output"] = result.output
|
||||
if train_key in agent_training_data:
|
||||
agent_training_data[train_key]["improved_output"] = result.output
|
||||
else:
|
||||
if self.agent.verbose:
|
||||
self._printer.print(
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
from collections.abc import Callable, Coroutine
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import contextvars
|
||||
from datetime import datetime
|
||||
import inspect
|
||||
import json
|
||||
@@ -1492,16 +1492,19 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
|
||||
# Initialize or retrieve agent's training data
|
||||
agent_training_data = training_data.get(agent_id, {})
|
||||
|
||||
# Use string key for JSON compatibility (JSON converts int keys to strings)
|
||||
train_key = str(train_iteration)
|
||||
|
||||
if human_feedback is not None:
|
||||
# Save initial output and human feedback
|
||||
agent_training_data[train_iteration] = {
|
||||
agent_training_data[train_key] = {
|
||||
"initial_output": result.output,
|
||||
"human_feedback": human_feedback,
|
||||
}
|
||||
else:
|
||||
# Save improved output
|
||||
if train_iteration in agent_training_data:
|
||||
agent_training_data[train_iteration]["improved_output"] = result.output
|
||||
if train_key in agent_training_data:
|
||||
agent_training_data[train_key]["improved_output"] = result.output
|
||||
else:
|
||||
train_error = Text()
|
||||
train_error.append("❌ ", style="red bold")
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
from datetime import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from typing_extensions import Unpack
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LogEntry(TypedDict, total=False):
|
||||
"""TypedDict for log entry kwargs with optional fields for flexibility."""
|
||||
|
||||
@@ -123,52 +126,96 @@ class FileHandler:
|
||||
|
||||
|
||||
class PickleHandler:
|
||||
"""Handler for saving and loading data using pickle.
|
||||
"""Handler for saving and loading data using JSON serialization.
|
||||
|
||||
Note: Despite the class name (kept for backward compatibility), this handler
|
||||
uses JSON serialization instead of pickle to prevent insecure deserialization
|
||||
vulnerabilities (CWE-502).
|
||||
|
||||
Attributes:
|
||||
file_path: The path to the pickle file.
|
||||
file_path: The path to the JSON data file.
|
||||
"""
|
||||
|
||||
def __init__(self, file_name: str) -> None:
|
||||
"""Initialize the PickleHandler with the name of the file where data will be stored.
|
||||
|
||||
The file will be saved in the current directory.
|
||||
The file will be saved in the current directory. Files use JSON format
|
||||
for safe serialization. Legacy .pkl files are automatically migrated.
|
||||
|
||||
Args:
|
||||
file_name: The name of the file for saving and loading data.
|
||||
"""
|
||||
if not file_name.endswith(".pkl"):
|
||||
file_name += ".pkl"
|
||||
# Strip old .pkl extension if present and use .json
|
||||
if file_name.endswith(".pkl"):
|
||||
file_name = file_name[:-4]
|
||||
if not file_name.endswith(".json"):
|
||||
file_name += ".json"
|
||||
|
||||
self.file_path = os.path.join(os.getcwd(), file_name)
|
||||
|
||||
# Derive legacy .pkl path for migration
|
||||
self._legacy_pkl_path = self.file_path.rsplit(".json", 1)[0] + ".pkl"
|
||||
|
||||
def _migrate_legacy_pkl(self) -> dict[str, Any] | None:
|
||||
"""Attempt to migrate data from a legacy .pkl file to JSON format.
|
||||
|
||||
Returns:
|
||||
The migrated data if successful, None otherwise.
|
||||
"""
|
||||
if not os.path.exists(self._legacy_pkl_path):
|
||||
return None
|
||||
|
||||
try:
|
||||
import pickle
|
||||
|
||||
with open(self._legacy_pkl_path, "rb") as f:
|
||||
data = pickle.load(f) # noqa: S301
|
||||
|
||||
# Save as JSON
|
||||
self.save(data)
|
||||
|
||||
# Remove the old pkl file after successful migration
|
||||
os.remove(self._legacy_pkl_path)
|
||||
logger.info(
|
||||
f"Migrated legacy pickle file to JSON: {self._legacy_pkl_path} -> {self.file_path}"
|
||||
)
|
||||
return data # type: ignore[no-any-return]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to migrate legacy pickle file {self._legacy_pkl_path}: {e}")
|
||||
return None
|
||||
|
||||
def initialize_file(self) -> None:
|
||||
"""Initialize the file with an empty dictionary and overwrite any existing data."""
|
||||
self.save({})
|
||||
|
||||
def save(self, data: Any) -> None:
|
||||
"""
|
||||
Save the data to the specified file using pickle.
|
||||
"""Save the data to the specified file using JSON.
|
||||
|
||||
Args:
|
||||
data: The data to be saved to the file.
|
||||
"""
|
||||
with open(self.file_path, "wb") as f:
|
||||
pickle.dump(obj=data, file=f)
|
||||
with open(self.file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2, default=str)
|
||||
|
||||
def load(self) -> Any:
|
||||
"""Load the data from the specified file using pickle.
|
||||
"""Load the data from the specified file using JSON.
|
||||
|
||||
Falls back to migrating legacy .pkl files if the JSON file doesn't exist.
|
||||
|
||||
Returns:
|
||||
The data loaded from the file.
|
||||
"""
|
||||
if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0:
|
||||
return {} # Return an empty dictionary if the file does not exist or is empty
|
||||
# Try to migrate from legacy pkl file
|
||||
migrated = self._migrate_legacy_pkl()
|
||||
if migrated is not None:
|
||||
return migrated
|
||||
return {} # Return an empty dictionary if no file exists
|
||||
|
||||
with open(self.file_path, "rb") as file:
|
||||
try:
|
||||
return pickle.load(file) # noqa: S301
|
||||
except EOFError:
|
||||
return {} # Return an empty dictionary if the file is empty or corrupted
|
||||
return json.loads(file.read().decode("utf-8"))
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
return {} # Return an empty dictionary if the file is corrupted
|
||||
except Exception:
|
||||
raise # Raise any other exceptions that occur during loading
|
||||
|
||||
@@ -22,6 +22,9 @@ try:
|
||||
from aiocache import Cache
|
||||
from aiocache.serializers import PickleSerializer
|
||||
|
||||
# PickleSerializer is safe here: this is an in-memory cache only.
|
||||
# Data never leaves the process, so there is no untrusted deserialization risk.
|
||||
# JsonSerializer would break FileInput objects (Pydantic models with IO streams).
|
||||
_file_store = Cache(Cache.MEMORY, serializer=PickleSerializer())
|
||||
except ImportError:
|
||||
logger.debug(
|
||||
|
||||
@@ -27,7 +27,7 @@ class CrewTrainingHandler(PickleHandler):
|
||||
data = self.load()
|
||||
if agent_id not in data:
|
||||
data[agent_id] = {}
|
||||
data[agent_id][train_iteration] = new_data
|
||||
data[agent_id][str(train_iteration)] = new_data
|
||||
self.save(data)
|
||||
|
||||
def clear(self) -> None:
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import unittest
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from crewai.utilities.file_handler import PickleHandler
|
||||
|
||||
|
||||
@@ -10,21 +11,23 @@ class TestPickleHandler(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Use a unique file name for each test to avoid race conditions in parallel test execution
|
||||
unique_id = str(uuid.uuid4())
|
||||
self.file_name = f"test_data_{unique_id}.pkl"
|
||||
self.file_path = os.path.join(os.getcwd(), self.file_name)
|
||||
self.file_name = f"test_data_{unique_id}"
|
||||
self.json_path = os.path.join(os.getcwd(), self.file_name + ".json")
|
||||
self.pkl_path = os.path.join(os.getcwd(), self.file_name + ".pkl")
|
||||
self.handler = PickleHandler(self.file_name)
|
||||
|
||||
def tearDown(self):
|
||||
if os.path.exists(self.file_path):
|
||||
os.remove(self.file_path)
|
||||
for path in (self.json_path, self.pkl_path):
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
|
||||
def test_initialize_file(self):
|
||||
assert os.path.exists(self.file_path) is False
|
||||
assert os.path.exists(self.json_path) is False
|
||||
|
||||
self.handler.initialize_file()
|
||||
|
||||
assert os.path.exists(self.file_path) is True
|
||||
assert os.path.getsize(self.file_path) >= 0
|
||||
assert os.path.exists(self.json_path) is True
|
||||
assert os.path.getsize(self.json_path) >= 0
|
||||
|
||||
def test_save_and_load(self):
|
||||
data = {"key": "value"}
|
||||
@@ -37,13 +40,79 @@ class TestPickleHandler(unittest.TestCase):
|
||||
assert loaded_data == {}
|
||||
|
||||
def test_load_corrupted_file(self):
|
||||
with open(self.file_path, "wb") as file:
|
||||
file.write(b"corrupted data")
|
||||
file.flush()
|
||||
os.fsync(file.fileno()) # Ensure data is written to disk
|
||||
"""Test that corrupted (non-JSON) files return empty dict gracefully."""
|
||||
with open(self.json_path, "w") as file:
|
||||
file.write("corrupted data that is not valid json")
|
||||
|
||||
with pytest.raises(Exception) as exc:
|
||||
self.handler.load()
|
||||
loaded_data = self.handler.load()
|
||||
assert loaded_data == {}
|
||||
|
||||
assert str(exc.value) == "pickle data was truncated"
|
||||
assert "<class '_pickle.UnpicklingError'>" == str(exc.type)
|
||||
def test_uses_json_format(self):
|
||||
"""Test that data is saved in JSON format, not pickle."""
|
||||
data = {"agent1": {"param1": 1, "param2": "test"}}
|
||||
self.handler.save(data)
|
||||
|
||||
# Verify the file is valid JSON
|
||||
with open(self.json_path, encoding="utf-8") as f:
|
||||
loaded = json.load(f)
|
||||
assert loaded == data
|
||||
|
||||
def test_file_extension_is_json(self):
|
||||
"""Test that the handler uses .json extension."""
|
||||
handler = PickleHandler("test_file.pkl")
|
||||
assert handler.file_path.endswith(".json")
|
||||
assert not handler.file_path.endswith(".pkl")
|
||||
|
||||
def test_no_pickle_in_saved_file(self):
|
||||
"""Test that saved files do not contain pickle data (security)."""
|
||||
data = {"key": "value", "nested": {"a": 1}}
|
||||
self.handler.save(data)
|
||||
|
||||
with open(self.json_path, "rb") as f:
|
||||
raw = f.read()
|
||||
|
||||
# Pickle files start with specific opcodes (0x80 for protocol 2+)
|
||||
assert not raw.startswith(b"\x80"), "File appears to contain pickle data"
|
||||
# Should be valid UTF-8 text (JSON)
|
||||
raw.decode("utf-8")
|
||||
|
||||
def test_migrate_legacy_pkl_file(self):
|
||||
"""Test that legacy .pkl files are automatically migrated to JSON."""
|
||||
data = {"agent1": {"param1": 1}}
|
||||
|
||||
# Create a legacy pkl file
|
||||
with open(self.pkl_path, "wb") as f:
|
||||
pickle.dump(data, f)
|
||||
|
||||
assert os.path.exists(self.pkl_path)
|
||||
assert not os.path.exists(self.json_path)
|
||||
|
||||
# Loading should migrate the pkl to json
|
||||
loaded_data = self.handler.load()
|
||||
assert loaded_data == data
|
||||
|
||||
# pkl file should be removed after migration
|
||||
assert not os.path.exists(self.pkl_path)
|
||||
# json file should now exist
|
||||
assert os.path.exists(self.json_path)
|
||||
|
||||
def test_pkl_extension_input_uses_json(self):
|
||||
"""Test that passing a .pkl filename still results in .json storage."""
|
||||
handler = PickleHandler("my_data.pkl")
|
||||
assert handler.file_path.endswith("my_data.json")
|
||||
|
||||
def test_insecure_pickle_not_loaded_directly(self):
|
||||
"""Test that arbitrary pickle files cannot be loaded directly as JSON.
|
||||
|
||||
This verifies the security fix: a malicious pickle file placed at the
|
||||
JSON path would not be deserialized via pickle.load().
|
||||
"""
|
||||
# Create a file with pickle content at the json path
|
||||
malicious_data = {"safe": True}
|
||||
with open(self.json_path, "wb") as f:
|
||||
pickle.dump(malicious_data, f)
|
||||
|
||||
# The handler should fail gracefully (corrupt JSON) rather than
|
||||
# executing pickle.load on this file
|
||||
loaded = self.handler.load()
|
||||
assert loaded == {} # Returns empty dict for corrupted JSON
|
||||
|
||||
@@ -7,13 +7,16 @@ from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
|
||||
class InternalCrewTrainingHandler(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.temp_file = tempfile.NamedTemporaryFile(suffix=".pkl", delete=False)
|
||||
self.temp_file = tempfile.NamedTemporaryFile(suffix=".json", delete=False)
|
||||
self.temp_file.close()
|
||||
self.handler = CrewTrainingHandler(self.temp_file.name)
|
||||
|
||||
def tearDown(self):
|
||||
if os.path.exists(self.temp_file.name):
|
||||
os.remove(self.temp_file.name)
|
||||
# Clean up both potential file paths (.json used by handler)
|
||||
handler_path = self.handler.file_path
|
||||
for path in (self.temp_file.name, handler_path):
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
del self.handler
|
||||
|
||||
def test_save_trained_data(self):
|
||||
@@ -37,12 +40,13 @@ class InternalCrewTrainingHandler(unittest.TestCase):
|
||||
self.handler.append(train_iteration, agent_id, new_data)
|
||||
|
||||
# Assert that the new data is appended correctly to the existing agent
|
||||
# Note: JSON serializes integer keys as strings
|
||||
data = self.handler.load()
|
||||
assert agent_id in data
|
||||
assert initial_iteration in data[agent_id]
|
||||
assert train_iteration in data[agent_id]
|
||||
assert data[agent_id][initial_iteration] == initial_data
|
||||
assert data[agent_id][train_iteration] == new_data
|
||||
assert str(initial_iteration) in data[agent_id]
|
||||
assert str(train_iteration) in data[agent_id]
|
||||
assert data[agent_id][str(initial_iteration)] == initial_data
|
||||
assert data[agent_id][str(train_iteration)] == new_data
|
||||
|
||||
def test_append_new_agent(self):
|
||||
train_iteration = 1
|
||||
@@ -51,5 +55,6 @@ class InternalCrewTrainingHandler(unittest.TestCase):
|
||||
self.handler.append(train_iteration, agent_id, new_data)
|
||||
|
||||
# Assert that the new agent and data are appended correctly
|
||||
# Note: JSON serializes integer keys as strings
|
||||
data = self.handler.load()
|
||||
assert data[agent_id][train_iteration] == new_data
|
||||
assert data[agent_id][str(train_iteration)] == new_data
|
||||
|
||||
Reference in New Issue
Block a user