mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 12:28:30 +00:00
Compare commits
3 Commits
lg-trigger
...
devin/1745
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0360988835 | ||
|
|
fa39ce9db2 | ||
|
|
2f66aa0efc |
@@ -118,6 +118,10 @@ class Agent(BaseAgent):
|
||||
default=None,
|
||||
description="Knowledge context for the agent.",
|
||||
)
|
||||
trained_data_file: Optional[str] = Field(
|
||||
default=TRAINED_AGENTS_DATA_FILE,
|
||||
description="Path to the trained data file to use for task prompts.",
|
||||
)
|
||||
crew_knowledge_context: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Knowledge context for the crew.",
|
||||
@@ -497,13 +501,24 @@ class Agent(BaseAgent):
|
||||
return task_prompt
|
||||
|
||||
def _use_trained_data(self, task_prompt: str) -> str:
|
||||
"""Use trained data for the agent task prompt to improve output."""
|
||||
if data := CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).load():
|
||||
"""
|
||||
Use trained data from a specified file for the agent task prompt.
|
||||
|
||||
Uses the 'trained_data_file' attribute as the source of training instructions.
|
||||
|
||||
Args:
|
||||
task_prompt: The original task prompt to enhance.
|
||||
|
||||
Returns:
|
||||
Enhanced task prompt with training instructions if available.
|
||||
"""
|
||||
if self.trained_data_file and (data := CrewTrainingHandler(self.trained_data_file).load()):
|
||||
if trained_data_output := data.get(self.role):
|
||||
task_prompt += (
|
||||
"\n\nYou MUST follow these instructions: \n - "
|
||||
+ "\n - ".join(trained_data_output["suggestions"])
|
||||
)
|
||||
if "suggestions" in trained_data_output:
|
||||
task_prompt += (
|
||||
"\n\nYou MUST follow these instructions: \n - "
|
||||
+ "\n - ".join(trained_data_output["suggestions"])
|
||||
)
|
||||
return task_prompt
|
||||
|
||||
def _render_text_description(self, tools: List[Any]) -> str:
|
||||
|
||||
@@ -201,9 +201,17 @@ def install(context):
|
||||
|
||||
|
||||
@crewai.command()
|
||||
def run():
|
||||
@click.option(
|
||||
"-f",
|
||||
"--trained-data-file",
|
||||
type=str,
|
||||
default="trained_agents_data.pkl",
|
||||
help="Path to a trained data file to use for agent task prompts",
|
||||
)
|
||||
def run(trained_data_file: str):
|
||||
"""Run the Crew."""
|
||||
run_crew()
|
||||
click.echo(f"Running the Crew with agent training data file: {trained_data_file}")
|
||||
run_crew(trained_data_file=trained_data_file)
|
||||
|
||||
|
||||
@crewai.command()
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import subprocess
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
@@ -14,13 +15,16 @@ class CrewType(Enum):
|
||||
FLOW = "flow"
|
||||
|
||||
|
||||
def run_crew() -> None:
|
||||
def run_crew(trained_data_file: Optional[str] = None) -> None:
|
||||
"""
|
||||
Run the crew or flow by running a command in the UV environment.
|
||||
|
||||
Starting from version 0.103.0, this command can be used to run both
|
||||
standard crews and flows. For flows, it detects the type from pyproject.toml
|
||||
and automatically runs the appropriate command.
|
||||
|
||||
Args:
|
||||
trained_data_file: Optional path to a trained data file to use
|
||||
"""
|
||||
crewai_version = get_crewai_version()
|
||||
min_required_version = "0.71.0"
|
||||
@@ -44,17 +48,35 @@ def run_crew() -> None:
|
||||
click.echo(f"Running the {'Flow' if is_flow else 'Crew'}")
|
||||
|
||||
# Execute the appropriate command
|
||||
execute_command(crew_type)
|
||||
execute_command(crew_type, trained_data_file)
|
||||
|
||||
|
||||
def execute_command(crew_type: CrewType) -> None:
|
||||
def execute_command(crew_type: CrewType, trained_data_file: Optional[str] = None) -> None:
|
||||
"""
|
||||
Execute the appropriate command based on crew type.
|
||||
|
||||
Args:
|
||||
crew_type: The type of crew to run
|
||||
trained_data_file: Optional path to a trained data file to use
|
||||
"""
|
||||
command = ["uv", "run", "kickoff" if crew_type == CrewType.FLOW else "run_crew"]
|
||||
|
||||
if trained_data_file and crew_type == CrewType.STANDARD:
|
||||
if not trained_data_file.endswith('.pkl'):
|
||||
click.secho(
|
||||
f"Error: Trained data file '{trained_data_file}' must have a .pkl extension.",
|
||||
fg="red",
|
||||
)
|
||||
return
|
||||
|
||||
if not os.path.exists(trained_data_file):
|
||||
click.secho(
|
||||
f"Error: Trained data file '{trained_data_file}' does not exist.",
|
||||
fg="red",
|
||||
)
|
||||
return
|
||||
|
||||
command.extend(["--trained-data-file", trained_data_file])
|
||||
|
||||
try:
|
||||
subprocess.run(command, capture_output=False, text=True, check=True)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#!/usr/bin/env python
|
||||
import argparse
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
@@ -17,13 +18,23 @@ def run():
|
||||
"""
|
||||
Run the crew.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Run the crew")
|
||||
parser.add_argument(
|
||||
"--trained-data-file",
|
||||
"-f",
|
||||
type=str,
|
||||
default="trained_agents_data.pkl",
|
||||
help="Path to a trained data file to use for agent task prompts"
|
||||
)
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
inputs = {
|
||||
'topic': 'AI LLMs',
|
||||
'current_year': str(datetime.now().year)
|
||||
}
|
||||
|
||||
try:
|
||||
{{crew_name}}().crew().kickoff(inputs=inputs)
|
||||
{{crew_name}}().crew(trained_data_file=args.trained_data_file).kickoff(inputs=inputs)
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred while running the crew: {e}")
|
||||
|
||||
|
||||
@@ -122,6 +122,10 @@ class Crew(BaseModel):
|
||||
tasks: List[Task] = Field(default_factory=list)
|
||||
agents: List[BaseAgent] = Field(default_factory=list)
|
||||
process: Process = Field(default=Process.sequential)
|
||||
trained_data_file: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Path to the trained data file to use for agent task prompts.",
|
||||
)
|
||||
verbose: bool = Field(default=False)
|
||||
memory: bool = Field(
|
||||
default=False,
|
||||
@@ -1196,7 +1200,12 @@ class Crew(BaseModel):
|
||||
"manager_llm",
|
||||
}
|
||||
|
||||
cloned_agents = [agent.copy() for agent in self.agents]
|
||||
cloned_agents = []
|
||||
for agent in self.agents:
|
||||
cloned_agent = agent.copy()
|
||||
if self.trained_data_file:
|
||||
cloned_agent.trained_data_file = self.trained_data_file
|
||||
cloned_agents.append(cloned_agent)
|
||||
manager_agent = self.manager_agent.copy() if self.manager_agent else None
|
||||
manager_llm = shallow_copy(self.manager_llm) if self.manager_llm else None
|
||||
|
||||
|
||||
@@ -1196,6 +1196,102 @@ def test_agent_use_trained_data(crew_training_handler):
|
||||
)
|
||||
|
||||
|
||||
@patch("crewai.agent.CrewTrainingHandler")
|
||||
def test_agent_use_custom_trained_data_file(crew_training_handler):
|
||||
task_prompt = "What is 1 + 1?"
|
||||
custom_file = "custom_trained_data.pkl"
|
||||
agent = Agent(
|
||||
role="researcher",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
verbose=True,
|
||||
trained_data_file=custom_file
|
||||
)
|
||||
crew_training_handler().load.return_value = {
|
||||
agent.role: {
|
||||
"suggestions": [
|
||||
"The result of the math operation must be right.",
|
||||
"Result must be better than 1.",
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
result = agent._use_trained_data(task_prompt=task_prompt)
|
||||
|
||||
assert (
|
||||
result == "What is 1 + 1?\n\nYou MUST follow these instructions: \n"
|
||||
" - The result of the math operation must be right.\n - Result must be better than 1."
|
||||
)
|
||||
crew_training_handler.assert_has_calls(
|
||||
[mock.call(), mock.call(custom_file), mock.call().load()]
|
||||
)
|
||||
|
||||
|
||||
@patch("crewai.agent.CrewTrainingHandler")
|
||||
def test_agent_with_none_trained_data_file(crew_training_handler):
|
||||
task_prompt = "What is 1 + 1?"
|
||||
agent = Agent(
|
||||
role="researcher",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
verbose=True,
|
||||
trained_data_file=None
|
||||
)
|
||||
|
||||
result = agent._use_trained_data(task_prompt=task_prompt)
|
||||
|
||||
assert result == task_prompt
|
||||
crew_training_handler.assert_not_called()
|
||||
|
||||
|
||||
@patch("crewai.agent.CrewTrainingHandler")
|
||||
def test_agent_with_missing_role_in_trained_data(crew_training_handler):
|
||||
task_prompt = "What is 1 + 1?"
|
||||
agent = Agent(
|
||||
role="researcher",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
verbose=True,
|
||||
trained_data_file="trained_agents_data.pkl"
|
||||
)
|
||||
crew_training_handler().load.return_value = {
|
||||
"other_role": {
|
||||
"suggestions": ["This should not be used."]
|
||||
}
|
||||
}
|
||||
|
||||
result = agent._use_trained_data(task_prompt=task_prompt)
|
||||
|
||||
assert result == task_prompt
|
||||
crew_training_handler.assert_has_calls(
|
||||
[mock.call(), mock.call("trained_agents_data.pkl"), mock.call().load()]
|
||||
)
|
||||
|
||||
|
||||
@patch("crewai.agent.CrewTrainingHandler")
|
||||
def test_agent_with_missing_suggestions_in_trained_data(crew_training_handler):
|
||||
task_prompt = "What is 1 + 1?"
|
||||
agent = Agent(
|
||||
role="researcher",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
verbose=True,
|
||||
trained_data_file="trained_agents_data.pkl"
|
||||
)
|
||||
crew_training_handler().load.return_value = {
|
||||
"researcher": {
|
||||
"other_key": ["This should not be used."]
|
||||
}
|
||||
}
|
||||
|
||||
result = agent._use_trained_data(task_prompt=task_prompt)
|
||||
|
||||
assert result == task_prompt
|
||||
crew_training_handler.assert_has_calls(
|
||||
[mock.call(), mock.call("trained_agents_data.pkl"), mock.call().load()]
|
||||
)
|
||||
|
||||
|
||||
def test_agent_max_retry_limit():
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import sys
|
||||
from time import sleep
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@@ -285,6 +286,9 @@ def test_gemini_models(model):
|
||||
],
|
||||
)
|
||||
def test_gemma3(model):
|
||||
if sys.version_info.major == 3 and sys.version_info.minor == 11:
|
||||
pytest.skip("Skipping test_gemma3 on Python 3.11 due to segmentation fault")
|
||||
|
||||
llm = LLM(model=model)
|
||||
result = llm.call("What is the capital of France?")
|
||||
assert isinstance(result, str)
|
||||
|
||||
Reference in New Issue
Block a user