Compare commits

...

2 Commits

Author SHA1 Message Date
Eduardo Chiarotti
b1d8a13f8f fix: add logic for the trained_agent data 2024-07-04 08:42:22 -03:00
Eduardo Chiarotti
2b22e5ecd3 fix: file_handler issue 2024-07-04 08:19:30 -03:00
3 changed files with 13 additions and 7 deletions

View File

@@ -1,7 +1,7 @@
import asyncio import asyncio
import json import json
import uuid import uuid
from typing import Any, Dict, List, Optional, Union, Tuple from typing import Any, Dict, List, Optional, Tuple, Union
from langchain_core.callbacks import BaseCallbackHandler from langchain_core.callbacks import BaseCallbackHandler
from pydantic import ( from pydantic import (
@@ -28,6 +28,7 @@ from crewai.task import Task
from crewai.telemetry import Telemetry from crewai.telemetry import Telemetry
from crewai.tools.agent_tools import AgentTools from crewai.tools.agent_tools import AgentTools
from crewai.utilities import I18N, FileHandler, Logger, RPMController from crewai.utilities import I18N, FileHandler, Logger, RPMController
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
from crewai.utilities.training_handler import CrewTrainingHandler from crewai.utilities.training_handler import CrewTrainingHandler
@@ -289,6 +290,9 @@ class Crew(BaseModel):
for agent in self.agents: for agent in self.agents:
agent.allow_delegation = False agent.allow_delegation = False
CrewTrainingHandler(TRAINING_DATA_FILE).initialize_file()
CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).initialize_file()
def train(self, n_iterations: int, inputs: Optional[Dict[str, Any]] = {}) -> None: def train(self, n_iterations: int, inputs: Optional[Dict[str, Any]] = {}) -> None:
"""Trains the crew for a given number of iterations.""" """Trains the crew for a given number of iterations."""
self._setup_for_training() self._setup_for_training()
@@ -297,14 +301,14 @@ class Crew(BaseModel):
self._train_iteration = n_iteration self._train_iteration = n_iteration
self.kickoff(inputs=inputs) self.kickoff(inputs=inputs)
training_data = CrewTrainingHandler("training_data.pkl").load() training_data = CrewTrainingHandler(TRAINING_DATA_FILE).load()
for agent in self.agents: for agent in self.agents:
result = TaskEvaluator(agent).evaluate_training_data( result = TaskEvaluator(agent).evaluate_training_data(
training_data=training_data, agent_id=str(agent.id) training_data=training_data, agent_id=str(agent.id)
) )
CrewTrainingHandler("trained_agents_data.pkl").save_trained_data( CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).save_trained_data(
agent_id=str(agent.role), trained_data=result.model_dump() agent_id=str(agent.role), trained_data=result.model_dump()
) )
@@ -585,8 +589,7 @@ class Crew(BaseModel):
self._rpm_controller.stop_rpm_counter() self._rpm_controller.stop_rpm_counter()
if agentops: if agentops:
agentops.end_session( agentops.end_session(
end_state="Success", end_state="Success", end_state_reason="Finished Execution"
end_state_reason="Finished Execution"
) )
self._telemetry.end_crew(self, output) self._telemetry.end_crew(self, output)

View File

@@ -31,9 +31,8 @@ class PickleHandler:
- file_name (str): The name of the file for saving and loading data. - file_name (str): The name of the file for saving and loading data.
""" """
self.file_path = os.path.join(os.getcwd(), file_name) self.file_path = os.path.join(os.getcwd(), file_name)
self._initialize_file()
def _initialize_file(self) -> None: def initialize_file(self) -> None:
""" """
Initialize the file with an empty dictionary if it does not exist or is empty. Initialize the file with an empty dictionary if it does not exist or is empty.
""" """

View File

@@ -17,6 +17,10 @@ class TestPickleHandler(unittest.TestCase):
os.remove(self.file_path) os.remove(self.file_path)
def test_initialize_file(self): def test_initialize_file(self):
assert os.path.exists(self.file_path) is False
self.handler.initialize_file()
assert os.path.exists(self.file_path) is True assert os.path.exists(self.file_path) is True
assert os.path.getsize(self.file_path) >= 0 assert os.path.getsize(self.file_path) >= 0