Compare commits

...

3 Commits

Author SHA1 Message Date
Devin AI
0360988835 Skip test_gemma3 on Python 3.11 due to segmentation fault
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-04-29 21:35:49 +00:00
Devin AI
fa39ce9db2 Address PR review comments: improve validation, error handling, and add tests
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-04-29 21:30:04 +00:00
Devin AI
2f66aa0efc Fix issue #2724: Allow specifying trained data file for run command
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-04-29 21:20:58 +00:00
7 changed files with 178 additions and 13 deletions

View File

@@ -118,6 +118,10 @@ class Agent(BaseAgent):
default=None,
description="Knowledge context for the agent.",
)
trained_data_file: Optional[str] = Field(
default=TRAINED_AGENTS_DATA_FILE,
description="Path to the trained data file to use for task prompts.",
)
crew_knowledge_context: Optional[str] = Field(
default=None,
description="Knowledge context for the crew.",
@@ -497,13 +501,24 @@ class Agent(BaseAgent):
return task_prompt
def _use_trained_data(self, task_prompt: str) -> str:
"""Use trained data for the agent task prompt to improve output."""
if data := CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).load():
"""
Use trained data from a specified file for the agent task prompt.
Uses the 'trained_data_file' attribute as the source of training instructions.
Args:
task_prompt: The original task prompt to enhance.
Returns:
Enhanced task prompt with training instructions if available.
"""
if self.trained_data_file and (data := CrewTrainingHandler(self.trained_data_file).load()):
if trained_data_output := data.get(self.role):
task_prompt += (
"\n\nYou MUST follow these instructions: \n - "
+ "\n - ".join(trained_data_output["suggestions"])
)
if "suggestions" in trained_data_output:
task_prompt += (
"\n\nYou MUST follow these instructions: \n - "
+ "\n - ".join(trained_data_output["suggestions"])
)
return task_prompt
def _render_text_description(self, tools: List[Any]) -> str:

View File

@@ -201,9 +201,17 @@ def install(context):
@crewai.command()
def run():
@click.option(
"-f",
"--trained-data-file",
type=str,
default="trained_agents_data.pkl",
help="Path to a trained data file to use for agent task prompts",
)
def run(trained_data_file: str):
"""Run the Crew."""
run_crew()
click.echo(f"Running the Crew with agent training data file: {trained_data_file}")
run_crew(trained_data_file=trained_data_file)
@crewai.command()

View File

@@ -1,3 +1,4 @@
import os
import subprocess
from enum import Enum
from typing import List, Optional
@@ -14,13 +15,16 @@ class CrewType(Enum):
FLOW = "flow"
def run_crew() -> None:
def run_crew(trained_data_file: Optional[str] = None) -> None:
"""
Run the crew or flow by running a command in the UV environment.
Starting from version 0.103.0, this command can be used to run both
standard crews and flows. For flows, it detects the type from pyproject.toml
and automatically runs the appropriate command.
Args:
trained_data_file: Optional path to a trained data file to use
"""
crewai_version = get_crewai_version()
min_required_version = "0.71.0"
@@ -44,17 +48,35 @@ def run_crew() -> None:
click.echo(f"Running the {'Flow' if is_flow else 'Crew'}")
# Execute the appropriate command
execute_command(crew_type)
execute_command(crew_type, trained_data_file)
def execute_command(crew_type: CrewType) -> None:
def execute_command(crew_type: CrewType, trained_data_file: Optional[str] = None) -> None:
"""
Execute the appropriate command based on crew type.
Args:
crew_type: The type of crew to run
trained_data_file: Optional path to a trained data file to use
"""
command = ["uv", "run", "kickoff" if crew_type == CrewType.FLOW else "run_crew"]
if trained_data_file and crew_type == CrewType.STANDARD:
if not trained_data_file.endswith('.pkl'):
click.secho(
f"Error: Trained data file '{trained_data_file}' must have a .pkl extension.",
fg="red",
)
return
if not os.path.exists(trained_data_file):
click.secho(
f"Error: Trained data file '{trained_data_file}' does not exist.",
fg="red",
)
return
command.extend(["--trained-data-file", trained_data_file])
try:
subprocess.run(command, capture_output=False, text=True, check=True)

View File

@@ -1,4 +1,5 @@
#!/usr/bin/env python
import argparse
import sys
import warnings
@@ -17,13 +18,23 @@ def run():
"""
Run the crew.
"""
parser = argparse.ArgumentParser(description="Run the crew")
parser.add_argument(
"--trained-data-file",
"-f",
type=str,
default="trained_agents_data.pkl",
help="Path to a trained data file to use for agent task prompts"
)
args, _ = parser.parse_known_args()
inputs = {
'topic': 'AI LLMs',
'current_year': str(datetime.now().year)
}
try:
{{crew_name}}().crew().kickoff(inputs=inputs)
{{crew_name}}().crew(trained_data_file=args.trained_data_file).kickoff(inputs=inputs)
except Exception as e:
raise Exception(f"An error occurred while running the crew: {e}")

