Compare commits

...

2 Commits

Author SHA1 Message Date
Eduardo Chiarotti
b1d8a13f8f fix: add logic for the trained_agent data 2024-07-04 08:42:22 -03:00
Eduardo Chiarotti
2b22e5ecd3 fix: file_handler issue 2024-07-04 08:19:30 -03:00
3 changed files with 13 additions and 7 deletions

View File

@@ -1,7 +1,7 @@
import asyncio
import json
import uuid
from typing import Any, Dict, List, Optional, Union, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union
from langchain_core.callbacks import BaseCallbackHandler
from pydantic import (
@@ -28,6 +28,7 @@ from crewai.task import Task
from crewai.telemetry import Telemetry
from crewai.tools.agent_tools import AgentTools
from crewai.utilities import I18N, FileHandler, Logger, RPMController
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
from crewai.utilities.training_handler import CrewTrainingHandler
@@ -289,6 +290,9 @@ class Crew(BaseModel):
for agent in self.agents:
agent.allow_delegation = False
CrewTrainingHandler(TRAINING_DATA_FILE).initialize_file()
CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).initialize_file()
def train(self, n_iterations: int, inputs: Optional[Dict[str, Any]] = {}) -> None:
"""Trains the crew for a given number of iterations."""
self._setup_for_training()
@@ -297,14 +301,14 @@ class Crew(BaseModel):
self._train_iteration = n_iteration
self.kickoff(inputs=inputs)
training_data = CrewTrainingHandler("training_data.pkl").load()
training_data = CrewTrainingHandler(TRAINING_DATA_FILE).load()
for agent in self.agents:
result = TaskEvaluator(agent).evaluate_training_data(
training_data=training_data, agent_id=str(agent.id)
)
CrewTrainingHandler("trained_agents_data.pkl").save_trained_data(
CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).save_trained_data(
agent_id=str(agent.role), trained_data=result.model_dump()
)
@@ -585,8 +589,7 @@ class Crew(BaseModel):
self._rpm_controller.stop_rpm_counter()
if agentops:
agentops.end_session(
end_state="Success",
end_state_reason="Finished Execution"
end_state="Success", end_state_reason="Finished Execution"
)
self._telemetry.end_crew(self, output)

View File

@@ -31,9 +31,8 @@ class PickleHandler:
- file_name (str): The name of the file for saving and loading data.
"""
self.file_path = os.path.join(os.getcwd(), file_name)
self._initialize_file()
def _initialize_file(self) -> None:
def initialize_file(self) -> None:
"""
Initialize the file with an empty dictionary if it does not exist or is empty.
"""

View File

@@ -17,6 +17,10 @@ class TestPickleHandler(unittest.TestCase):
os.remove(self.file_path)
def test_initialize_file(self):
assert os.path.exists(self.file_path) is False
self.handler.initialize_file()
assert os.path.exists(self.file_path) is True
assert os.path.getsize(self.file_path) >= 0