From 0e65091c43b0d5cd758053c62d7b559a9cbdfec0 Mon Sep 17 00:00:00 2001 From: Lorenze Jay Date: Fri, 12 Jul 2024 11:06:05 -0700 Subject: [PATCH] better typing for stored_outputs and separated task_output_handler --- src/crewai/cli/cli.py | 2 +- src/crewai/crew.py | 46 +++++++------- src/crewai/utilities/file_handler.py | 54 +--------------- src/crewai/utilities/task_output_handler.py | 69 +++++++++++++++++++++ 4 files changed, 94 insertions(+), 77 deletions(-) create mode 100644 src/crewai/utilities/task_output_handler.py diff --git a/src/crewai/cli/cli.py b/src/crewai/cli/cli.py index c84d9afe6..70fb24730 100644 --- a/src/crewai/cli/cli.py +++ b/src/crewai/cli/cli.py @@ -55,7 +55,7 @@ def train(n_iterations: int): "-t", "--task_id", type=str, - help="The task ID of the task to replay from. This will replay the task and all the tasks that were executed after it.", + help="Replay the crew from this task ID, including all subsequent tasks.", ) def replay_from_task(task_id: str) -> None: """ diff --git a/src/crewai/crew.py b/src/crewai/crew.py index 6a6749f7e..b7effb04c 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -1,7 +1,6 @@ import asyncio import json import uuid -from datetime import datetime from concurrent.futures import Future from typing import Any, Dict, List, Optional, Tuple, Union @@ -38,7 +37,10 @@ from crewai.utilities.constants import ( TRAINING_DATA_FILE, ) from crewai.utilities.evaluators.task_evaluator import TaskEvaluator -from crewai.utilities.file_handler import TaskOutputJsonHandler +from crewai.utilities.task_output_handler import ( + ExecutionLog, + TaskOutputJsonHandler, +) from crewai.utilities.formatter import ( aggregate_raw_outputs_from_task_outputs, aggregate_raw_outputs_from_tasks, @@ -149,7 +151,7 @@ class Crew(BaseModel): default=None, description="List of file paths for task execution JSON files.", ) - execution_logs: List[Dict[str, Any]] = Field( + execution_logs: List[ExecutionLog] = Field( default=[], description="List of execution logs for tasks", ) @@ -397,6 +399,7 @@ class Crew(BaseModel): self._execution_span = self._telemetry.crew_execution_span(self, inputs) TaskOutputJsonHandler(CREW_TASKS_OUTPUT_FILE).initialize_file() TaskOutputJsonHandler(CREW_TASKS_OUTPUT_FILE).reset() + self._logging_color = "bold_purple" if inputs is not None: self._inputs = inputs @@ -525,10 +528,11 @@ class Crew(BaseModel): inputs = self._inputs else: inputs = {} - log = { - "task_id": str(task.id), - "expected_output": task.expected_output, - "output": { + + log = ExecutionLog( + task_id=str(task.id), + expected_output=task.expected_output, + output={ "description": output.description, "summary": output.summary, "raw": output.raw, @@ -537,11 +541,10 @@ class Crew(BaseModel): "output_format": output.output_format, "agent": output.agent, }, - "timestamp": datetime.now().isoformat(), - "task_index": task_index, - "inputs": inputs, - "was_replayed": was_replayed, - } + task_index=task_index, + inputs=inputs, + was_replayed=was_replayed, + ) if task_index < len(self.execution_logs): self.execution_logs[task_index] = log else: @@ -620,7 +623,7 @@ class Crew(BaseModel): self._log_task_start(task, agent_to_use) if task.async_execution: - context = self._set_context( + context = self._get_context( task, [last_sync_output] if last_sync_output else [] ) future = task.execute_async( @@ -636,7 +639,7 @@ class Crew(BaseModel): ) futures.clear() - context = self._set_context(task, task_outputs) + context = self._get_context(task, task_outputs) task_output = task.execute_sync( agent=agent_to_use, context=context, @@ -652,10 +655,10 @@ class Crew(BaseModel): return self._create_crew_output(task_outputs) def _prepare_task(self, task: Task, manager: Optional[BaseAgent]): - if task.agent and task.agent.allow_delegation: - self._add_delegation_tools(task) if self.process == Process.hierarchical: self._update_manager_tools(task, manager) + elif task.agent and task.agent.allow_delegation: + self._add_delegation_tools(task) def _add_delegation_tools(self, task: Task): agents_for_delegation = [agent for agent in self.agents if agent != task.agent] @@ -676,7 +679,7 @@ class Crew(BaseModel): if manager: manager.tools = manager.get_delegation_tools(self.agents) - def _set_context(self, task: Task, task_outputs: List[TaskOutput]): + def _get_context(self, task: Task, task_outputs: List[TaskOutput]): context = ( aggregate_raw_outputs_from_tasks(task.context) if task.context @@ -724,7 +727,7 @@ class Crew(BaseModel): return task_outputs def _find_task_index( - self, task_id: str, stored_outputs: List[Dict[str, Any]] + self, task_id: str, stored_outputs: List[Any] ) -> Optional[int]: return next( ( @@ -736,7 +739,7 @@ class Crew(BaseModel): ) def replay_from_task( - self, task_id: str, inputs: Dict[str, Any] | None = None + self, task_id: str, inputs: Optional[Dict[str, Any]] = None ) -> CrewOutput: stored_outputs = TaskOutputJsonHandler(CREW_TASKS_OUTPUT_FILE).load() start_index = self._find_task_index(task_id, stored_outputs) @@ -745,9 +748,7 @@ class Crew(BaseModel): raise ValueError(f"Task with id {task_id} not found in the crew's tasks.") replay_inputs = ( - inputs - if inputs is not None - else stored_outputs[start_index].get("inputs", {}) + inputs if inputs is not None else stored_outputs[start_index]["inputs"] ) self._inputs = replay_inputs @@ -771,7 +772,6 @@ class Crew(BaseModel): self._logging_color = "bold_blue" result = self._execute_tasks(self.tasks, self.manager_agent, start_index, True) - self._logging_color = "bold_purple" return result def copy(self): diff --git a/src/crewai/utilities/file_handler.py b/src/crewai/utilities/file_handler.py index 8f9b727a2..68c33241d 100644 --- a/src/crewai/utilities/file_handler.py +++ b/src/crewai/utilities/file_handler.py @@ -1,11 +1,8 @@ import os import pickle -import json + from datetime import datetime -from typing import Dict, Any, List - -from crewai.utilities.crew_json_encoder import CrewJSONEncoder class FileHandler: @@ -71,52 +68,3 @@ class PickleHandler: return {} # Return an empty dictionary if the file is empty or corrupted except Exception: raise # Raise any other exceptions that occur during loading - - -class TaskOutputJsonHandler: - def __init__(self, file_name: str) -> None: - self.file_path = os.path.join(os.getcwd(), file_name) - - def initialize_file(self) -> None: - if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0: - with open(self.file_path, "w") as file: - json.dump([], file) - - def update(self, task_index: int, log: Dict[str, Any]): - logs = self.load() - if task_index < len(logs): - logs[task_index] = log - else: - logs.append(log) - self.save(logs) - - def save(self, logs: List[Dict[str, Any]]): - with open(self.file_path, "w") as file: - json.dump(logs, file, indent=2, cls=CrewJSONEncoder) - - def reset(self): - """Reset the JSON file by creating an empty file.""" - with open(self.file_path, "w") as f: - json.dump([], f) - - def load(self) -> list: - try: - if ( - not os.path.exists(self.file_path) - or os.path.getsize(self.file_path) == 0 - ): - return [] - - with open(self.file_path, "r") as file: - return json.load(file) - except FileNotFoundError: - print(f"File {self.file_path} not found. Returning empty list.") - return [] - except json.JSONDecodeError: - print( - f"Error decoding JSON from file {self.file_path}. Returning empty list." - ) - return [] - except Exception as e: - print(f"An unexpected error occurred: {e}") - return [] diff --git a/src/crewai/utilities/task_output_handler.py b/src/crewai/utilities/task_output_handler.py new file mode 100644 index 000000000..7caec6774 --- /dev/null +++ b/src/crewai/utilities/task_output_handler.py @@ -0,0 +1,69 @@ +import json +import os + +from pydantic import BaseModel, Field +from datetime import datetime +from typing import Dict, Any, Optional, List +from crewai.utilities.crew_json_encoder import CrewJSONEncoder + + +class ExecutionLog(BaseModel): + task_id: str + expected_output: Optional[str] = None + output: Dict[str, Any] + timestamp: datetime = Field(default_factory=datetime.now) + task_index: int + inputs: Dict[str, Any] = Field(default_factory=dict) + was_replayed: bool = False + + def __getitem__(self, key: str) -> Any: + return getattr(self, key) + + +class TaskOutputJsonHandler: + def __init__(self, file_name: str) -> None: + self.file_path = os.path.join(os.getcwd(), file_name) + + def initialize_file(self) -> None: + if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0: + with open(self.file_path, "w") as file: + json.dump([], file) + + def update(self, task_index: int, log: ExecutionLog): + logs = self.load() + if task_index < len(logs): + logs[task_index] = log + else: + logs.append(log) + self.save(logs) + + def save(self, logs: List[ExecutionLog]): + with open(self.file_path, "w") as file: + json.dump(logs, file, indent=2, cls=CrewJSONEncoder) + + def reset(self): + """Reset the JSON file by creating an empty file.""" + with open(self.file_path, "w") as f: + json.dump([], f) + + def load(self) -> List[ExecutionLog]: + try: + if ( + not os.path.exists(self.file_path) + or os.path.getsize(self.file_path) == 0 + ): + return [] + + with open(self.file_path, "r") as file: + return json.load(file) + except FileNotFoundError: + print(f"File {self.file_path} not found. Returning empty list.") + return [] + except json.JSONDecodeError: + print( + f"Error decoding JSON from file {self.file_path}. Returning empty list." + ) + return [] + except Exception as e: + print(f"An unexpected error occurred: {e}") + return []