mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-03 05:08:29 +00:00
* feat: add training logic to agent and crew * feat: add training logic to agent executor * feat: add input parameter to cli command * feat: add utilities for the training logic * feat: polish code, logic and add private variables * feat: add docstring and type hinting to executor * feat: add constant file, add constant to code * feat: fix name of training handler function * feat: remove unused var * feat: change file handler file name * feat: Add training handler file, class and change on the code * feat: fix name error from file * fix: change import to adapt to logic * feat: add training handler test * feat: add tests for file and training_handler * feat: add test for task evaluator function * feat: change text to fit in-screen * feat: add test for train function * feat: add test for agent training_handler function * feat: add test for agent._use_trained_data
43 lines
1.4 KiB
Python
43 lines
1.4 KiB
Python
import os
|
|
import unittest
|
|
|
|
from crewai.utilities.training_handler import CrewTrainingHandler
|
|
|
|
|
|
class TestCrewTrainingHandler(unittest.TestCase):
|
|
def setUp(self):
|
|
self.handler = CrewTrainingHandler("trained_data.pkl")
|
|
|
|
def tearDown(self):
|
|
os.remove("trained_data.pkl")
|
|
del self.handler
|
|
|
|
def test_save_trained_data(self):
|
|
agent_id = "agent1"
|
|
trained_data = {"param1": 1, "param2": 2}
|
|
self.handler.save_trained_data(agent_id, trained_data)
|
|
|
|
# Assert that the trained data is saved correctly
|
|
data = self.handler.load()
|
|
assert data[agent_id] == trained_data
|
|
|
|
def test_append_existing_agent(self):
|
|
train_iteration = 1
|
|
agent_id = "agent1"
|
|
new_data = {"param3": 3, "param4": 4}
|
|
self.handler.append(train_iteration, agent_id, new_data)
|
|
|
|
# Assert that the new data is appended correctly to the existing agent
|
|
data = self.handler.load()
|
|
assert data[agent_id][train_iteration] == new_data
|
|
|
|
def test_append_new_agent(self):
|
|
train_iteration = 1
|
|
agent_id = "agent2"
|
|
new_data = {"param5": 5, "param6": 6}
|
|
self.handler.append(train_iteration, agent_id, new_data)
|
|
|
|
# Assert that the new agent and data are appended correctly
|
|
data = self.handler.load()
|
|
assert data[agent_id][train_iteration] == new_data
|