diff --git a/src/crewai/crew.py b/src/crewai/crew.py index 4af3304e1..ac373a0b6 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -1,7 +1,7 @@ import asyncio import json import uuid -from typing import Any, Dict, List, Optional, Union, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union from langchain_core.callbacks import BaseCallbackHandler from pydantic import ( @@ -28,6 +28,7 @@ from crewai.task import Task from crewai.telemetry import Telemetry from crewai.tools.agent_tools import AgentTools from crewai.utilities import I18N, FileHandler, Logger, RPMController +from crewai.utilities.constants import TRAINING_DATA_FILE from crewai.utilities.evaluators.task_evaluator import TaskEvaluator from crewai.utilities.training_handler import CrewTrainingHandler @@ -289,6 +290,8 @@ class Crew(BaseModel): for agent in self.agents: agent.allow_delegation = False + CrewTrainingHandler(TRAINING_DATA_FILE).initialize_file() + def train(self, n_iterations: int, inputs: Optional[Dict[str, Any]] = {}) -> None: """Trains the crew for a given number of iterations.""" self._setup_for_training() @@ -297,7 +300,7 @@ class Crew(BaseModel): self._train_iteration = n_iteration self.kickoff(inputs=inputs) - training_data = CrewTrainingHandler("training_data.pkl").load() + training_data = CrewTrainingHandler(TRAINING_DATA_FILE).load() for agent in self.agents: result = TaskEvaluator(agent).evaluate_training_data( @@ -585,8 +588,7 @@ class Crew(BaseModel): self._rpm_controller.stop_rpm_counter() if agentops: agentops.end_session( - end_state="Success", - end_state_reason="Finished Execution" + end_state="Success", end_state_reason="Finished Execution" ) self._telemetry.end_crew(self, output) diff --git a/src/crewai/utilities/file_handler.py b/src/crewai/utilities/file_handler.py index 0f929e962..4af101471 100644 --- a/src/crewai/utilities/file_handler.py +++ b/src/crewai/utilities/file_handler.py @@ -31,9 +31,8 @@ class PickleHandler: - file_name (str): The name of the file for saving and loading data. """ self.file_path = os.path.join(os.getcwd(), file_name) - self._initialize_file() - def _initialize_file(self) -> None: + def initialize_file(self) -> None: """ Initialize the file with an empty dictionary if it does not exist or is empty. """ diff --git a/tests/utilities/test_file_handler.py b/tests/utilities/test_file_handler.py index 1983c37d9..4a1038a9b 100644 --- a/tests/utilities/test_file_handler.py +++ b/tests/utilities/test_file_handler.py @@ -17,6 +17,10 @@ class TestPickleHandler(unittest.TestCase): os.remove(self.file_path) def test_initialize_file(self): + assert os.path.exists(self.file_path) is False + + self.handler.initialize_file() + assert os.path.exists(self.file_path) is True assert os.path.getsize(self.file_path) >= 0