WIP: replay working with async. need to add tests

This commit is contained in:
Lorenze Jay
2024-07-08 21:59:00 -07:00
parent 5c04c63127
commit 92fca9bbe9
4 changed files with 201 additions and 113 deletions

View File

@@ -35,6 +35,7 @@ from crewai.utilities import I18N, FileHandler, Logger, RPMController
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
from crewai.utilities.crew_json_encoder import CrewJSONEncoder from crewai.utilities.crew_json_encoder import CrewJSONEncoder
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
from crewai.utilities.file_handler import TaskOutputJsonHandler
from crewai.utilities.formatter import aggregate_raw_outputs_from_task_outputs from crewai.utilities.formatter import aggregate_raw_outputs_from_task_outputs
from crewai.utilities.training_handler import CrewTrainingHandler from crewai.utilities.training_handler import CrewTrainingHandler
@@ -73,6 +74,7 @@ class Crew(BaseModel):
_rpm_controller: RPMController = PrivateAttr() _rpm_controller: RPMController = PrivateAttr()
_logger: Logger = PrivateAttr() _logger: Logger = PrivateAttr()
_file_handler: FileHandler = PrivateAttr() _file_handler: FileHandler = PrivateAttr()
_task_output_handler: TaskOutputJsonHandler = PrivateAttr()
_cache_handler: InstanceOf[CacheHandler] = PrivateAttr(default=CacheHandler()) _cache_handler: InstanceOf[CacheHandler] = PrivateAttr(default=CacheHandler())
_short_term_memory: Optional[InstanceOf[ShortTermMemory]] = PrivateAttr() _short_term_memory: Optional[InstanceOf[ShortTermMemory]] = PrivateAttr()
_long_term_memory: Optional[InstanceOf[LongTermMemory]] = PrivateAttr() _long_term_memory: Optional[InstanceOf[LongTermMemory]] = PrivateAttr()
@@ -176,6 +178,7 @@ class Crew(BaseModel):
self._logger = Logger(self.verbose) self._logger = Logger(self.verbose)
if self.output_log_file: if self.output_log_file:
self._file_handler = FileHandler(self.output_log_file) self._file_handler = FileHandler(self.output_log_file)
self._task_output_handler = TaskOutputJsonHandler(self._log_file)
self._rpm_controller = RPMController(max_rpm=self.max_rpm, logger=self._logger) self._rpm_controller = RPMController(max_rpm=self.max_rpm, logger=self._logger)
self._telemetry = Telemetry() self._telemetry = Telemetry()
self._telemetry.set_tracer() self._telemetry.set_tracer()
@@ -329,9 +332,10 @@ class Crew(BaseModel):
) -> CrewOutput: ) -> CrewOutput:
"""Starts the crew to work on its assigned tasks.""" """Starts the crew to work on its assigned tasks."""
self._execution_span = self._telemetry.crew_execution_span(self, inputs) self._execution_span = self._telemetry.crew_execution_span(self, inputs)
self.execution_logs = []
if inputs is not None: if inputs is not None:
self._interpolate_inputs(inputs) self._interpolate_inputs(inputs)
self._interpolate_inputs(inputs) # self._interpolate_inputs(inputs)
self._set_tasks_callbacks() self._set_tasks_callbacks()
i18n = I18N(prompt_file=self.prompt_file) i18n = I18N(prompt_file=self.prompt_file)
@@ -355,7 +359,7 @@ class Crew(BaseModel):
metrics = [] metrics = []
if self.process == Process.sequential: if self.process == Process.sequential:
result = self._run_sequential_process() result = self._run_sequential_process(inputs)
elif self.process == Process.hierarchical: elif self.process == Process.hierarchical:
result, manager_metrics = self._run_hierarchical_process() # type: ignore # Incompatible types in assignment (expression has type "str | dict[str, Any]", variable has type "str") result, manager_metrics = self._run_hierarchical_process() # type: ignore # Incompatible types in assignment (expression has type "str | dict[str, Any]", variable has type "str")
metrics.append(manager_metrics) metrics.append(manager_metrics)
@@ -448,25 +452,52 @@ class Crew(BaseModel):
return results return results
def _store_execution_log(self, task, output, task_index): def _store_execution_log(self, task, output, task_index, inputs=None):
print("output passeed in", output)
log = { log = {
"task_id": str(task.id), "task_id": str(task.id),
"description": task.description, "description": task.description,
"expected_output": task.expected_output,
"agent_role": task.agent.role if task.agent else "None", "agent_role": task.agent.role if task.agent else "None",
"output": output, "output": {
"description": task.description,
"summary": task.description,
"raw_output": output.raw_output,
"pydantic_output": output.pydantic_output,
"json_output": output.json_output,
"agent": task.agent.role if task.agent else "None",
},
"timestamp": datetime.now().isoformat(), "timestamp": datetime.now().isoformat(),
"task_index": task_index, "task_index": task_index,
# "output_py": output.pydantic_output,
"inputs": inputs,
# "task": task.model_dump(),
} }
self.execution_logs.append(log) self.execution_logs.append(log)
print("execution_logs", self.execution_logs) self._task_output_handler.append(log)
def _run_sequential_process(self) -> CrewOutput: def _run_sequential_process(
self, inputs: Dict[str, Any] | None = None
) -> CrewOutput:
"""Executes tasks sequentially and returns the final output.""" """Executes tasks sequentially and returns the final output."""
self.execution_logs = []
task_outputs = self._execute_tasks(self.tasks, inputs=inputs)
final_string_output = aggregate_raw_outputs_from_task_outputs(task_outputs)
self._finish_execution(final_string_output)
self.save_execution_logs()
token_usage = self.calculate_usage_metrics()
return self._format_output(task_outputs, token_usage)
def _execute_tasks(
self,
tasks,
start_index=0,
is_replay=False,
inputs: Dict[str, Any] | None = None,
):
task_outputs: List[TaskOutput] = [] task_outputs: List[TaskOutput] = []
futures: List[Tuple[Task, Future[TaskOutput]]] = [] futures: List[Tuple[Task, Future[TaskOutput]]] = []
# execution_logs: List[Dict[str, Any]] = [] for task_index, task in enumerate(tasks[start_index:], start=start_index):
for task_index, task in enumerate(self.tasks):
if task.agent and task.agent.allow_delegation: if task.agent and task.agent.allow_delegation:
agents_for_delegation = [ agents_for_delegation = [
agent for agent in self.agents if agent != task.agent agent for agent in self.agents if agent != task.agent
@@ -475,9 +506,15 @@ class Crew(BaseModel):
task.tools += task.agent.get_delegation_tools(agents_for_delegation) task.tools += task.agent.get_delegation_tools(agents_for_delegation)
role = task.agent.role if task.agent is not None else "None" role = task.agent.role if task.agent is not None else "None"
self._logger.log("debug", f"== Working Agent: {role}", color="bold_purple") log_prefix = "== Replaying from" if is_replay else "=="
log_color = "bold_blue" if is_replay else "bold_purple"
self._logger.log( self._logger.log(
"info", f"== Starting Task: {task.description}", color="bold_purple" "debug", f"{log_prefix} Working Agent: {role}", color=log_color
)
self._logger.log(
"info",
f"{log_prefix} {'Replaying' if is_replay else 'Starting'} Task: {task.description}",
color=log_color,
) )
if self.output_log_file: if self.output_log_file:
@@ -492,19 +529,10 @@ class Crew(BaseModel):
) )
futures.append((task, future)) futures.append((task, future))
else: else:
# Before executing a synchronous task, wait for all async tasks to complete
if futures: if futures:
print("futures for sync task", futures) task_outputs = self._process_async_tasks(
# Clear task_outputs before processing async tasks futures, task_index, inputs
task_outputs = [] )
for future_task, future in futures:
task_output = future.result()
task_outputs.append(task_output)
self._store_execution_log(future_task, task_output, task_index)
self._process_task_result(future_task, task_output)
# Clear the futures list after processing all async results
futures.clear() futures.clear()
context = aggregate_raw_outputs_from_task_outputs(task_outputs) context = aggregate_raw_outputs_from_task_outputs(task_outputs)
@@ -513,26 +541,12 @@ class Crew(BaseModel):
) )
task_outputs = [task_output] task_outputs = [task_output]
self._process_task_result(task, task_output) self._process_task_result(task, task_output)
self._store_execution_log(task, task_output, task_index) self._store_execution_log(task, task_output, task_index, inputs)
if futures: if futures:
print("there are some async tasks we need to eecute in the future", futures) task_outputs = self._process_async_tasks(futures, len(tasks), inputs)
# Clear task_outputs before processing async tasks
task_outputs = self._process_async_tasks(futures, len(self.tasks))
print("task_outputs from futures", task_outputs) return task_outputs
# task_outputs = []
# for future_task, future in futures:
# task_output = future.result()
# task_outputs.append(task_output)
# self._process_task_result(future_task, task_output)
final_string_output = aggregate_raw_outputs_from_task_outputs(task_outputs)
self._finish_execution(final_string_output)
print("self.execution_logs", self.execution_logs)
self.save_execution_logs()
token_usage = self.calculate_usage_metrics()
return self._format_output(task_outputs, token_usage)
def _process_task_result(self, task: Task, output: TaskOutput) -> None: def _process_task_result(self, task: Task, output: TaskOutput) -> None:
role = task.agent.role if task.agent is not None else "None" role = task.agent.role if task.agent is not None else "None"
@@ -544,42 +558,19 @@ class Crew(BaseModel):
self, self,
futures: List[Tuple[Task, Future[TaskOutput]]], futures: List[Tuple[Task, Future[TaskOutput]]],
task_index: int, task_index: int,
inputs: Dict[str, Any] | None = None,
) -> List[TaskOutput]: ) -> List[TaskOutput]:
task_outputs = [] task_outputs = []
for future_task, future in futures: for future_task, future in futures:
task_output = future.result() task_output = future.result()
task_outputs.append(task_output) task_outputs.append(task_output)
self._process_task_result(future_task, task_output) self._process_task_result(future_task, task_output)
self._store_execution_log(future_task, task_output, task_index) self._store_execution_log(future_task, task_output, task_index, inputs)
return task_outputs return task_outputs
def _create_execution_log( def replay_from_task(self, task_id: str):
self, task: Task, output: TaskOutput, task_index: int stored_outputs = self._load_stored_outputs()
) -> Dict[str, Any]:
return {
"task_id": str(task.id),
"task_index": task_index,
"task_description": task.description,
"agent_role": task.agent.role if task.agent else "None",
"output": output.raw_output,
"timestamp": datetime.now().isoformat(),
"task": task.model_dump(),
}
def replay_from_task(self, task_id: UUID4, use_stored_logs: bool = False):
"""Replay execution from a specific task and continue through subsequent tasks."""
task_outputs: List[TaskOutput] = []
futures: List[Tuple[Task, Future[TaskOutput]]] = []
execution_logs: List[Dict[str, Any]] = []
if use_stored_logs:
self.load_execution_logs()
# Load the task outputs from the crew_tasks_output.json file
with open("crew_tasks_output.json", "r") as f:
stored_outputs = json.load(f)
# Find the index of the task with the given task_id
start_index = next( start_index = next(
( (
index index
@@ -588,47 +579,107 @@ class Crew(BaseModel):
), ),
None, None,
) )
if start_index is None: if start_index is None:
raise ValueError(f"Task with id {task_id} not found in the task outputs.") raise ValueError(f"Task with id {task_id} not found in the crew's tasks.")
# Create a map of task ID to stored output
stored_output_map: Dict[str, dict] = {
log["task_id"]: log["output"] for log in stored_outputs
}
# Run the tasks sequentially starting from the task_id task_outputs: List[
for task_index, stored_output in enumerate(stored_outputs[start_index:]): TaskOutput
task = Task(**stored_output["task"]) ] = [] # will propogate the old outputs first to add context then fill the content with the new task outputs relative to the replay start
if task.async_execution: futures: List[Tuple[Task, Future[TaskOutput]]] = []
context = aggregate_raw_outputs_from_task_outputs(stored_outputs) context = ""
future = task.execute_async(
agent=task.agent, context=context, tools=task.tools inputs = stored_outputs[start_index].get("inputs", {})
) if inputs is not None:
futures.append((task, future)) self._interpolate_inputs(inputs)
for task_index, task in enumerate(self.tasks):
if task_index < start_index:
# Use stored output for tasks before the replay point
if task.id in stored_output_map:
stored_output = stored_output_map[task.id]
task_output = TaskOutput(
description=stored_output["description"],
raw_output=stored_output["raw_output"],
pydantic_output=stored_output["pydantic_output"],
json_output=stored_output["json_output"],
agent=stored_output["agent"],
)
task_outputs.append(task_output)
context += (
f"\nTask {task_index + 1} Output:\n{task_output.raw_output}"
)
else: else:
# Before executing a synchronous task, wait for all async tasks to complete role = task.agent.role if task.agent is not None else "None"
if futures: log_color = "bold_blue"
print("futures for sync task", futures) self._logger.log(
# Clear task_outputs before processing async tasks "debug", f"Replaying Working Agent: {role}", color=log_color
task_outputs = [] )
for future_task, future in futures: self._logger.log(
task_output = future.result() "info",
task_outputs.append(task_output) f"Replaying Task: {task.description}",
execution_logs.append( color=log_color,
self._create_execution_log( )
future_task, task_output, task_index
) if self.output_log_file:
self._file_handler.log(
agent=role, task=task.description, status="started"
)
# Execute task for replay and subsequent tasks
if task.async_execution:
future = task.execute_async(
agent=task.agent, context=context, tools=task.tools
)
futures.append((task, future))
else:
if futures:
async_outputs = self._process_async_tasks(
futures, task_index, inputs
) )
self._process_task_result(future_task, task_output) task_outputs.extend(async_outputs)
for output in async_outputs:
context += (
f"\nTask {task_index + 1} Output:\n{output.raw_output}"
)
futures.clear()
task_output = task.execute_sync(
agent=task.agent, context=context, tools=task.tools
)
task_outputs.append(task_output)
self._process_task_result(task, task_output)
self._store_execution_log(task, task_output, task_index, inputs)
context += (
f"\nTask {task_index + 1} Output:\n{task_output.raw_output}"
)
# Clear the futures list after processing all async results # Process any remaining async tasks
futures.clear() if futures:
async_outputs = self._process_async_tasks(futures, len(self.tasks), inputs)
task_outputs.extend(async_outputs)
# Calculate usage metrics
token_usage = self.calculate_usage_metrics()
context = aggregate_raw_outputs_from_task_outputs(task_outputs) # Format and return the final output
task_output = task.execute_sync( return self._format_output(task_outputs, token_usage)
agent=task.agent, context=context, tools=task.tools
) def _load_stored_outputs(self) -> List[Dict]:
task_outputs = [task_output] try:
self._process_task_result(task, task_output) with open(self._log_file, "r") as f:
execution_logs.append( return json.load(f)
self._create_execution_log(task, task_output, task_index) except FileNotFoundError:
) self._logger.log(
"warning",
f"Log file {self._log_file} not found. Starting with empty logs.",
)
return []
except json.JSONDecodeError:
self._logger.log(
"error",
f"Failed to parse log file {self._log_file}. Starting with empty logs.",
)
return []
def save_execution_logs(self, filename: str | None = None): def save_execution_logs(self, filename: str | None = None):
"""Save execution logs to a file.""" """Save execution logs to a file."""
@@ -795,7 +846,7 @@ class Crew(BaseModel):
""" """
return CrewOutput( return CrewOutput(
output=output, output=output,
tasks_output=[task.output for task in self.tasks if task], tasks_output=[task.output for task in self.tasks if task and task.output],
token_usage=token_usage, token_usage=token_usage,
) )

