better typing for stored_outputs and separated task_output_handler

This commit is contained in:
Lorenze Jay
2024-07-12 11:06:05 -07:00
parent 8b7040577f
commit 0e65091c43
4 changed files with 94 additions and 77 deletions

View File

@@ -55,7 +55,7 @@ def train(n_iterations: int):
"-t",
"--task_id",
type=str,
help="The task ID of the task to replay from. This will replay the task and all the tasks that were executed after it.",
help="Replay the crew from this task ID, including all subsequent tasks.",
)
def replay_from_task(task_id: str) -> None:
"""

View File

@@ -1,7 +1,6 @@
import asyncio
import json
import uuid
from datetime import datetime
from concurrent.futures import Future
from typing import Any, Dict, List, Optional, Tuple, Union
@@ -38,7 +37,10 @@ from crewai.utilities.constants import (
TRAINING_DATA_FILE,
)
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
from crewai.utilities.file_handler import TaskOutputJsonHandler
from crewai.utilities.task_output_handler import (
ExecutionLog,
TaskOutputJsonHandler,
)
from crewai.utilities.formatter import (
aggregate_raw_outputs_from_task_outputs,
aggregate_raw_outputs_from_tasks,
@@ -149,7 +151,7 @@ class Crew(BaseModel):
default=None,
description="List of file paths for task execution JSON files.",
)
execution_logs: List[Dict[str, Any]] = Field(
execution_logs: List[ExecutionLog] = Field(
default=[],
description="List of execution logs for tasks",
)
@@ -397,6 +399,7 @@ class Crew(BaseModel):
self._execution_span = self._telemetry.crew_execution_span(self, inputs)
TaskOutputJsonHandler(CREW_TASKS_OUTPUT_FILE).initialize_file()
TaskOutputJsonHandler(CREW_TASKS_OUTPUT_FILE).reset()
self._logging_color = "bold_purple"
if inputs is not None:
self._inputs = inputs
@@ -525,10 +528,11 @@ class Crew(BaseModel):
inputs = self._inputs
else:
inputs = {}
log = {
"task_id": str(task.id),
"expected_output": task.expected_output,
"output": {
log = ExecutionLog(
task_id=str(task.id),
expected_output=task.expected_output,
output={
"description": output.description,
"summary": output.summary,
"raw": output.raw,
@@ -537,11 +541,10 @@ class Crew(BaseModel):
"output_format": output.output_format,
"agent": output.agent,
},
"timestamp": datetime.now().isoformat(),
"task_index": task_index,
"inputs": inputs,
"was_replayed": was_replayed,
}
task_index=task_index,
inputs=inputs,
was_replayed=was_replayed,
)
if task_index < len(self.execution_logs):
self.execution_logs[task_index] = log
else:
@@ -620,7 +623,7 @@ class Crew(BaseModel):
self._log_task_start(task, agent_to_use)
if task.async_execution:
context = self._set_context(
context = self._get_context(
task, [last_sync_output] if last_sync_output else []
)
future = task.execute_async(
@@ -636,7 +639,7 @@ class Crew(BaseModel):
)
futures.clear()
context = self._set_context(task, task_outputs)
context = self._get_context(task, task_outputs)
task_output = task.execute_sync(
agent=agent_to_use,
context=context,
@@ -652,10 +655,10 @@ class Crew(BaseModel):
return self._create_crew_output(task_outputs)
def _prepare_task(self, task: Task, manager: Optional[BaseAgent]):
if task.agent and task.agent.allow_delegation:
self._add_delegation_tools(task)
if self.process == Process.hierarchical:
self._update_manager_tools(task, manager)
elif task.agent and task.agent.allow_delegation:
self._add_delegation_tools(task)
def _add_delegation_tools(self, task: Task):
agents_for_delegation = [agent for agent in self.agents if agent != task.agent]
@@ -676,7 +679,7 @@ class Crew(BaseModel):
if manager:
manager.tools = manager.get_delegation_tools(self.agents)
def _set_context(self, task: Task, task_outputs: List[TaskOutput]):
def _get_context(self, task: Task, task_outputs: List[TaskOutput]):
context = (
aggregate_raw_outputs_from_tasks(task.context)
if task.context
@@ -724,7 +727,7 @@ class Crew(BaseModel):
return task_outputs
def _find_task_index(
self, task_id: str, stored_outputs: List[Dict[str, Any]]
self, task_id: str, stored_outputs: List[Any]
) -> Optional[int]:
return next(
(
@@ -736,7 +739,7 @@ class Crew(BaseModel):
)
def replay_from_task(
self, task_id: str, inputs: Dict[str, Any] | None = None
self, task_id: str, inputs: Optional[Dict[str, Any]] = None
) -> CrewOutput:
stored_outputs = TaskOutputJsonHandler(CREW_TASKS_OUTPUT_FILE).load()
start_index = self._find_task_index(task_id, stored_outputs)
@@ -745,9 +748,7 @@ class Crew(BaseModel):
raise ValueError(f"Task with id {task_id} not found in the crew's tasks.")
replay_inputs = (
inputs
if inputs is not None
else stored_outputs[start_index].get("inputs", {})
inputs if inputs is not None else stored_outputs[start_index]["inputs"]
)
self._inputs = replay_inputs
@@ -771,7 +772,6 @@ class Crew(BaseModel):
self._logging_color = "bold_blue"
result = self._execute_tasks(self.tasks, self.manager_agent, start_index, True)
self._logging_color = "bold_purple"
return result
def copy(self):

View File

@@ -1,11 +1,8 @@
import os
import pickle
import json
from datetime import datetime
from typing import Dict, Any, List
from crewai.utilities.crew_json_encoder import CrewJSONEncoder
class FileHandler:
@@ -71,52 +68,3 @@ class PickleHandler:
return {} # Return an empty dictionary if the file is empty or corrupted
except Exception:
raise # Raise any other exceptions that occur during loading
class TaskOutputJsonHandler:
def __init__(self, file_name: str) -> None:
self.file_path = os.path.join(os.getcwd(), file_name)
def initialize_file(self) -> None:
if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0:
with open(self.file_path, "w") as file:
json.dump([], file)
def update(self, task_index: int, log: Dict[str, Any]):
logs = self.load()
if task_index < len(logs):
logs[task_index] = log
else:
logs.append(log)
self.save(logs)
def save(self, logs: List[Dict[str, Any]]):
with open(self.file_path, "w") as file:
json.dump(logs, file, indent=2, cls=CrewJSONEncoder)
def reset(self):
"""Reset the JSON file by creating an empty file."""
with open(self.file_path, "w") as f:
json.dump([], f)
def load(self) -> list:
try:
if (
not os.path.exists(self.file_path)
or os.path.getsize(self.file_path) == 0
):
return []
with open(self.file_path, "r") as file:
return json.load(file)
except FileNotFoundError:
print(f"File {self.file_path} not found. Returning empty list.")
return []
except json.JSONDecodeError:
print(
f"Error decoding JSON from file {self.file_path}. Returning empty list."
)
return []
except Exception as e:
print(f"An unexpected error occurred: {e}")
return []

View File

@@ -0,0 +1,69 @@
import json
import os
from pydantic import BaseModel, Field
from datetime import datetime
from typing import Dict, Any, Optional, List
from crewai.utilities.crew_json_encoder import CrewJSONEncoder
class ExecutionLog(BaseModel):
task_id: str
expected_output: Optional[str] = None
output: Dict[str, Any]
timestamp: datetime = Field(default_factory=datetime.now)
task_index: int
inputs: Dict[str, Any] = Field(default_factory=dict)
was_replayed: bool = False
def __getitem__(self, key: str) -> Any:
return getattr(self, key)
class TaskOutputJsonHandler:
def __init__(self, file_name: str) -> None:
self.file_path = os.path.join(os.getcwd(), file_name)
def initialize_file(self) -> None:
if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0:
with open(self.file_path, "w") as file:
json.dump([], file)
def update(self, task_index: int, log: ExecutionLog):
logs = self.load()
if task_index < len(logs):
logs[task_index] = log
else:
logs.append(log)
self.save(logs)
def save(self, logs: List[ExecutionLog]):
with open(self.file_path, "w") as file:
json.dump(logs, file, indent=2, cls=CrewJSONEncoder)
def reset(self):
"""Reset the JSON file by creating an empty file."""
with open(self.file_path, "w") as f:
json.dump([], f)
def load(self) -> List[ExecutionLog]:
try:
if (
not os.path.exists(self.file_path)
or os.path.getsize(self.file_path) == 0
):
return []
with open(self.file_path, "r") as file:
return json.load(file)
except FileNotFoundError:
print(f"File {self.file_path} not found. Returning empty list.")
return []
except json.JSONDecodeError:
print(
f"Error decoding JSON from file {self.file_path}. Returning empty list."
)
return []
except Exception as e:
print(f"An unexpected error occurred: {e}")
return []