Files
crewAI/tests/utilities/test_file_handler.py
Eduardo Chiarotti 175d5b3dd6 feat: Add Train feature for Crews (#686)
* feat: add training logic to agent and crew

* feat: add training logic to agent executor

* feat: add input parameter  to cli command

* feat: add utilities for the training logic

* feat: polish code, logic and add private variables

* feat: add docstring and type hinting to executor

* feat: add constant file, add constant to code

* feat: fix name of training handler function

* feat: remove unused var

* feat: change file handler file name

* feat: Add training handler file, class and change on the code

* feat: fix name error from file

* fix: change import to adapt to logic

* feat: add training handler test

* feat: add tests for file and training_handler

* feat: add test for task evaluator function

* feat: change text to fit in-screen

* feat: add test for train function

* feat: add test for agent training_handler function

* feat: add test for agent._use_trained_data
2024-06-27 02:22:34 -03:00

42 lines
1.2 KiB
Python

import os
import unittest
import pytest
from crewai.utilities.file_handler import PickleHandler
class TestPickleHandler(unittest.TestCase):
def setUp(self):
self.file_name = "test_data.pkl"
self.file_path = os.path.join(os.getcwd(), self.file_name)
self.handler = PickleHandler(self.file_name)
def tearDown(self):
if os.path.exists(self.file_path):
os.remove(self.file_path)
def test_initialize_file(self):
assert os.path.exists(self.file_path) is True
assert os.path.getsize(self.file_path) >= 0
def test_save_and_load(self):
data = {"key": "value"}
self.handler.save(data)
loaded_data = self.handler.load()
assert loaded_data == data
def test_load_empty_file(self):
loaded_data = self.handler.load()
assert loaded_data == {}
def test_load_corrupted_file(self):
with open(self.file_path, "wb") as file:
file.write(b"corrupted data")
with pytest.raises(Exception) as exc:
self.handler.load()
assert str(exc.value) == "pickle data was truncated"
assert "<class '_pickle.UnpicklingError'>" == str(exc.type)