mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-02 05:38:12 +00:00
Add crew trained agents file support (#6012)
* Add crew trained agents file support * Add crew trained agents file support
This commit is contained in:
@@ -187,7 +187,7 @@ flowchart TD
|
||||
- **Filename Requirement:** Ensure that the filename ends with `.pkl`. The code will raise a `ValueError` if this condition is not met.
|
||||
- **Error Handling:** The code handles subprocess errors and unexpected exceptions, providing error messages to the user.
|
||||
- Trained guidance is applied at prompt time; it does not modify your Python/YAML agent configuration.
|
||||
- Agents automatically load trained suggestions from a file named `trained_agents_data.pkl` located in the current working directory. If you trained to a different filename, either rename it to `trained_agents_data.pkl` before running, or adjust the loader in code.
|
||||
- Agents automatically load trained suggestions from a file named `trained_agents_data.pkl` located in the current working directory. If you trained to a different filename, pass that path with `Crew(trained_agents_file="my_custom_trained.pkl")`, set `CREWAI_TRAINED_AGENTS_FILE`, or use `crewai run -f my_custom_trained.pkl`.
|
||||
- You can change the output filename when calling `crewai train` with `-f/--filename`. Absolute paths are supported if you want to save outside the CWD.
|
||||
|
||||
It is important to note that the training process may take some time, depending on the complexity of your agents and will also require your feedback on each iteration.
|
||||
|
||||
@@ -1219,9 +1219,17 @@ class Agent(BaseAgent):
|
||||
|
||||
def _use_trained_data(self, task_prompt: str) -> str:
|
||||
"""Use trained data for the agent task prompt to improve output."""
|
||||
trained_file = os.getenv(
|
||||
CREWAI_TRAINED_AGENTS_FILE_ENV, TRAINED_AGENTS_DATA_FILE
|
||||
crew_trained_agents_file = (
|
||||
getattr(self.crew, "trained_agents_file", None)
|
||||
if self.crew and not isinstance(self.crew, str)
|
||||
else None
|
||||
)
|
||||
trained_file = (
|
||||
os.fspath(crew_trained_agents_file)
|
||||
if crew_trained_agents_file
|
||||
else os.getenv(CREWAI_TRAINED_AGENTS_FILE_ENV, TRAINED_AGENTS_DATA_FILE)
|
||||
)
|
||||
|
||||
if data := CrewTrainingHandler(trained_file).load():
|
||||
if trained_data_output := data.get(self.role):
|
||||
task_prompt += (
|
||||
|
||||
@@ -179,6 +179,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
max_rpm: Maximum number of requests per minute for the crew execution to
|
||||
be respected.
|
||||
prompt_file: Path to the prompt json file to be used for the crew.
|
||||
trained_agents_file: Path to trained agent suggestions loaded during inference.
|
||||
id: A unique identifier for the crew instance.
|
||||
task_callback: Callback to be executed after each task for every agents
|
||||
execution.
|
||||
@@ -303,6 +304,13 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default=None,
|
||||
description="Path to the prompt json file to be used for the crew.",
|
||||
)
|
||||
trained_agents_file: str | Path | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Path to a trained-agents pickle produced by train(). "
|
||||
"When set, agents load suggestions from this file during inference."
|
||||
),
|
||||
)
|
||||
output_log_file: bool | str | None = Field(
|
||||
default=None,
|
||||
description="Path to the log file to be saved",
|
||||
|
||||
@@ -1067,6 +1067,62 @@ def test_agent_use_trained_data_honors_env_var(crew_training_handler, monkeypatc
|
||||
)
|
||||
|
||||
|
||||
@patch("crewai.agent.core.CrewTrainingHandler")
|
||||
def test_agent_use_trained_data_prefers_crew_trained_agents_file(
|
||||
crew_training_handler, monkeypatch
|
||||
):
|
||||
monkeypatch.setenv("CREWAI_TRAINED_AGENTS_FILE", "env_trained.pkl")
|
||||
agent = Agent(
|
||||
role="researcher",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
)
|
||||
task = Task(
|
||||
description="Research the topic",
|
||||
expected_output="A short report",
|
||||
agent=agent,
|
||||
)
|
||||
crew = Crew(agents=[agent], tasks=[task], trained_agents_file="crew_trained.pkl")
|
||||
agent.crew = crew
|
||||
crew_training_handler.return_value.load.return_value = {}
|
||||
|
||||
agent._use_trained_data(task_prompt="What is 1 + 1?")
|
||||
|
||||
crew_training_handler.assert_has_calls(
|
||||
[mock.call("crew_trained.pkl"), mock.call().load()]
|
||||
)
|
||||
|
||||
|
||||
@patch("crewai.agent.core.CrewTrainingHandler")
|
||||
def test_agent_use_trained_data_accepts_crew_trained_agents_file_path(
|
||||
crew_training_handler, tmp_path
|
||||
):
|
||||
agent = Agent(
|
||||
role="researcher",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
)
|
||||
task = Task(
|
||||
description="Research the topic",
|
||||
expected_output="A short report",
|
||||
agent=agent,
|
||||
)
|
||||
trained_agents_file = tmp_path / "crew_trained.pkl"
|
||||
crew = Crew(
|
||||
agents=[agent],
|
||||
tasks=[task],
|
||||
trained_agents_file=trained_agents_file,
|
||||
)
|
||||
agent.crew = crew
|
||||
crew_training_handler.return_value.load.return_value = {}
|
||||
|
||||
agent._use_trained_data(task_prompt="What is 1 + 1?")
|
||||
|
||||
crew_training_handler.assert_has_calls(
|
||||
[mock.call(str(trained_agents_file)), mock.call().load()]
|
||||
)
|
||||
|
||||
|
||||
def test_agent_use_trained_data_skips_load_when_file_missing(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv(
|
||||
"CREWAI_TRAINED_AGENTS_FILE", str(tmp_path / "does_not_exist.pkl")
|
||||
|
||||
@@ -3010,6 +3010,23 @@ def test__setup_for_training(researcher, writer):
|
||||
assert agent.allow_delegation is False
|
||||
|
||||
|
||||
def test_crew_trained_agents_file_is_preserved_on_copy(researcher):
|
||||
task = Task(
|
||||
description="Come up with a list of 5 interesting ideas to explore for an article",
|
||||
expected_output="5 bullet points with a paragraph for each idea.",
|
||||
agent=researcher,
|
||||
)
|
||||
crew = Crew(
|
||||
agents=[researcher],
|
||||
tasks=[task],
|
||||
trained_agents_file="custom_trained_agents.pkl",
|
||||
)
|
||||
|
||||
cloned_crew = crew.copy()
|
||||
|
||||
assert cloned_crew.trained_agents_file == "custom_trained_agents.pkl"
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_replay_feature(researcher, writer):
|
||||
list_ideas = Task(
|
||||
|
||||
Reference in New Issue
Block a user