fix(agent): honor custom trained-agents file at inference

This commit is contained in:
Greyson LaLonde
2026-04-28 22:09:34 +08:00
committed by GitHub
parent a29977f4f6
commit 4e9331a2c8
6 changed files with 121 additions and 12 deletions

View File

@@ -8,6 +8,7 @@ import concurrent.futures
import contextvars
from datetime import datetime
import json
import os
from pathlib import Path
import time
from typing import (
@@ -93,7 +94,11 @@ from crewai.utilities.agent_utils import (
parse_tools,
render_text_description_and_args,
)
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
from crewai.utilities.constants import (
CREWAI_TRAINED_AGENTS_FILE_ENV,
TRAINED_AGENTS_DATA_FILE,
TRAINING_DATA_FILE,
)
from crewai.utilities.converter import Converter, ConverterError
from crewai.utilities.env import get_env_context
from crewai.utilities.guardrail import process_guardrail, serialize_guardrail_for_json
@@ -1181,7 +1186,10 @@ class Agent(BaseAgent):
def _use_trained_data(self, task_prompt: str) -> str:
"""Use trained data for the agent task prompt to improve output."""
if data := CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).load():
trained_file = os.getenv(
CREWAI_TRAINED_AGENTS_FILE_ENV, TRAINED_AGENTS_DATA_FILE
)
if data := CrewTrainingHandler(trained_file).load():
if trained_data_output := data.get(self.role):
task_prompt += (
"\n\nYou MUST follow these instructions: \n - "

View File

@@ -351,9 +351,22 @@ def install(context: click.Context) -> None:
@crewai.command()
def run() -> None:
@click.option(
"-f",
"--filename",
"trained_agents_file",
type=str,
default=None,
help=(
"Path to a trained-agents pickle (produced by `crewai train -f`). "
"When set, agents load suggestions from this file instead of the "
"default trained_agents_data.pkl. Equivalent to setting "
"CREWAI_TRAINED_AGENTS_FILE."
),
)
def run(trained_agents_file: str | None) -> None:
"""Run the Crew."""
run_crew()
run_crew(trained_agents_file=trained_agents_file)
@crewai.command()

View File

@@ -5,6 +5,7 @@ import click
from packaging import version
from crewai.cli.utils import build_env_with_all_tool_credentials, read_toml
from crewai.utilities.constants import CREWAI_TRAINED_AGENTS_FILE_ENV
from crewai.utilities.version import get_crewai_version
@@ -13,13 +14,18 @@ class CrewType(Enum):
FLOW = "flow"
def run_crew() -> None:
"""
Run the crew or flow by running a command in the UV environment.
def run_crew(trained_agents_file: str | None = None) -> None:
"""Run the crew or flow by running a command in the UV environment.
Starting from version 0.103.0, this command can be used to run both
standard crews and flows. For flows, it detects the type from pyproject.toml
and automatically runs the appropriate command.
Args:
trained_agents_file: Optional path to a trained-agents pickle produced
by ``crewai train -f``. When set, exported as
``CREWAI_TRAINED_AGENTS_FILE`` so agents load suggestions from this
file instead of the default ``trained_agents_data.pkl``.
"""
crewai_version = get_crewai_version()
min_required_version = "0.71.0"
@@ -43,19 +49,24 @@ def run_crew() -> None:
click.echo(f"Running the {'Flow' if is_flow else 'Crew'}")
# Execute the appropriate command
execute_command(crew_type)
execute_command(crew_type, trained_agents_file=trained_agents_file)
def execute_command(crew_type: CrewType) -> None:
"""
Execute the appropriate command based on crew type.
def execute_command(
crew_type: CrewType, trained_agents_file: str | None = None
) -> None:
"""Execute the appropriate command based on crew type.
Args:
crew_type: The type of crew to run
crew_type: The type of crew to run.
trained_agents_file: Optional trained-agents pickle path forwarded to
the subprocess via the ``CREWAI_TRAINED_AGENTS_FILE`` env var.
"""
command = ["uv", "run", "kickoff" if crew_type == CrewType.FLOW else "run_crew"]
env = build_env_with_all_tool_credentials()
if trained_agents_file:
env[CREWAI_TRAINED_AGENTS_FILE_ENV] = trained_agents_file
try:
subprocess.run(command, capture_output=False, text=True, check=True, env=env) # noqa: S603

View File

@@ -7,6 +7,7 @@ from crewai.utilities.printer import PrinterColor
TRAINING_DATA_FILE: Final[str] = "training_data.pkl"
TRAINED_AGENTS_DATA_FILE: Final[str] = "trained_agents_data.pkl"
CREWAI_TRAINED_AGENTS_FILE_ENV: Final[str] = "CREWAI_TRAINED_AGENTS_FILE"
KNOWLEDGE_DIRECTORY: Final[str] = "knowledge"
MAX_FILE_NAME_LENGTH: Final[int] = 255
EMITTER_COLOR: Final[PrinterColor] = "bold_blue"

View File

@@ -1064,6 +1064,23 @@ def test_agent_use_trained_data(crew_training_handler):
)
@patch("crewai.agent.core.CrewTrainingHandler")
def test_agent_use_trained_data_honors_env_var(crew_training_handler, monkeypatch):
monkeypatch.setenv("CREWAI_TRAINED_AGENTS_FILE", "my_custom_trained.pkl")
agent = Agent(
role="researcher",
goal="test goal",
backstory="test backstory",
)
crew_training_handler.return_value.load.return_value = {}
agent._use_trained_data(task_prompt="What is 1 + 1?")
crew_training_handler.assert_has_calls(
[mock.call("my_custom_trained.pkl"), mock.call().load()]
)
def test_agent_max_retry_limit():
agent = Agent(
role="test role",

View File

@@ -0,0 +1,59 @@
"""Tests for the ``crewai run`` command and its subprocess plumbing."""
from unittest import mock
from click.testing import CliRunner
import pytest
from crewai.cli.cli import run
from crewai.cli.run_crew import CrewType, execute_command
@pytest.fixture
def runner() -> CliRunner:
return CliRunner()
@mock.patch("crewai.cli.cli.run_crew")
def test_run_passes_filename_to_run_crew(run_crew_mock: mock.Mock, runner: CliRunner) -> None:
result = runner.invoke(run, ["-f", "my_custom_trained.pkl"])
run_crew_mock.assert_called_once_with(trained_agents_file="my_custom_trained.pkl")
assert result.exit_code == 0
@mock.patch("crewai.cli.cli.run_crew")
def test_run_without_filename_passes_none(run_crew_mock: mock.Mock, runner: CliRunner) -> None:
result = runner.invoke(run)
run_crew_mock.assert_called_once_with(trained_agents_file=None)
assert result.exit_code == 0
@mock.patch("crewai.cli.run_crew.subprocess.run")
@mock.patch(
"crewai.cli.run_crew.build_env_with_all_tool_credentials",
return_value={"EXISTING": "value"},
)
def test_execute_command_sets_env_var_when_filename_provided(
_build_env: mock.Mock, subprocess_run: mock.Mock
) -> None:
execute_command(CrewType.STANDARD, trained_agents_file="my_custom_trained.pkl")
_, kwargs = subprocess_run.call_args
assert kwargs["env"]["CREWAI_TRAINED_AGENTS_FILE"] == "my_custom_trained.pkl"
assert kwargs["env"]["EXISTING"] == "value"
@mock.patch("crewai.cli.run_crew.subprocess.run")
@mock.patch(
"crewai.cli.run_crew.build_env_with_all_tool_credentials",
return_value={"EXISTING": "value"},
)
def test_execute_command_omits_env_var_when_filename_absent(
_build_env: mock.Mock, subprocess_run: mock.Mock
) -> None:
execute_command(CrewType.STANDARD)
_, kwargs = subprocess_run.call_args
assert "CREWAI_TRAINED_AGENTS_FILE" not in kwargs["env"]