mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-04 16:52:37 +00:00
fix(agent): honor custom trained-agents file at inference
This commit is contained in:
@@ -8,6 +8,7 @@ import concurrent.futures
|
||||
import contextvars
|
||||
from datetime import datetime
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
import time
|
||||
from typing import (
|
||||
@@ -93,7 +94,11 @@ from crewai.utilities.agent_utils import (
|
||||
parse_tools,
|
||||
render_text_description_and_args,
|
||||
)
|
||||
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
|
||||
from crewai.utilities.constants import (
|
||||
CREWAI_TRAINED_AGENTS_FILE_ENV,
|
||||
TRAINED_AGENTS_DATA_FILE,
|
||||
TRAINING_DATA_FILE,
|
||||
)
|
||||
from crewai.utilities.converter import Converter, ConverterError
|
||||
from crewai.utilities.env import get_env_context
|
||||
from crewai.utilities.guardrail import process_guardrail, serialize_guardrail_for_json
|
||||
@@ -1181,7 +1186,10 @@ class Agent(BaseAgent):
|
||||
|
||||
def _use_trained_data(self, task_prompt: str) -> str:
|
||||
"""Use trained data for the agent task prompt to improve output."""
|
||||
if data := CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).load():
|
||||
trained_file = os.getenv(
|
||||
CREWAI_TRAINED_AGENTS_FILE_ENV, TRAINED_AGENTS_DATA_FILE
|
||||
)
|
||||
if data := CrewTrainingHandler(trained_file).load():
|
||||
if trained_data_output := data.get(self.role):
|
||||
task_prompt += (
|
||||
"\n\nYou MUST follow these instructions: \n - "
|
||||
|
||||
@@ -351,9 +351,22 @@ def install(context: click.Context) -> None:
|
||||
|
||||
|
||||
@crewai.command()
|
||||
def run() -> None:
|
||||
@click.option(
|
||||
"-f",
|
||||
"--filename",
|
||||
"trained_agents_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Path to a trained-agents pickle (produced by `crewai train -f`). "
|
||||
"When set, agents load suggestions from this file instead of the "
|
||||
"default trained_agents_data.pkl. Equivalent to setting "
|
||||
"CREWAI_TRAINED_AGENTS_FILE."
|
||||
),
|
||||
)
|
||||
def run(trained_agents_file: str | None) -> None:
|
||||
"""Run the Crew."""
|
||||
run_crew()
|
||||
run_crew(trained_agents_file=trained_agents_file)
|
||||
|
||||
|
||||
@crewai.command()
|
||||
|
||||
@@ -5,6 +5,7 @@ import click
|
||||
from packaging import version
|
||||
|
||||
from crewai.cli.utils import build_env_with_all_tool_credentials, read_toml
|
||||
from crewai.utilities.constants import CREWAI_TRAINED_AGENTS_FILE_ENV
|
||||
from crewai.utilities.version import get_crewai_version
|
||||
|
||||
|
||||
@@ -13,13 +14,18 @@ class CrewType(Enum):
|
||||
FLOW = "flow"
|
||||
|
||||
|
||||
def run_crew() -> None:
|
||||
"""
|
||||
Run the crew or flow by running a command in the UV environment.
|
||||
def run_crew(trained_agents_file: str | None = None) -> None:
|
||||
"""Run the crew or flow by running a command in the UV environment.
|
||||
|
||||
Starting from version 0.103.0, this command can be used to run both
|
||||
standard crews and flows. For flows, it detects the type from pyproject.toml
|
||||
and automatically runs the appropriate command.
|
||||
|
||||
Args:
|
||||
trained_agents_file: Optional path to a trained-agents pickle produced
|
||||
by ``crewai train -f``. When set, exported as
|
||||
``CREWAI_TRAINED_AGENTS_FILE`` so agents load suggestions from this
|
||||
file instead of the default ``trained_agents_data.pkl``.
|
||||
"""
|
||||
crewai_version = get_crewai_version()
|
||||
min_required_version = "0.71.0"
|
||||
@@ -43,19 +49,24 @@ def run_crew() -> None:
|
||||
click.echo(f"Running the {'Flow' if is_flow else 'Crew'}")
|
||||
|
||||
# Execute the appropriate command
|
||||
execute_command(crew_type)
|
||||
execute_command(crew_type, trained_agents_file=trained_agents_file)
|
||||
|
||||
|
||||
def execute_command(crew_type: CrewType) -> None:
|
||||
"""
|
||||
Execute the appropriate command based on crew type.
|
||||
def execute_command(
|
||||
crew_type: CrewType, trained_agents_file: str | None = None
|
||||
) -> None:
|
||||
"""Execute the appropriate command based on crew type.
|
||||
|
||||
Args:
|
||||
crew_type: The type of crew to run
|
||||
crew_type: The type of crew to run.
|
||||
trained_agents_file: Optional trained-agents pickle path forwarded to
|
||||
the subprocess via the ``CREWAI_TRAINED_AGENTS_FILE`` env var.
|
||||
"""
|
||||
command = ["uv", "run", "kickoff" if crew_type == CrewType.FLOW else "run_crew"]
|
||||
|
||||
env = build_env_with_all_tool_credentials()
|
||||
if trained_agents_file:
|
||||
env[CREWAI_TRAINED_AGENTS_FILE_ENV] = trained_agents_file
|
||||
|
||||
try:
|
||||
subprocess.run(command, capture_output=False, text=True, check=True, env=env) # noqa: S603
|
||||
|
||||
@@ -7,6 +7,7 @@ from crewai.utilities.printer import PrinterColor
|
||||
|
||||
TRAINING_DATA_FILE: Final[str] = "training_data.pkl"
|
||||
TRAINED_AGENTS_DATA_FILE: Final[str] = "trained_agents_data.pkl"
|
||||
CREWAI_TRAINED_AGENTS_FILE_ENV: Final[str] = "CREWAI_TRAINED_AGENTS_FILE"
|
||||
KNOWLEDGE_DIRECTORY: Final[str] = "knowledge"
|
||||
MAX_FILE_NAME_LENGTH: Final[int] = 255
|
||||
EMITTER_COLOR: Final[PrinterColor] = "bold_blue"
|
||||
|
||||
@@ -1064,6 +1064,23 @@ def test_agent_use_trained_data(crew_training_handler):
|
||||
)
|
||||
|
||||
|
||||
@patch("crewai.agent.core.CrewTrainingHandler")
|
||||
def test_agent_use_trained_data_honors_env_var(crew_training_handler, monkeypatch):
|
||||
monkeypatch.setenv("CREWAI_TRAINED_AGENTS_FILE", "my_custom_trained.pkl")
|
||||
agent = Agent(
|
||||
role="researcher",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
)
|
||||
crew_training_handler.return_value.load.return_value = {}
|
||||
|
||||
agent._use_trained_data(task_prompt="What is 1 + 1?")
|
||||
|
||||
crew_training_handler.assert_has_calls(
|
||||
[mock.call("my_custom_trained.pkl"), mock.call().load()]
|
||||
)
|
||||
|
||||
|
||||
def test_agent_max_retry_limit():
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
|
||||
59
lib/crewai/tests/cli/test_run_crew.py
Normal file
59
lib/crewai/tests/cli/test_run_crew.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""Tests for the ``crewai run`` command and its subprocess plumbing."""
|
||||
|
||||
from unittest import mock
|
||||
|
||||
from click.testing import CliRunner
|
||||
import pytest
|
||||
|
||||
from crewai.cli.cli import run
|
||||
from crewai.cli.run_crew import CrewType, execute_command
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner() -> CliRunner:
|
||||
return CliRunner()
|
||||
|
||||
|
||||
@mock.patch("crewai.cli.cli.run_crew")
|
||||
def test_run_passes_filename_to_run_crew(run_crew_mock: mock.Mock, runner: CliRunner) -> None:
|
||||
result = runner.invoke(run, ["-f", "my_custom_trained.pkl"])
|
||||
|
||||
run_crew_mock.assert_called_once_with(trained_agents_file="my_custom_trained.pkl")
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
@mock.patch("crewai.cli.cli.run_crew")
|
||||
def test_run_without_filename_passes_none(run_crew_mock: mock.Mock, runner: CliRunner) -> None:
|
||||
result = runner.invoke(run)
|
||||
|
||||
run_crew_mock.assert_called_once_with(trained_agents_file=None)
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
@mock.patch("crewai.cli.run_crew.subprocess.run")
|
||||
@mock.patch(
|
||||
"crewai.cli.run_crew.build_env_with_all_tool_credentials",
|
||||
return_value={"EXISTING": "value"},
|
||||
)
|
||||
def test_execute_command_sets_env_var_when_filename_provided(
|
||||
_build_env: mock.Mock, subprocess_run: mock.Mock
|
||||
) -> None:
|
||||
execute_command(CrewType.STANDARD, trained_agents_file="my_custom_trained.pkl")
|
||||
|
||||
_, kwargs = subprocess_run.call_args
|
||||
assert kwargs["env"]["CREWAI_TRAINED_AGENTS_FILE"] == "my_custom_trained.pkl"
|
||||
assert kwargs["env"]["EXISTING"] == "value"
|
||||
|
||||
|
||||
@mock.patch("crewai.cli.run_crew.subprocess.run")
|
||||
@mock.patch(
|
||||
"crewai.cli.run_crew.build_env_with_all_tool_credentials",
|
||||
return_value={"EXISTING": "value"},
|
||||
)
|
||||
def test_execute_command_omits_env_var_when_filename_absent(
|
||||
_build_env: mock.Mock, subprocess_run: mock.Mock
|
||||
) -> None:
|
||||
execute_command(CrewType.STANDARD)
|
||||
|
||||
_, kwargs = subprocess_run.call_args
|
||||
assert "CREWAI_TRAINED_AGENTS_FILE" not in kwargs["env"]
|
||||
Reference in New Issue
Block a user