feat: Add Train feature for Crews (#686)

* feat: add training logic to agent and crew

* feat: add training logic to agent executor

* feat: add input parameter  to cli command

* feat: add utilities for the training logic

* feat: polish code, logic and add private variables

* feat: add docstring and type hinting to executor

* feat: add constant file, add constant to code

* feat: fix name of training handler function

* feat: remove unused var

* feat: change file handler file name

* feat: Add training handler file, class and change on the code

* feat: fix name error from file

* fix: change import to adapt to logic

* feat: add training handler test

* feat: add tests for file and training_handler

* feat: add test for task evaluator function

* feat: change text to fit in-screen

* feat: add test for train function

* feat: add test for agent training_handler function

* feat: add test for agent._use_trained_data
This commit is contained in:
Eduardo Chiarotti
2024-06-27 02:22:34 -03:00
committed by GitHub
parent 9e61b8325b
commit 175d5b3dd6
15 changed files with 564 additions and 45 deletions

View File

@@ -1,6 +1,6 @@
from copy import deepcopy
import os
import uuid
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple
from langchain.agents.agent import RunnableAgent
@@ -24,7 +24,9 @@ from pydantic_core import PydanticCustomError
from crewai.agents import CacheHandler, CrewAgentExecutor, CrewAgentParser, ToolsHandler
from crewai.memory.contextual.contextual_memory import ContextualMemory
from crewai.utilities import I18N, Logger, Prompts, RPMController
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
from crewai.utilities.token_counter_callback import TokenCalcHandler, TokenProcess
from crewai.utilities.training_handler import CrewTrainingHandler
class Agent(BaseModel):
@@ -98,8 +100,7 @@ class Agent(BaseModel):
agent_executor: InstanceOf[CrewAgentExecutor] = Field(
default=None, description="An instance of the CrewAgentExecutor class."
)
crew: Any = Field(
default=None, description="Crew to which the agent belongs.")
crew: Any = Field(default=None, description="Crew to which the agent belongs.")
tools_handler: InstanceOf[ToolsHandler] = Field(
default=None, description="An instance of the ToolsHandler class."
)
@@ -110,8 +111,7 @@ class Agent(BaseModel):
default=None,
description="Callback to be executed after each step of the agent execution.",
)
i18n: I18N = Field(
default=I18N(), description="Internationalization settings.")
i18n: I18N = Field(default=I18N(), description="Internationalization settings.")
llm: Any = Field(
default_factory=lambda: ChatOpenAI(
model=os.environ.get("OPENAI_MODEL_NAME", "gpt-4o")
@@ -172,8 +172,7 @@ class Agent(BaseModel):
def set_agent_executor(self) -> "Agent":
"""set agent executor is set."""
if hasattr(self.llm, "model_name"):
token_handler = TokenCalcHandler(
self.llm.model_name, self._token_process)
token_handler = TokenCalcHandler(self.llm.model_name, self._token_process)
# Ensure self.llm.callbacks is a list
if not isinstance(self.llm.callbacks, list):
@@ -236,10 +235,14 @@ class Agent(BaseModel):
self.agent_executor.tools = parsed_tools
self.agent_executor.task = task
self.agent_executor.tools_description = render_text_description(
parsed_tools)
self.agent_executor.tools_description = render_text_description(parsed_tools)
self.agent_executor.tools_names = self.__tools_names(parsed_tools)
if self.crew._train:
task_prompt = self._training_handler(task_prompt=task_prompt)
else:
task_prompt = self._use_trained_data(task_prompt=task_prompt)
result = self.agent_executor.invoke(
{
"input": task_prompt,
@@ -335,8 +338,7 @@ class Agent(BaseModel):
)
bind = self.llm.bind(stop=stop_words)
inner_agent = agent_args | execution_prompt | bind | CrewAgentParser(
agent=self)
inner_agent = agent_args | execution_prompt | bind | CrewAgentParser(agent=self)
self.agent_executor = CrewAgentExecutor(
agent=RunnableAgent(runnable=inner_agent), **executor_args
)
@@ -371,7 +373,7 @@ class Agent(BaseModel):
thoughts += action.log
thoughts += f"\n{observation_prefix}{observation}\n{llm_prefix}"
return thoughts
def copy(self):
"""Create a deep copy of the Agent."""
exclude = {
@@ -379,8 +381,8 @@ class Agent(BaseModel):
"_logger",
"_rpm_controller",
"_request_within_rpm_limit",
"_token_process",
"agent_executor",
"_token_process",
"agent_executor",
"tools",
"tools_handler",
"cache_handler",
@@ -412,6 +414,30 @@ class Agent(BaseModel):
tools_list.append(tool)
return tools_list
def _training_handler(self, task_prompt: str) -> str:
"""Handle training data for the agent task prompt to improve output on Training."""
if data := CrewTrainingHandler(TRAINING_DATA_FILE).load():
agent_id = str(self.id)
if data.get(agent_id):
human_feedbacks = [
i["human_feedback"] for i in data.get(agent_id, {}).values()
]
task_prompt += "You MUST follow these feedbacks: \n " + "\n - ".join(
human_feedbacks
)
return task_prompt
def _use_trained_data(self, task_prompt: str) -> str:
"""Use trained data for the agent task prompt to improve output."""
if data := CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).load():
if trained_data_output := data.get(self.role):
task_prompt += "You MUST follow these feedbacks: \n " + "\n - ".join(
trained_data_output["suggestions"]
)
return task_prompt
@staticmethod
def __tools_names(tools) -> str:
return ", ".join([t.name for t in tools])