mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
fix: file_handler issue (#869)
* fix: file_handler issue * fix: add logic for the trained_agent data
This commit is contained in:
committed by
GitHub
parent
cc9e30ac23
commit
3d78ad4fff
@@ -1,7 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict, List, Optional, Union, Tuple
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from langchain_core.callbacks import BaseCallbackHandler
|
from langchain_core.callbacks import BaseCallbackHandler
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
@@ -28,6 +28,7 @@ from crewai.task import Task
|
|||||||
from crewai.telemetry import Telemetry
|
from crewai.telemetry import Telemetry
|
||||||
from crewai.tools.agent_tools import AgentTools
|
from crewai.tools.agent_tools import AgentTools
|
||||||
from crewai.utilities import I18N, FileHandler, Logger, RPMController
|
from crewai.utilities import I18N, FileHandler, Logger, RPMController
|
||||||
|
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
|
||||||
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
|
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
|
||||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||||
|
|
||||||
@@ -289,6 +290,9 @@ class Crew(BaseModel):
|
|||||||
for agent in self.agents:
|
for agent in self.agents:
|
||||||
agent.allow_delegation = False
|
agent.allow_delegation = False
|
||||||
|
|
||||||
|
CrewTrainingHandler(TRAINING_DATA_FILE).initialize_file()
|
||||||
|
CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).initialize_file()
|
||||||
|
|
||||||
def train(self, n_iterations: int, inputs: Optional[Dict[str, Any]] = {}) -> None:
|
def train(self, n_iterations: int, inputs: Optional[Dict[str, Any]] = {}) -> None:
|
||||||
"""Trains the crew for a given number of iterations."""
|
"""Trains the crew for a given number of iterations."""
|
||||||
self._setup_for_training()
|
self._setup_for_training()
|
||||||
@@ -297,14 +301,14 @@ class Crew(BaseModel):
|
|||||||
self._train_iteration = n_iteration
|
self._train_iteration = n_iteration
|
||||||
self.kickoff(inputs=inputs)
|
self.kickoff(inputs=inputs)
|
||||||
|
|
||||||
training_data = CrewTrainingHandler("training_data.pkl").load()
|
training_data = CrewTrainingHandler(TRAINING_DATA_FILE).load()
|
||||||
|
|
||||||
for agent in self.agents:
|
for agent in self.agents:
|
||||||
result = TaskEvaluator(agent).evaluate_training_data(
|
result = TaskEvaluator(agent).evaluate_training_data(
|
||||||
training_data=training_data, agent_id=str(agent.id)
|
training_data=training_data, agent_id=str(agent.id)
|
||||||
)
|
)
|
||||||
|
|
||||||
CrewTrainingHandler("trained_agents_data.pkl").save_trained_data(
|
CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).save_trained_data(
|
||||||
agent_id=str(agent.role), trained_data=result.model_dump()
|
agent_id=str(agent.role), trained_data=result.model_dump()
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -585,8 +589,7 @@ class Crew(BaseModel):
|
|||||||
self._rpm_controller.stop_rpm_counter()
|
self._rpm_controller.stop_rpm_counter()
|
||||||
if agentops:
|
if agentops:
|
||||||
agentops.end_session(
|
agentops.end_session(
|
||||||
end_state="Success",
|
end_state="Success", end_state_reason="Finished Execution"
|
||||||
end_state_reason="Finished Execution"
|
|
||||||
)
|
)
|
||||||
self._telemetry.end_crew(self, output)
|
self._telemetry.end_crew(self, output)
|
||||||
|
|
||||||
|
|||||||
@@ -31,9 +31,8 @@ class PickleHandler:
|
|||||||
- file_name (str): The name of the file for saving and loading data.
|
- file_name (str): The name of the file for saving and loading data.
|
||||||
"""
|
"""
|
||||||
self.file_path = os.path.join(os.getcwd(), file_name)
|
self.file_path = os.path.join(os.getcwd(), file_name)
|
||||||
self._initialize_file()
|
|
||||||
|
|
||||||
def _initialize_file(self) -> None:
|
def initialize_file(self) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize the file with an empty dictionary if it does not exist or is empty.
|
Initialize the file with an empty dictionary if it does not exist or is empty.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -17,6 +17,10 @@ class TestPickleHandler(unittest.TestCase):
|
|||||||
os.remove(self.file_path)
|
os.remove(self.file_path)
|
||||||
|
|
||||||
def test_initialize_file(self):
|
def test_initialize_file(self):
|
||||||
|
assert os.path.exists(self.file_path) is False
|
||||||
|
|
||||||
|
self.handler.initialize_file()
|
||||||
|
|
||||||
assert os.path.exists(self.file_path) is True
|
assert os.path.exists(self.file_path) is True
|
||||||
assert os.path.getsize(self.file_path) >= 0
|
assert os.path.getsize(self.file_path) >= 0
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user