mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
Fix issue #2724: Allow specifying trained data file for run command
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -118,6 +118,10 @@ class Agent(BaseAgent):
|
|||||||
default=None,
|
default=None,
|
||||||
description="Knowledge context for the agent.",
|
description="Knowledge context for the agent.",
|
||||||
)
|
)
|
||||||
|
trained_data_file: str = Field(
|
||||||
|
default=TRAINED_AGENTS_DATA_FILE,
|
||||||
|
description="Path to the trained data file to use for task prompts.",
|
||||||
|
)
|
||||||
crew_knowledge_context: Optional[str] = Field(
|
crew_knowledge_context: Optional[str] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Knowledge context for the crew.",
|
description="Knowledge context for the crew.",
|
||||||
@@ -498,7 +502,7 @@ class Agent(BaseAgent):
|
|||||||
|
|
||||||
def _use_trained_data(self, task_prompt: str) -> str:
|
def _use_trained_data(self, task_prompt: str) -> str:
|
||||||
"""Use trained data for the agent task prompt to improve output."""
|
"""Use trained data for the agent task prompt to improve output."""
|
||||||
if data := CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).load():
|
if data := CrewTrainingHandler(self.trained_data_file).load():
|
||||||
if trained_data_output := data.get(self.role):
|
if trained_data_output := data.get(self.role):
|
||||||
task_prompt += (
|
task_prompt += (
|
||||||
"\n\nYou MUST follow these instructions: \n - "
|
"\n\nYou MUST follow these instructions: \n - "
|
||||||
|
|||||||
@@ -201,9 +201,17 @@ def install(context):
|
|||||||
|
|
||||||
|
|
||||||
@crewai.command()
|
@crewai.command()
|
||||||
def run():
|
@click.option(
|
||||||
|
"-f",
|
||||||
|
"--filename",
|
||||||
|
type=str,
|
||||||
|
default="trained_agents_data.pkl",
|
||||||
|
help="Path to a trained data file to use",
|
||||||
|
)
|
||||||
|
def run(filename: str):
|
||||||
"""Run the Crew."""
|
"""Run the Crew."""
|
||||||
run_crew()
|
click.echo(f"Running the Crew with trained data from {filename}")
|
||||||
|
run_crew(trained_data_file=filename)
|
||||||
|
|
||||||
|
|
||||||
@crewai.command()
|
@crewai.command()
|
||||||
|
|||||||
@@ -14,13 +14,16 @@ class CrewType(Enum):
|
|||||||
FLOW = "flow"
|
FLOW = "flow"
|
||||||
|
|
||||||
|
|
||||||
def run_crew() -> None:
|
def run_crew(trained_data_file: Optional[str] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Run the crew or flow by running a command in the UV environment.
|
Run the crew or flow by running a command in the UV environment.
|
||||||
|
|
||||||
Starting from version 0.103.0, this command can be used to run both
|
Starting from version 0.103.0, this command can be used to run both
|
||||||
standard crews and flows. For flows, it detects the type from pyproject.toml
|
standard crews and flows. For flows, it detects the type from pyproject.toml
|
||||||
and automatically runs the appropriate command.
|
and automatically runs the appropriate command.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trained_data_file: Optional path to a trained data file to use
|
||||||
"""
|
"""
|
||||||
crewai_version = get_crewai_version()
|
crewai_version = get_crewai_version()
|
||||||
min_required_version = "0.71.0"
|
min_required_version = "0.71.0"
|
||||||
@@ -44,17 +47,21 @@ def run_crew() -> None:
|
|||||||
click.echo(f"Running the {'Flow' if is_flow else 'Crew'}")
|
click.echo(f"Running the {'Flow' if is_flow else 'Crew'}")
|
||||||
|
|
||||||
# Execute the appropriate command
|
# Execute the appropriate command
|
||||||
execute_command(crew_type)
|
execute_command(crew_type, trained_data_file)
|
||||||
|
|
||||||
|
|
||||||
def execute_command(crew_type: CrewType) -> None:
|
def execute_command(crew_type: CrewType, trained_data_file: Optional[str] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Execute the appropriate command based on crew type.
|
Execute the appropriate command based on crew type.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
crew_type: The type of crew to run
|
crew_type: The type of crew to run
|
||||||
|
trained_data_file: Optional path to a trained data file to use
|
||||||
"""
|
"""
|
||||||
command = ["uv", "run", "kickoff" if crew_type == CrewType.FLOW else "run_crew"]
|
command = ["uv", "run", "kickoff" if crew_type == CrewType.FLOW else "run_crew"]
|
||||||
|
|
||||||
|
if trained_data_file and crew_type == CrewType.STANDARD:
|
||||||
|
command.extend(["--trained-data-file", trained_data_file])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
subprocess.run(command, capture_output=False, text=True, check=True)
|
subprocess.run(command, capture_output=False, text=True, check=True)
|
||||||
|
|||||||
@@ -23,7 +23,8 @@ def run():
|
|||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
{{crew_name}}().crew().kickoff(inputs=inputs)
|
filename = sys.argv[1] if len(sys.argv) > 1 else "trained_agents_data.pkl"
|
||||||
|
{{crew_name}}().crew(trained_data_file=filename).kickoff(inputs=inputs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f"An error occurred while running the crew: {e}")
|
raise Exception(f"An error occurred while running the crew: {e}")
|
||||||
|
|
||||||
|
|||||||
@@ -122,6 +122,10 @@ class Crew(BaseModel):
|
|||||||
tasks: List[Task] = Field(default_factory=list)
|
tasks: List[Task] = Field(default_factory=list)
|
||||||
agents: List[BaseAgent] = Field(default_factory=list)
|
agents: List[BaseAgent] = Field(default_factory=list)
|
||||||
process: Process = Field(default=Process.sequential)
|
process: Process = Field(default=Process.sequential)
|
||||||
|
trained_data_file: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Path to the trained data file to use for agent task prompts.",
|
||||||
|
)
|
||||||
verbose: bool = Field(default=False)
|
verbose: bool = Field(default=False)
|
||||||
memory: bool = Field(
|
memory: bool = Field(
|
||||||
default=False,
|
default=False,
|
||||||
@@ -1196,7 +1200,12 @@ class Crew(BaseModel):
|
|||||||
"manager_llm",
|
"manager_llm",
|
||||||
}
|
}
|
||||||
|
|
||||||
cloned_agents = [agent.copy() for agent in self.agents]
|
cloned_agents = []
|
||||||
|
for agent in self.agents:
|
||||||
|
cloned_agent = agent.copy()
|
||||||
|
if self.trained_data_file:
|
||||||
|
cloned_agent.trained_data_file = self.trained_data_file
|
||||||
|
cloned_agents.append(cloned_agent)
|
||||||
manager_agent = self.manager_agent.copy() if self.manager_agent else None
|
manager_agent = self.manager_agent.copy() if self.manager_agent else None
|
||||||
manager_llm = shallow_copy(self.manager_llm) if self.manager_llm else None
|
manager_llm = shallow_copy(self.manager_llm) if self.manager_llm else None
|
||||||
|
|
||||||
|
|||||||
@@ -1196,6 +1196,37 @@ def test_agent_use_trained_data(crew_training_handler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("crewai.agent.CrewTrainingHandler")
|
||||||
|
def test_agent_use_custom_trained_data_file(crew_training_handler):
|
||||||
|
task_prompt = "What is 1 + 1?"
|
||||||
|
custom_file = "custom_trained_data.pkl"
|
||||||
|
agent = Agent(
|
||||||
|
role="researcher",
|
||||||
|
goal="test goal",
|
||||||
|
backstory="test backstory",
|
||||||
|
verbose=True,
|
||||||
|
trained_data_file=custom_file
|
||||||
|
)
|
||||||
|
crew_training_handler().load.return_value = {
|
||||||
|
agent.role: {
|
||||||
|
"suggestions": [
|
||||||
|
"The result of the math operation must be right.",
|
||||||
|
"Result must be better than 1.",
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result = agent._use_trained_data(task_prompt=task_prompt)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
result == "What is 1 + 1?\n\nYou MUST follow these instructions: \n"
|
||||||
|
" - The result of the math operation must be right.\n - Result must be better than 1."
|
||||||
|
)
|
||||||
|
crew_training_handler.assert_has_calls(
|
||||||
|
[mock.call(), mock.call(custom_file), mock.call().load()]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_agent_max_retry_limit():
|
def test_agent_max_retry_limit():
|
||||||
agent = Agent(
|
agent = Agent(
|
||||||
role="test role",
|
role="test role",
|
||||||
|
|||||||
Reference in New Issue
Block a user