diff --git a/docs/en/concepts/training.mdx b/docs/en/concepts/training.mdx index 6eff75772..468e1acef 100644 --- a/docs/en/concepts/training.mdx +++ b/docs/en/concepts/training.mdx @@ -187,7 +187,7 @@ flowchart TD - **Filename Requirement:** Ensure that the filename ends with `.pkl`. The code will raise a `ValueError` if this condition is not met. - **Error Handling:** The code handles subprocess errors and unexpected exceptions, providing error messages to the user. - Trained guidance is applied at prompt time; it does not modify your Python/YAML agent configuration. -- Agents automatically load trained suggestions from a file named `trained_agents_data.pkl` located in the current working directory. If you trained to a different filename, either rename it to `trained_agents_data.pkl` before running, or adjust the loader in code. +- Agents automatically load trained suggestions from a file named `trained_agents_data.pkl` located in the current working directory. If you trained to a different filename, pass that path with `Crew(trained_agents_file="my_custom_trained.pkl")`, set `CREWAI_TRAINED_AGENTS_FILE`, or use `crewai run -f my_custom_trained.pkl`. - You can change the output filename when calling `crewai train` with `-f/--filename`. Absolute paths are supported if you want to save outside the CWD. It is important to note that the training process may take some time, depending on the complexity of your agents and will also require your feedback on each iteration. diff --git a/lib/crewai/src/crewai/agent/core.py b/lib/crewai/src/crewai/agent/core.py index 73020f115..2686d66ff 100644 --- a/lib/crewai/src/crewai/agent/core.py +++ b/lib/crewai/src/crewai/agent/core.py @@ -1219,9 +1219,17 @@ class Agent(BaseAgent): def _use_trained_data(self, task_prompt: str) -> str: """Use trained data for the agent task prompt to improve output.""" - trained_file = os.getenv( - CREWAI_TRAINED_AGENTS_FILE_ENV, TRAINED_AGENTS_DATA_FILE + crew_trained_agents_file = ( + getattr(self.crew, "trained_agents_file", None) + if self.crew and not isinstance(self.crew, str) + else None ) + trained_file = ( + os.fspath(crew_trained_agents_file) + if crew_trained_agents_file + else os.getenv(CREWAI_TRAINED_AGENTS_FILE_ENV, TRAINED_AGENTS_DATA_FILE) + ) + if data := CrewTrainingHandler(trained_file).load(): if trained_data_output := data.get(self.role): task_prompt += ( diff --git a/lib/crewai/src/crewai/crew.py b/lib/crewai/src/crewai/crew.py index 1221c10f6..b2cebd3ed 100644 --- a/lib/crewai/src/crewai/crew.py +++ b/lib/crewai/src/crewai/crew.py @@ -179,6 +179,7 @@ class Crew(FlowTrackable, BaseModel): max_rpm: Maximum number of requests per minute for the crew execution to be respected. prompt_file: Path to the prompt json file to be used for the crew. + trained_agents_file: Path to trained agent suggestions loaded during inference. id: A unique identifier for the crew instance. task_callback: Callback to be executed after each task for every agents execution. @@ -303,6 +304,13 @@ class Crew(FlowTrackable, BaseModel): default=None, description="Path to the prompt json file to be used for the crew.", ) + trained_agents_file: str | Path | None = Field( + default=None, + description=( + "Path to a trained-agents pickle produced by train(). " + "When set, agents load suggestions from this file during inference." + ), + ) output_log_file: bool | str | None = Field( default=None, description="Path to the log file to be saved", diff --git a/lib/crewai/tests/agents/test_agent.py b/lib/crewai/tests/agents/test_agent.py index 25c8b4040..89c1689cf 100644 --- a/lib/crewai/tests/agents/test_agent.py +++ b/lib/crewai/tests/agents/test_agent.py @@ -1067,6 +1067,62 @@ def test_agent_use_trained_data_honors_env_var(crew_training_handler, monkeypatc ) +@patch("crewai.agent.core.CrewTrainingHandler") +def test_agent_use_trained_data_prefers_crew_trained_agents_file( + crew_training_handler, monkeypatch +): + monkeypatch.setenv("CREWAI_TRAINED_AGENTS_FILE", "env_trained.pkl") + agent = Agent( + role="researcher", + goal="test goal", + backstory="test backstory", + ) + task = Task( + description="Research the topic", + expected_output="A short report", + agent=agent, + ) + crew = Crew(agents=[agent], tasks=[task], trained_agents_file="crew_trained.pkl") + agent.crew = crew + crew_training_handler.return_value.load.return_value = {} + + agent._use_trained_data(task_prompt="What is 1 + 1?") + + crew_training_handler.assert_has_calls( + [mock.call("crew_trained.pkl"), mock.call().load()] + ) + + +@patch("crewai.agent.core.CrewTrainingHandler") +def test_agent_use_trained_data_accepts_crew_trained_agents_file_path( + crew_training_handler, tmp_path +): + agent = Agent( + role="researcher", + goal="test goal", + backstory="test backstory", + ) + task = Task( + description="Research the topic", + expected_output="A short report", + agent=agent, + ) + trained_agents_file = tmp_path / "crew_trained.pkl" + crew = Crew( + agents=[agent], + tasks=[task], + trained_agents_file=trained_agents_file, + ) + agent.crew = crew + crew_training_handler.return_value.load.return_value = {} + + agent._use_trained_data(task_prompt="What is 1 + 1?") + + crew_training_handler.assert_has_calls( + [mock.call(str(trained_agents_file)), mock.call().load()] + ) + + def test_agent_use_trained_data_skips_load_when_file_missing(tmp_path, monkeypatch): monkeypatch.setenv( "CREWAI_TRAINED_AGENTS_FILE", str(tmp_path / "does_not_exist.pkl") diff --git a/lib/crewai/tests/test_crew.py b/lib/crewai/tests/test_crew.py index 2a09733dc..8ce25774e 100644 --- a/lib/crewai/tests/test_crew.py +++ b/lib/crewai/tests/test_crew.py @@ -3010,6 +3010,23 @@ def test__setup_for_training(researcher, writer): assert agent.allow_delegation is False +def test_crew_trained_agents_file_is_preserved_on_copy(researcher): + task = Task( + description="Come up with a list of 5 interesting ideas to explore for an article", + expected_output="5 bullet points with a paragraph for each idea.", + agent=researcher, + ) + crew = Crew( + agents=[researcher], + tasks=[task], + trained_agents_file="custom_trained_agents.pkl", + ) + + cloned_crew = crew.copy() + + assert cloned_crew.trained_agents_file == "custom_trained_agents.pkl" + + @pytest.mark.vcr() def test_replay_feature(researcher, writer): list_ideas = Task(