added cli command + code cleanup TODO: need better refactoring

This commit is contained in:
Lorenze Jay
2024-07-11 10:06:21 -07:00
parent 28929e1c5f
commit 3aa5d16a6f
7 changed files with 90 additions and 85 deletions

View File

@@ -1,8 +1,10 @@
import click
import pkg_resources
from .create_crew import create_crew
from .train_crew import train_crew
from .replay_from_task import replay_task_command
@click.group()
@@ -48,5 +50,26 @@ def train(n_iterations: int):
train_crew(n_iterations)
@crewai.command()
@click.option(
"-t",
"--task_id",
type=str,
help="The task ID of the task to replay from. This will replay the task and all the tasks that were executed after it.",
)
def replay_from_task(task_id: str) -> None:
"""
Replay the crew execution from a specific task.
Args:
task_id (str): The ID of the task to replay from.
"""
try:
click.echo(f"Replaying the crew from task {task_id}")
replay_task_command(task_id)
except Exception as e:
click.echo(f"An error occurred while replaying: {e}", err=True)
if __name__ == "__main__":
crewai()

View File

@@ -0,0 +1,24 @@
import subprocess
import click
def replay_task_command(task_id: str) -> None:
"""
Replay the crew execution from a specific task.
Args:
task_id (str): The ID of the task to replay from.
"""
command = ["poetry", "run", "replay_from_task", task_id]
try:
result = subprocess.run(command, capture_output=False, text=True, check=True)
if result.stderr:
click.echo(result.stderr, err=True)
except subprocess.CalledProcessError as e:
click.echo(f"An error occurred while replaying the task: {e}", err=True)
click.echo(e.output, err=True)
except Exception as e:
click.echo(f"An unexpected error occurred: {e}", err=True)

View File

@@ -21,3 +21,13 @@ def train():
except Exception as e:
raise Exception(f"An error occurred while training the crew: {e}")
def replay_from_task():
"""
Replay the crew execution from a specific task.
"""
try:
{{crew_name}}Crew().crew().replay_from_task(task_id=sys.argv[1])
except Exception as e:
raise Exception(f"An error occurred while replaying the crew: {e}")

View File

@@ -11,6 +11,7 @@ crewai = { extras = ["tools"], version = "^0.35.8" }
[tool.poetry.scripts]
{{folder_name}} = "{{folder_name}}.main:run"
train = "{{folder_name}}.main:train"
replay_from_task = "{{folder_name}}.main:replay_from_task"
[build-system]
requires = ["poetry-core"]

View File

