better typing for stored_outputs and separated task_output_handler

This commit is contained in:
Lorenze Jay
2024-07-12 11:06:05 -07:00
parent 8b7040577f
commit 0e65091c43
4 changed files with 94 additions and 77 deletions

View File

@@ -55,7 +55,7 @@ def train(n_iterations: int):
"-t", "-t",
"--task_id", "--task_id",
type=str, type=str,
help="The task ID of the task to replay from. This will replay the task and all the tasks that were executed after it.", help="Replay the crew from this task ID, including all subsequent tasks.",
) )
def replay_from_task(task_id: str) -> None: def replay_from_task(task_id: str) -> None:
""" """

View File

@@ -1,7 +1,6 @@
import asyncio import asyncio
import json import json
import uuid import uuid
from datetime import datetime
from concurrent.futures import Future from concurrent.futures import Future
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
@@ -38,7 +37,10 @@ from crewai.utilities.constants import (
TRAINING_DATA_FILE, TRAINING_DATA_FILE,
) )
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
from crewai.utilities.file_handler import TaskOutputJsonHandler from crewai.utilities.task_output_handler import (
ExecutionLog,
TaskOutputJsonHandler,
)
from crewai.utilities.formatter import ( from crewai.utilities.formatter import (
aggregate_raw_outputs_from_task_outputs, aggregate_raw_outputs_from_task_outputs,
aggregate_raw_outputs_from_tasks, aggregate_raw_outputs_from_tasks,
@@ -149,7 +151,7 @@ class Crew(BaseModel):
default=None, default=None,
description="List of file paths for task execution JSON files.", description="List of file paths for task execution JSON files.",
) )
execution_logs: List[Dict[str, Any]] = Field( execution_logs: List[ExecutionLog] = Field(
default=[], default=[],
description="List of execution logs for tasks", description="List of execution logs for tasks",
) )
@@ -397,6 +399,7 @@ class Crew(BaseModel):
self._execution_span = self._telemetry.crew_execution_span(self, inputs) self._execution_span = self._telemetry.crew_execution_span(self, inputs)
TaskOutputJsonHandler(CREW_TASKS_OUTPUT_FILE).initialize_file() TaskOutputJsonHandler(CREW_TASKS_OUTPUT_FILE).initialize_file()
TaskOutputJsonHandler(CREW_TASKS_OUTPUT_FILE).reset() TaskOutputJsonHandler(CREW_TASKS_OUTPUT_FILE).reset()
self._logging_color = "bold_purple"
if inputs is not None: if inputs is not None:
self._inputs = inputs self._inputs = inputs
@@ -525,10 +528,11 @@ class Crew(BaseModel):
inputs = self._inputs inputs = self._inputs
else: else:
inputs = {} inputs = {}
log = {
"task_id": str(task.id), log = ExecutionLog(
"expected_output": task.expected_output, task_id=str(task.id),
"output": { expected_output=task.expected_output,
output={
"description": output.description, "description": output.description,
"summary": output.summary, "summary": output.summary,
"raw": output.raw, "raw": output.raw,
@@ -537,11 +541,10 @@ class Crew(BaseModel):
"output_format": output.output_format, "output_format": output.output_format,
"agent": output.agent, "agent": output.agent,
}, },
"timestamp": datetime.now().isoformat(), task_index=task_index,
"task_index": task_index, inputs=inputs,
"inputs": inputs, was_replayed=was_replayed,
"was_replayed": was_replayed, )
}
if task_index < len(self.execution_logs): if task_index < len(self.execution_logs):
self.execution_logs[task_index] = log self.execution_logs[task_index] = log
else: else:
@@ -620,7 +623,7 @@ class Crew(BaseModel):
self._log_task_start(task, agent_to_use) self._log_task_start(task, agent_to_use)
if task.async_execution: if task.async_execution:
context = self._set_context( context = self._get_context(
task, [last_sync_output] if last_sync_output else [] task, [last_sync_output] if last_sync_output else []
) )
future = task.execute_async( future = task.execute_async(
@@ -636,7 +639,7 @@ class Crew(BaseModel):
) )
futures.clear() futures.clear()
context = self._set_context(task, task_outputs) context = self._get_context(task, task_outputs)
task_output = task.execute_sync( task_output = task.execute_sync(
agent=agent_to_use, agent=agent_to_use,
context=context, context=context,
@@ -652,10 +655,10 @@ class Crew(BaseModel):
return self._create_crew_output(task_outputs) return self._create_crew_output(task_outputs)
def _prepare_task(self, task: Task, manager: Optional[BaseAgent]): def _prepare_task(self, task: Task, manager: Optional[BaseAgent]):
if task.agent and task.agent.allow_delegation:
self._add_delegation_tools(task)
if self.process == Process.hierarchical: if self.process == Process.hierarchical:
self._update_manager_tools(task, manager) self._update_manager_tools(task, manager)
elif task.agent and task.agent.allow_delegation:
self._add_delegation_tools(task)
def _add_delegation_tools(self, task: Task): def _add_delegation_tools(self, task: Task):
agents_for_delegation = [agent for agent in self.agents if agent != task.agent] agents_for_delegation = [agent for agent in self.agents if agent != task.agent]
@@ -676,7 +679,7 @@ class Crew(BaseModel):
if manager: if manager:
manager.tools = manager.get_delegation_tools(self.agents) manager.tools = manager.get_delegation_tools(self.agents)
def _set_context(self, task: Task, task_outputs: List[TaskOutput]): def _get_context(self, task: Task, task_outputs: List[TaskOutput]):
context = ( context = (
aggregate_raw_outputs_from_tasks(task.context) aggregate_raw_outputs_from_tasks(task.context)
if task.context if task.context
@@ -724,7 +727,7 @@ class Crew(BaseModel):
return task_outputs return task_outputs
def _find_task_index( def _find_task_index(
self, task_id: str, stored_outputs: List[Dict[str, Any]] self, task_id: str, stored_outputs: List[Any]
) -> Optional[int]: ) -> Optional[int]:
return next( return next(
( (
@@ -736,7 +739,7 @@ class Crew(BaseModel):
) )
def replay_from_task( def replay_from_task(
self, task_id: str, inputs: Dict[str, Any] | None = None self, task_id: str, inputs: Optional[Dict[str, Any]] = None
) -> CrewOutput: ) -> CrewOutput:
stored_outputs = TaskOutputJsonHandler(CREW_TASKS_OUTPUT_FILE).load() stored_outputs = TaskOutputJsonHandler(CREW_TASKS_OUTPUT_FILE).load()
start_index = self._find_task_index(task_id, stored_outputs) start_index = self._find_task_index(task_id, stored_outputs)
@@ -745,9 +748,7 @@ class Crew(BaseModel):
raise ValueError(f"Task with id {task_id} not found in the crew's tasks.") raise ValueError(f"Task with id {task_id} not found in the crew's tasks.")
replay_inputs = ( replay_inputs = (
inputs inputs if inputs is not None else stored_outputs[start_index]["inputs"]
if inputs is not None
else stored_outputs[start_index].get("inputs", {})
) )
self._inputs = replay_inputs self._inputs = replay_inputs
@@ -771,7 +772,6 @@ class Crew(BaseModel):
self._logging_color = "bold_blue" self._logging_color = "bold_blue"
result = self._execute_tasks(self.tasks, self.manager_agent, start_index, True) result = self._execute_tasks(self.tasks, self.manager_agent, start_index, True)
self._logging_color = "bold_purple"
return result return result
def copy(self): def copy(self):

View File

@@ -1,11 +1,8 @@
import os import os
import pickle import pickle
import json
from datetime import datetime from datetime import datetime
from typing import Dict, Any, List
from crewai.utilities.crew_json_encoder import CrewJSONEncoder
class FileHandler: class FileHandler:
@@ -71,52 +68,3 @@ class PickleHandler:
return {} # Return an empty dictionary if the file is empty or corrupted return {} # Return an empty dictionary if the file is empty or corrupted
except Exception: except Exception:
raise # Raise any other exceptions that occur during loading raise # Raise any other exceptions that occur during loading
class TaskOutputJsonHandler:
def __init__(self, file_name: str) -> None:
self.file_path = os.path.join(os.getcwd(), file_name)
def initialize_file(self) -> None:
if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0:
with open(self.file_path, "w") as file:
json.dump([], file)
def update(self, task_index: int, log: Dict[str, Any]):
logs = self.load()
if task_index < len(logs):
logs[task_index] = log
else:
logs.append(log)
self.save(logs)
def save(self, logs: List[Dict[str, Any]]):
with open(self.file_path, "w") as file:
json.dump(logs, file, indent=2, cls=CrewJSONEncoder)
def reset(self):
"""Reset the JSON file by creating an empty file."""
with open(self.file_path, "w") as f:
json.dump([], f)
def load(self) -> list:
try:
if (
not os.path.exists(self.file_path)
or os.path.getsize(self.file_path) == 0
):
return []
with open(self.file_path, "r") as file:
return json.load(file)
except FileNotFoundError:
print(f"File {self.file_path} not found. Returning empty list.")
return []
except json.JSONDecodeError:
print(
f"Error decoding JSON from file {self.file_path}. Returning empty list."
)
return []
except Exception as e:
print(f"An unexpected error occurred: {e}")
return []

View File

@@ -0,0 +1,69 @@
import json
import os
from pydantic import BaseModel, Field
from datetime import datetime
from typing import Dict, Any, Optional, List
from crewai.utilities.crew_json_encoder import CrewJSONEncoder
class ExecutionLog(BaseModel):
task_id: str
expected_output: Optional[str] = None
output: Dict[str, Any]
timestamp: datetime = Field(default_factory=datetime.now)
task_index: int
inputs: Dict[str, Any] = Field(default_factory=dict)
was_replayed: bool = False
def __getitem__(self, key: str) -> Any:
return getattr(self, key)
class TaskOutputJsonHandler:
def __init__(self, file_name: str) -> None:
self.file_path = os.path.join(os.getcwd(), file_name)
def initialize_file(self) -> None:
if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0:
with open(self.file_path, "w") as file:
json.dump([], file)
def update(self, task_index: int, log: ExecutionLog):
logs = self.load()
if task_index < len(logs):
logs[task_index] = log
else:
logs.append(log)
self.save(logs)
def save(self, logs: List[ExecutionLog]):
with open(self.file_path, "w") as file:
json.dump(logs, file, indent=2, cls=CrewJSONEncoder)
def reset(self):
"""Reset the JSON file by creating an empty file."""
with open(self.file_path, "w") as f:
json.dump([], f)
def load(self) -> List[ExecutionLog]:
try:
if (
not os.path.exists(self.file_path)
or os.path.getsize(self.file_path) == 0
):
return []
with open(self.file_path, "r") as file:
return json.load(file)
except FileNotFoundError:
print(f"File {self.file_path} not found. Returning empty list.")
return []
except json.JSONDecodeError:
print(
f"Error decoding JSON from file {self.file_path}. Returning empty list."
)
return []
except Exception as e:
print(f"An unexpected error occurred: {e}")
return []