Fix issue #2724: Allow specifying trained data file for run command

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-04-29 21:20:58 +00:00
parent 25c8155609
commit 2f66aa0efc
6 changed files with 68 additions and 8 deletions

View File

@@ -118,6 +118,10 @@ class Agent(BaseAgent):
default=None,
description="Knowledge context for the agent.",
)
trained_data_file: str = Field(
default=TRAINED_AGENTS_DATA_FILE,
description="Path to the trained data file to use for task prompts.",
)
crew_knowledge_context: Optional[str] = Field(
default=None,
description="Knowledge context for the crew.",
@@ -498,7 +502,7 @@ class Agent(BaseAgent):
def _use_trained_data(self, task_prompt: str) -> str:
"""Use trained data for the agent task prompt to improve output."""
if data := CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).load():
if data := CrewTrainingHandler(self.trained_data_file).load():
if trained_data_output := data.get(self.role):
task_prompt += (
"\n\nYou MUST follow these instructions: \n - "

View File

@@ -201,9 +201,17 @@ def install(context):
@crewai.command()
def run():
@click.option(
"-f",
"--filename",
type=str,
default="trained_agents_data.pkl",
help="Path to a trained data file to use",
)
def run(filename: str):
"""Run the Crew."""
run_crew()
click.echo(f"Running the Crew with trained data from {filename}")
run_crew(trained_data_file=filename)
@crewai.command()

View File

@@ -14,13 +14,16 @@ class CrewType(Enum):
FLOW = "flow"
def run_crew() -> None:
def run_crew(trained_data_file: Optional[str] = None) -> None:
"""
Run the crew or flow by running a command in the UV environment.
Starting from version 0.103.0, this command can be used to run both
standard crews and flows. For flows, it detects the type from pyproject.toml
and automatically runs the appropriate command.
Args:
trained_data_file: Optional path to a trained data file to use
"""
crewai_version = get_crewai_version()
min_required_version = "0.71.0"
@@ -44,17 +47,21 @@ def run_crew() -> None:
click.echo(f"Running the {'Flow' if is_flow else 'Crew'}")
# Execute the appropriate command
execute_command(crew_type)
execute_command(crew_type, trained_data_file)
def execute_command(crew_type: CrewType) -> None:
def execute_command(crew_type: CrewType, trained_data_file: Optional[str] = None) -> None:
"""
Execute the appropriate command based on crew type.
Args:
crew_type: The type of crew to run
trained_data_file: Optional path to a trained data file to use
"""
command = ["uv", "run", "kickoff" if crew_type == CrewType.FLOW else "run_crew"]
if trained_data_file and crew_type == CrewType.STANDARD:
command.extend(["--trained-data-file", trained_data_file])
try:
subprocess.run(command, capture_output=False, text=True, check=True)

View File

@@ -23,7 +23,8 @@ def run():
}
try:
{{crew_name}}().crew().kickoff(inputs=inputs)
filename = sys.argv[1] if len(sys.argv) > 1 else "trained_agents_data.pkl"
{{crew_name}}().crew(trained_data_file=filename).kickoff(inputs=inputs)
except Exception as e:
raise Exception(f"An error occurred while running the crew: {e}")

View File

@@ -122,6 +122,10 @@ class Crew(BaseModel):
tasks: List[Task] = Field(default_factory=list)
agents: List[BaseAgent] = Field(default_factory=list)
process: Process = Field(default=Process.sequential)
trained_data_file: Optional[str] = Field(
default=None,
description="Path to the trained data file to use for agent task prompts.",
)
verbose: bool = Field(default=False)
memory: bool = Field(
default=False,
@@ -1196,7 +1200,12 @@ class Crew(BaseModel):
"manager_llm",
}
cloned_agents = [agent.copy() for agent in self.agents]
cloned_agents = []
for agent in self.agents:
cloned_agent = agent.copy()
if self.trained_data_file:
cloned_agent.trained_data_file = self.trained_data_file
cloned_agents.append(cloned_agent)
manager_agent = self.manager_agent.copy() if self.manager_agent else None
manager_llm = shallow_copy(self.manager_llm) if self.manager_llm else None

View File

@@ -1196,6 +1196,37 @@ def test_agent_use_trained_data(crew_training_handler):
)
@patch("crewai.agent.CrewTrainingHandler")
def test_agent_use_custom_trained_data_file(crew_training_handler):
task_prompt = "What is 1 + 1?"
custom_file = "custom_trained_data.pkl"
agent = Agent(
role="researcher",
goal="test goal",
backstory="test backstory",
verbose=True,
trained_data_file=custom_file
)
crew_training_handler().load.return_value = {
agent.role: {
"suggestions": [
"The result of the math operation must be right.",
"Result must be better than 1.",
]
}
}
result = agent._use_trained_data(task_prompt=task_prompt)
assert (
result == "What is 1 + 1?\n\nYou MUST follow these instructions: \n"
" - The result of the math operation must be right.\n - Result must be better than 1."
)
crew_training_handler.assert_has_calls(
[mock.call(), mock.call(custom_file), mock.call().load()]
)
def test_agent_max_retry_limit():
agent = Agent(
role="test role",