@@ -32,7 +32,11 @@ from crewai.tasks.task_output import TaskOutput
from crewai.telemetry import Telemetry
from crewai.tools.agent_tools import AgentTools
from crewai.utilities import I18N, FileHandler, Logger, RPMController
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
from crewai.utilities.constants import (
CREW_TASKS_OUTPUT_FILE,
TRAINED_AGENTS_DATA_FILE,
TRAINING_DATA_FILE,
)
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
from crewai.utilities.file_handler import TaskOutputJsonHandler
from crewai.utilities.formatter import (
@@ -76,7 +80,6 @@ class Crew(BaseModel):
_rpm_controller: RPMController = PrivateAttr()
_logger: Logger = PrivateAttr()
_file_handler: FileHandler = PrivateAttr()
_task_output_handler: TaskOutputJsonHandler = PrivateAttr()
_cache_handler: InstanceOf[CacheHandler] = PrivateAttr(default=CacheHandler())
_short_term_memory: Optional[InstanceOf[ShortTermMemory]] = PrivateAttr()
_long_term_memory: Optional[InstanceOf[LongTermMemory]] = PrivateAttr()
@@ -148,8 +151,6 @@ class Crew(BaseModel):
description="List of execution logs for tasks",
)
_log_file: str = PrivateAttr(default="crew_tasks_output.json")
@field_validator("id", mode="before")
@classmethod
def _deny_user_set_id(cls, v: Optional[UUID4]) -> None:
@@ -181,7 +182,6 @@ class Crew(BaseModel):
self._logger = Logger(self.verbose)
if self.output_log_file:
self._file_handler = FileHandler(self.output_log_file)
self._task_output_handler = TaskOutputJsonHandler(self._log_file)
self._rpm_controller = RPMController(max_rpm=self.max_rpm, logger=self._logger)
self._telemetry = Telemetry()
self._telemetry.set_tracer()
@@ -392,7 +392,9 @@ class Crew(BaseModel):
) -> CrewOutput:
"""Starts the crew to work on its assigned tasks."""
self._execution_span = self._telemetry.crew_execution_span(self, inputs)
self._task_output_handler.reset()
TaskOutputJsonHandler(CREW_TASKS_OUTPUT_FILE).initialize_file()
TaskOutputJsonHandler(CREW_TASKS_OUTPUT_FILE).reset()
if inputs is not None:
self._inputs = inputs
self._interpolate_inputs(inputs)
@@ -522,9 +524,7 @@ class Crew(BaseModel):
inputs = {}
log = {
"task_id": str(task.id),
"description": task.description,
"expected_output": task.expected_output,
"agent_role": task.agent.role if task.agent else "None",
"output": {
"description": output.description,
"summary": output.summary,
@@ -544,8 +544,7 @@ class Crew(BaseModel):
self.execution_logs[task_index] = log
else:
self.execution_logs.append(log)
self._task_output_handler.update(task_index, log)
TaskOutputJsonHandler(CREW_TASKS_OUTPUT_FILE).update(task_index, log)
def _run_sequential_process(self) -> CrewOutput:
"""Executes tasks sequentially and returns the final output."""
@@ -665,42 +664,6 @@ class Crew(BaseModel):
)
return task_outputs
def _get_agent(self, role: str) -> Optional[BaseAgent]:
"""
Private method to get an agent by role.
Args:
role (str): The role of the agent to retrieve.
Returns:
Optional[BaseAgent]: The agent with the specified role, or None if not found.
"""
for agent in self.agents:
if agent.role == role:
return agent
return None
def _initialize_execution(self, inputs: Optional[Dict[str, Any]]) -> None:
"""Initializes the execution by setting up necessary attributes and states."""
self._execution_span = self._telemetry.crew_execution_span(self, inputs)
# self.execution_logs = []
self._task_output_handler.reset()
if inputs is not None:
self._inputs = inputs
self._interpolate_inputs(inputs)
self._set_tasks_callbacks()
i18n = I18N(prompt_file=self.prompt_file)
for agent in self.agents:
agent.i18n = i18n
agent.crew = self # type: ignore[attr-defined]
if not agent.function_calling_llm: # type: ignore[attr-defined]
agent.function_calling_llm = self.function_calling_llm # type: ignore[attr-defined]
if agent.allow_code_execution: # type: ignore[attr-defined]
agent.tools += agent.get_code_execution_tools() # type: ignore[attr-defined]
if not agent.step_callback: # type: ignore[attr-defined]
agent.step_callback = self.step_callback # type: ignore[attr-defined]
agent.create_agent_executor()
def _find_task_index(
self, task_id: str, stored_outputs: List[Dict[str, Any]]
) -> Optional[int]:
@@ -716,7 +679,8 @@ class Crew(BaseModel):
def replay_from_task(
self, task_id: str, inputs: Dict[str, Any] | None = None
) -> CrewOutput:
stored_outputs = self._load_stored_outputs()
# stored_outputs = self._load_stored_outputs()
stored_outputs = TaskOutputJsonHandler(CREW_TASKS_OUTPUT_FILE).load()
start_index = self._find_task_index(task_id, stored_outputs)
if start_index is None:
@@ -851,25 +815,9 @@ class Crew(BaseModel):
)
self.manager_agent = manager
def _load_stored_outputs(self) -> List[Dict]:
try:
with open(self._log_file, "r") as f:
return json.load(f)
except FileNotFoundError:
self._logger.log(
"warning",
f"Log file {self._log_file} not found. Starting with empty logs.",
)
return []
except json.JSONDecodeError:
self._logger.log(
"error",
f"Failed to parse log file {self._log_file}. Starting with empty logs.",
)
return []
def _run_hierarchical_process(self) -> CrewOutput:
"""Creates and assigns a manager agent to make sure the crew completes the tasks."""
self.execution_logs = []
i18n = I18N(prompt_file=self.prompt_file)
if self.manager_agent is not None:
self.manager_agent.allow_delegation = True

View File

@@ -1,2 +1,3 @@
TRAINING_DATA_FILE = "training_data.pkl"
TRAINED_AGENTS_DATA_FILE = "trained_agents_data.pkl"
CREW_TASKS_OUTPUT_FILE = "crew_tasks_output.json"

View File

@@ -82,23 +82,6 @@ class TaskOutputJsonHandler:
with open(self.file_path, "w") as file:
json.dump([], file)
def append(self, log) -> None:
if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0:
# Initialize the file with an empty list if it doesn't exist or is empty
with open(self.file_path, "w") as file:
json.dump([], file)
with open(self.file_path, "r+") as file:
try:
file_data = json.load(file)
except json.JSONDecodeError:
# If the file contains invalid JSON, initialize it with an empty list
file_data = []
file_data.append(log)
file.seek(0)
json.dump(file_data, file, indent=2, cls=CrewJSONEncoder)
file.truncate()
def update(self, task_index: int, log: Dict[str, Any]):
logs = self.load()
if task_index < len(logs):
@@ -117,8 +100,23 @@ class TaskOutputJsonHandler:
json.dump([], f)
def load(self) -> list:
if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0:
return []
try:
if (
not os.path.exists(self.file_path)
or os.path.getsize(self.file_path) == 0
):
return []
with open(self.file_path, "r") as file:
return json.load(file)
with open(self.file_path, "r") as file:
return json.load(file)
except FileNotFoundError:
print(f"File {self.file_path} not found. Returning empty list.")
return []
except json.JSONDecodeError:
print(
f"Error decoding JSON from file {self.file_path}. Returning empty list."
)
return []
except Exception as e:
print(f"An unexpected error occurred: {e}")
return []