mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 09:08:31 +00:00
Fix issue #2724: Allow specifying trained data file for run command
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -1196,6 +1196,37 @@ def test_agent_use_trained_data(crew_training_handler):
|
||||
)
|
||||
|
||||
|
||||
@patch("crewai.agent.CrewTrainingHandler")
|
||||
def test_agent_use_custom_trained_data_file(crew_training_handler):
|
||||
task_prompt = "What is 1 + 1?"
|
||||
custom_file = "custom_trained_data.pkl"
|
||||
agent = Agent(
|
||||
role="researcher",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
verbose=True,
|
||||
trained_data_file=custom_file
|
||||
)
|
||||
crew_training_handler().load.return_value = {
|
||||
agent.role: {
|
||||
"suggestions": [
|
||||
"The result of the math operation must be right.",
|
||||
"Result must be better than 1.",
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
result = agent._use_trained_data(task_prompt=task_prompt)
|
||||
|
||||
assert (
|
||||
result == "What is 1 + 1?\n\nYou MUST follow these instructions: \n"
|
||||
" - The result of the math operation must be right.\n - Result must be better than 1."
|
||||
)
|
||||
crew_training_handler.assert_has_calls(
|
||||
[mock.call(), mock.call(custom_file), mock.call().load()]
|
||||
)
|
||||
|
||||
|
||||
def test_agent_max_retry_limit():
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
|
||||
Reference in New Issue
Block a user