View File

@@ -122,6 +122,10 @@ class Crew(BaseModel):
tasks: List[Task] = Field(default_factory=list)
agents: List[BaseAgent] = Field(default_factory=list)
process: Process = Field(default=Process.sequential)
trained_data_file: Optional[str] = Field(
default=None,
description="Path to the trained data file to use for agent task prompts.",
)
verbose: bool = Field(default=False)
memory: bool = Field(
default=False,
@@ -1196,7 +1200,12 @@ class Crew(BaseModel):
"manager_llm",
}
cloned_agents = [agent.copy() for agent in self.agents]
cloned_agents = []
for agent in self.agents:
cloned_agent = agent.copy()
if self.trained_data_file:
cloned_agent.trained_data_file = self.trained_data_file
cloned_agents.append(cloned_agent)
manager_agent = self.manager_agent.copy() if self.manager_agent else None
manager_llm = shallow_copy(self.manager_llm) if self.manager_llm else None

View File

@@ -1196,6 +1196,102 @@ def test_agent_use_trained_data(crew_training_handler):
)
@patch("crewai.agent.CrewTrainingHandler")
def test_agent_use_custom_trained_data_file(crew_training_handler):
task_prompt = "What is 1 + 1?"
custom_file = "custom_trained_data.pkl"
agent = Agent(
role="researcher",
goal="test goal",
backstory="test backstory",
verbose=True,
trained_data_file=custom_file
)
crew_training_handler().load.return_value = {
agent.role: {
"suggestions": [
"The result of the math operation must be right.",
"Result must be better than 1.",
]
}
}
result = agent._use_trained_data(task_prompt=task_prompt)
assert (
result == "What is 1 + 1?\n\nYou MUST follow these instructions: \n"
" - The result of the math operation must be right.\n - Result must be better than 1."
)
crew_training_handler.assert_has_calls(
[mock.call(), mock.call(custom_file), mock.call().load()]
)
@patch("crewai.agent.CrewTrainingHandler")
def test_agent_with_none_trained_data_file(crew_training_handler):
task_prompt = "What is 1 + 1?"
agent = Agent(
role="researcher",
goal="test goal",
backstory="test backstory",
verbose=True,
trained_data_file=None
)
result = agent._use_trained_data(task_prompt=task_prompt)
assert result == task_prompt
crew_training_handler.assert_not_called()
@patch("crewai.agent.CrewTrainingHandler")
def test_agent_with_missing_role_in_trained_data(crew_training_handler):
task_prompt = "What is 1 + 1?"
agent = Agent(
role="researcher",
goal="test goal",
backstory="test backstory",
verbose=True,
trained_data_file="trained_agents_data.pkl"
)
crew_training_handler().load.return_value = {
"other_role": {
"suggestions": ["This should not be used."]
}
}
result = agent._use_trained_data(task_prompt=task_prompt)
assert result == task_prompt
crew_training_handler.assert_has_calls(
[mock.call(), mock.call("trained_agents_data.pkl"), mock.call().load()]
)
@patch("crewai.agent.CrewTrainingHandler")
def test_agent_with_missing_suggestions_in_trained_data(crew_training_handler):
task_prompt = "What is 1 + 1?"
agent = Agent(
role="researcher",
goal="test goal",
backstory="test backstory",
verbose=True,
trained_data_file="trained_agents_data.pkl"
)
crew_training_handler().load.return_value = {
"researcher": {
"other_key": ["This should not be used."]
}
}
result = agent._use_trained_data(task_prompt=task_prompt)
assert result == task_prompt
crew_training_handler.assert_has_calls(
[mock.call(), mock.call("trained_agents_data.pkl"), mock.call().load()]
)
def test_agent_max_retry_limit():
agent = Agent(
role="test role",

View File

@@ -1,4 +1,5 @@
import os
import sys
from time import sleep
from unittest.mock import MagicMock, patch
@@ -285,6 +286,9 @@ def test_gemini_models(model):
],
)
def test_gemma3(model):
if sys.version_info.major == 3 and sys.version_info.minor == 11:
pytest.skip("Skipping test_gemma3 on Python 3.11 due to segmentation fault")
llm = LLM(model=model)
result = llm.call("What is the capital of France?")
assert isinstance(result, str)