mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
feat: Add Train feature for Crews (#686)
* feat: add training logic to agent and crew * feat: add training logic to agent executor * feat: add input parameter to cli command * feat: add utilities for the training logic * feat: polish code, logic and add private variables * feat: add docstring and type hinting to executor * feat: add constant file, add constant to code * feat: fix name of training handler function * feat: remove unused var * feat: change file handler file name * feat: Add training handler file, class and change on the code * feat: fix name error from file * fix: change import to adapt to logic * feat: add training handler test * feat: add tests for file and training_handler * feat: add test for task evaluator function * feat: change text to fit in-screen * feat: add test for train function * feat: add test for agent training_handler function * feat: add test for agent._use_trained_data
This commit is contained in:
committed by
GitHub
parent
9e61b8325b
commit
175d5b3dd6
@@ -18,8 +18,10 @@ from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
|
||||
from crewai.tools.tool_usage import ToolUsage, ToolUsageErrorException
|
||||
from crewai.utilities import I18N
|
||||
from crewai.utilities.constants import TRAINING_DATA_FILE
|
||||
from crewai.utilities.converter import ConverterError
|
||||
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
|
||||
|
||||
class CrewAgentExecutor(AgentExecutor):
|
||||
@@ -246,12 +248,17 @@ class CrewAgentExecutor(AgentExecutor):
|
||||
# If the tool chosen is the finishing tool, then we end and return.
|
||||
if isinstance(output, AgentFinish):
|
||||
if self.should_ask_for_human_input:
|
||||
human_feedback = self._ask_human_input(output.return_values["output"])
|
||||
|
||||
if self.crew._train:
|
||||
self._handle_crew_training_output(output, human_feedback)
|
||||
|
||||
# Making sure we only ask for it once, so disabling for the next thought loop
|
||||
self.should_ask_for_human_input = False
|
||||
human_feedback = self._ask_human_input(output.return_values["output"])
|
||||
action = AgentAction(
|
||||
tool="Human Input", tool_input=human_feedback, log=output.log
|
||||
)
|
||||
|
||||
yield AgentStep(
|
||||
action=action,
|
||||
observation=self._i18n.slice("human_feedback").format(
|
||||
@@ -261,6 +268,9 @@ class CrewAgentExecutor(AgentExecutor):
|
||||
return
|
||||
|
||||
else:
|
||||
if self.crew._train:
|
||||
self._handle_crew_training_output(output)
|
||||
|
||||
yield output
|
||||
return
|
||||
|
||||
@@ -305,3 +315,30 @@ class CrewAgentExecutor(AgentExecutor):
|
||||
return input(
|
||||
self._i18n.slice("getting_input").format(final_answer=final_answer)
|
||||
)
|
||||
|
||||
def _handle_crew_training_output(
|
||||
self, output: AgentFinish, human_feedback: str | None = None
|
||||
) -> None:
|
||||
"""Function to handle the process of the training data."""
|
||||
agent_id = str(self.crew_agent.id)
|
||||
|
||||
if (
|
||||
training_data := CrewTrainingHandler(TRAINING_DATA_FILE).load()
|
||||
and not self.should_ask_for_human_input
|
||||
):
|
||||
if training_data.get(agent_id):
|
||||
training_data[agent_id][self.crew._train_iteration][
|
||||
"improved_output"
|
||||
] = output.return_values["output"]
|
||||
CrewTrainingHandler(TRAINING_DATA_FILE).save(training_data)
|
||||
|
||||
if self.should_ask_for_human_input and human_feedback is not None:
|
||||
training_data = {
|
||||
"initial_output": output.return_values["output"],
|
||||
"human_feedback": human_feedback,
|
||||
"agent": agent_id,
|
||||
"agent_role": self.crew_agent.role,
|
||||
}
|
||||
CrewTrainingHandler(TRAINING_DATA_FILE).append(
|
||||
self.crew._train_iteration, agent_id, training_data
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user