View File

@@ -1,15 +1,10 @@
import json
from datetime import datetime from datetime import datetime
import json
from uuid import UUID from uuid import UUID
from pydantic import BaseModel
from openai import BaseModel
class CrewJSONEncoder(json.JSONEncoder): class CrewJSONEncoder(json.JSONEncoder):
"""
Custom JSON Encoder for Crew related objects.
"""
def default(self, obj): def default(self, obj):
if isinstance(obj, datetime): if isinstance(obj, datetime):
return obj.isoformat() return obj.isoformat()
@@ -17,6 +12,6 @@ class CrewJSONEncoder(json.JSONEncoder):
return str(obj) return str(obj)
if isinstance(obj, BaseModel): if isinstance(obj, BaseModel):
return obj.model_dump() return obj.model_dump()
if isinstance(obj, set): if hasattr(obj, "__dict__"):
return list(obj) return obj.__dict__
return super().default(obj) return str(obj)

View File

@@ -1,6 +1,9 @@
import os import os
import pickle import pickle
from datetime import datetime from datetime import datetime
import json
from crewai.utilities.crew_json_encoder import CrewJSONEncoder
class FileHandler: class FileHandler:
@@ -66,3 +69,37 @@ class PickleHandler:
return {} # Return an empty dictionary if the file is empty or corrupted return {} # Return an empty dictionary if the file is empty or corrupted
except Exception: except Exception:
raise # Raise any other exceptions that occur during loading raise # Raise any other exceptions that occur during loading
class TaskOutputJsonHandler:
def __init__(self, file_name: str) -> None:
self.file_path = os.path.join(os.getcwd(), file_name)
def initialize_file(self) -> None:
if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0:
with open(self.file_path, "w") as file:
json.dump([], file)
def append(self, log) -> None:
if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0:
# Initialize the file with an empty list if it doesn't exist or is empty
with open(self.file_path, "w") as file:
json.dump([], file)
with open(self.file_path, "r+") as file:
try:
file_data = json.load(file)
except json.JSONDecodeError:
# If the file contains invalid JSON, initialize it with an empty list
file_data = []
file_data.append(log)
file.seek(0)
json.dump(file_data, file, indent=2, cls=CrewJSONEncoder)
file.truncate()
def load(self) -> list:
if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0:
return []
with open(self.file_path, "r") as file:
return json.load(file)

View File

@@ -8,6 +8,8 @@ class Printer:
self._print_bold_green(content) self._print_bold_green(content)
elif color == "bold_purple": elif color == "bold_purple":
self._print_bold_purple(content) self._print_bold_purple(content)
elif color == "bold_blue":
self._print_bold_blue(content)
else: else:
print(content) print(content)
@@ -22,3 +24,6 @@ class Printer:
def _print_red(self, content): def _print_red(self, content):
print("\033[91m {}\033[00m".format(content)) print("\033[91m {}\033[00m".format(content))
def _print_bold_blue(self, content):
print("\033[1m\033[94m {}\033[00m".format(content))