mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 16:18:30 +00:00
feat: Add Train feature for Crews (#686)
* feat: add training logic to agent and crew * feat: add training logic to agent executor * feat: add input parameter to cli command * feat: add utilities for the training logic * feat: polish code, logic and add private variables * feat: add docstring and type hinting to executor * feat: add constant file, add constant to code * feat: fix name of training handler function * feat: remove unused var * feat: change file handler file name * feat: Add training handler file, class and change on the code * feat: fix name error from file * fix: change import to adapt to logic * feat: add training handler test * feat: add tests for file and training_handler * feat: add test for task evaluator function * feat: change text to fit in-screen * feat: add test for train function * feat: add test for agent training_handler function * feat: add test for agent._use_trained_data
This commit is contained in:
committed by
GitHub
parent
9e61b8325b
commit
175d5b3dd6
@@ -1,6 +1,6 @@
|
||||
from copy import deepcopy
|
||||
import os
|
||||
import uuid
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from langchain.agents.agent import RunnableAgent
|
||||
@@ -24,7 +24,9 @@ from pydantic_core import PydanticCustomError
|
||||
from crewai.agents import CacheHandler, CrewAgentExecutor, CrewAgentParser, ToolsHandler
|
||||
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||
from crewai.utilities import I18N, Logger, Prompts, RPMController
|
||||
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
|
||||
from crewai.utilities.token_counter_callback import TokenCalcHandler, TokenProcess
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
|
||||
|
||||
class Agent(BaseModel):
|
||||
@@ -98,8 +100,7 @@ class Agent(BaseModel):
|
||||
agent_executor: InstanceOf[CrewAgentExecutor] = Field(
|
||||
default=None, description="An instance of the CrewAgentExecutor class."
|
||||
)
|
||||
crew: Any = Field(
|
||||
default=None, description="Crew to which the agent belongs.")
|
||||
crew: Any = Field(default=None, description="Crew to which the agent belongs.")
|
||||
tools_handler: InstanceOf[ToolsHandler] = Field(
|
||||
default=None, description="An instance of the ToolsHandler class."
|
||||
)
|
||||
@@ -110,8 +111,7 @@ class Agent(BaseModel):
|
||||
default=None,
|
||||
description="Callback to be executed after each step of the agent execution.",
|
||||
)
|
||||
i18n: I18N = Field(
|
||||
default=I18N(), description="Internationalization settings.")
|
||||
i18n: I18N = Field(default=I18N(), description="Internationalization settings.")
|
||||
llm: Any = Field(
|
||||
default_factory=lambda: ChatOpenAI(
|
||||
model=os.environ.get("OPENAI_MODEL_NAME", "gpt-4o")
|
||||
@@ -172,8 +172,7 @@ class Agent(BaseModel):
|
||||
def set_agent_executor(self) -> "Agent":
|
||||
"""set agent executor is set."""
|
||||
if hasattr(self.llm, "model_name"):
|
||||
token_handler = TokenCalcHandler(
|
||||
self.llm.model_name, self._token_process)
|
||||
token_handler = TokenCalcHandler(self.llm.model_name, self._token_process)
|
||||
|
||||
# Ensure self.llm.callbacks is a list
|
||||
if not isinstance(self.llm.callbacks, list):
|
||||
@@ -236,10 +235,14 @@ class Agent(BaseModel):
|
||||
self.agent_executor.tools = parsed_tools
|
||||
self.agent_executor.task = task
|
||||
|
||||
self.agent_executor.tools_description = render_text_description(
|
||||
parsed_tools)
|
||||
self.agent_executor.tools_description = render_text_description(parsed_tools)
|
||||
self.agent_executor.tools_names = self.__tools_names(parsed_tools)
|
||||
|
||||
if self.crew._train:
|
||||
task_prompt = self._training_handler(task_prompt=task_prompt)
|
||||
else:
|
||||
task_prompt = self._use_trained_data(task_prompt=task_prompt)
|
||||
|
||||
result = self.agent_executor.invoke(
|
||||
{
|
||||
"input": task_prompt,
|
||||
@@ -335,8 +338,7 @@ class Agent(BaseModel):
|
||||
)
|
||||
|
||||
bind = self.llm.bind(stop=stop_words)
|
||||
inner_agent = agent_args | execution_prompt | bind | CrewAgentParser(
|
||||
agent=self)
|
||||
inner_agent = agent_args | execution_prompt | bind | CrewAgentParser(agent=self)
|
||||
self.agent_executor = CrewAgentExecutor(
|
||||
agent=RunnableAgent(runnable=inner_agent), **executor_args
|
||||
)
|
||||
@@ -371,7 +373,7 @@ class Agent(BaseModel):
|
||||
thoughts += action.log
|
||||
thoughts += f"\n{observation_prefix}{observation}\n{llm_prefix}"
|
||||
return thoughts
|
||||
|
||||
|
||||
def copy(self):
|
||||
"""Create a deep copy of the Agent."""
|
||||
exclude = {
|
||||
@@ -379,8 +381,8 @@ class Agent(BaseModel):
|
||||
"_logger",
|
||||
"_rpm_controller",
|
||||
"_request_within_rpm_limit",
|
||||
"_token_process",
|
||||
"agent_executor",
|
||||
"_token_process",
|
||||
"agent_executor",
|
||||
"tools",
|
||||
"tools_handler",
|
||||
"cache_handler",
|
||||
@@ -412,6 +414,30 @@ class Agent(BaseModel):
|
||||
tools_list.append(tool)
|
||||
return tools_list
|
||||
|
||||
def _training_handler(self, task_prompt: str) -> str:
|
||||
"""Handle training data for the agent task prompt to improve output on Training."""
|
||||
if data := CrewTrainingHandler(TRAINING_DATA_FILE).load():
|
||||
agent_id = str(self.id)
|
||||
|
||||
if data.get(agent_id):
|
||||
human_feedbacks = [
|
||||
i["human_feedback"] for i in data.get(agent_id, {}).values()
|
||||
]
|
||||
task_prompt += "You MUST follow these feedbacks: \n " + "\n - ".join(
|
||||
human_feedbacks
|
||||
)
|
||||
|
||||
return task_prompt
|
||||
|
||||
def _use_trained_data(self, task_prompt: str) -> str:
|
||||
"""Use trained data for the agent task prompt to improve output."""
|
||||
if data := CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).load():
|
||||
if trained_data_output := data.get(self.role):
|
||||
task_prompt += "You MUST follow these feedbacks: \n " + "\n - ".join(
|
||||
trained_data_output["suggestions"]
|
||||
)
|
||||
return task_prompt
|
||||
|
||||
@staticmethod
|
||||
def __tools_names(tools) -> str:
|
||||
return ", ".join([t.name for t in tools])
|
||||
|
||||
Reference in New Issue
Block a user