using sqllite instead of .json file for logging previous task_outputs

This commit is contained in:
Lorenze Jay
2024-07-13 18:06:38 -07:00
parent 96af6027bd
commit 9eefa312ae
8 changed files with 297 additions and 149 deletions

View File

@@ -1,6 +1,10 @@
import click
import pkg_resources
from crewai.memory.storage.kickoff_task_outputs_storage import (
KickoffTaskOutputsSQLiteStorage,
)
from .create_crew import create_crew
from .train_crew import train_crew
@@ -71,5 +75,27 @@ def replay(task_id: str) -> None:
click.echo(f"An error occurred while replaying: {e}", err=True)
@crewai.command()
def log_tasks_outputs() -> None:
"""
Log your previously ran kickoff task outputs.
"""
try:
storage = KickoffTaskOutputsSQLiteStorage()
tasks = storage.load()
if not tasks:
click.echo("No task outputs found.")
return
for index, task in enumerate(tasks, 1):
click.echo(f"Task {index}: {task['task_id']}")
click.echo(f"Description: {task['expected_output']}")
click.echo("------")
except Exception as e:
click.echo(f"An error occurred while logging task outputs: {e}", err=True)
if __name__ == "__main__":
crewai()

View File

@@ -9,7 +9,7 @@ def replay_task_command(task_id: str) -> None:
Args:
task_id (str): The ID of the task to replay from.
"""
command = ["poetry", "run", "replay_from_task", task_id]
command = ["poetry", "run", "replay", task_id]
try:
result = subprocess.run(command, capture_output=False, text=True, check=True)

View File

@@ -32,15 +32,12 @@ from crewai.telemetry import Telemetry
from crewai.tools.agent_tools import AgentTools
from crewai.utilities import I18N, FileHandler, Logger, RPMController
from crewai.utilities.constants import (
CREW_TASKS_OUTPUT_FILE,
TRAINED_AGENTS_DATA_FILE,
TRAINING_DATA_FILE,
)
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
from crewai.utilities.task_output_handler import (
ExecutionLog,
TaskOutputJsonHandler,
)
from crewai.utilities.task_output_storage_handler import TaskOutputStorageHandler
from crewai.utilities.formatter import (
aggregate_raw_outputs_from_task_outputs,
aggregate_raw_outputs_from_tasks,
@@ -92,6 +89,9 @@ class Crew(BaseModel):
_logging_color: str = PrivateAttr(
default="bold_purple",
)
_task_output_handler: TaskOutputStorageHandler = PrivateAttr(
default_factory=TaskOutputStorageHandler
)
cache: bool = Field(default=True)
model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -151,7 +151,7 @@ class Crew(BaseModel):
default=None,
description="List of file paths for task execution JSON files.",
)
execution_logs: List[ExecutionLog] = Field(
execution_logs: List[Dict[str, Any]] = Field(
default=[],
description="List of execution logs for tasks",
)
@@ -397,8 +397,7 @@ class Crew(BaseModel):
) -> CrewOutput:
"""Starts the crew to work on its assigned tasks."""
self._execution_span = self._telemetry.crew_execution_span(self, inputs)
TaskOutputJsonHandler(CREW_TASKS_OUTPUT_FILE).initialize_file()
TaskOutputJsonHandler(CREW_TASKS_OUTPUT_FILE).reset()
self._task_output_handler.reset()
self._logging_color = "bold_purple"
if inputs is not None:
@@ -529,10 +528,9 @@ class Crew(BaseModel):
else:
inputs = {}
log = ExecutionLog(
task_id=str(task.id),
expected_output=task.expected_output,
output={
log = {
"task": task,
"output": {
"description": output.description,
"summary": output.summary,
"raw": output.raw,
@@ -541,16 +539,11 @@ class Crew(BaseModel):
"output_format": output.output_format,
"agent": output.agent,
},
task_index=task_index,
inputs=inputs,
was_replayed=was_replayed,
)
if task_index < len(self.execution_logs):
self.execution_logs[task_index] = log
else:
self.execution_logs.append(log)
TaskOutputJsonHandler(CREW_TASKS_OUTPUT_FILE).update(task_index, log)
"task_index": task_index,
"inputs": inputs,
"was_replayed": was_replayed,
}
self._task_output_handler.update(task_index, log)
def _run_sequential_process(self) -> CrewOutput:
"""Executes tasks sequentially and returns the final output."""
@@ -741,7 +734,11 @@ class Crew(BaseModel):
def replay_from_task(
self, task_id: str, inputs: Optional[Dict[str, Any]] = None
) -> CrewOutput:
stored_outputs = TaskOutputJsonHandler(CREW_TASKS_OUTPUT_FILE).load()
stored_outputs = self._task_output_handler.load()
# TODO: write tests for this
if not stored_outputs:
raise ValueError(f"Task with id {task_id} not found in the crew's tasks.")
start_index = self._find_task_index(task_id, stored_outputs)
if start_index is None:
@@ -759,7 +756,9 @@ class Crew(BaseModel):
self._create_manager_agent()
for i in range(start_index):
stored_output = stored_outputs[i]["output"]
stored_output = stored_outputs[i][
"output"
] # for adding context to the task
task_output = TaskOutput(
description=stored_output["description"],
agent=stored_output["agent"],

View File

@@ -0,0 +1,170 @@
import json
import sqlite3
from typing import Any, Dict, List, Optional
from crewai.task import Task
from crewai.utilities import Printer
from crewai.utilities.crew_json_encoder import CrewJSONEncoder
from crewai.utilities.paths import db_storage_path
class KickoffTaskOutputsSQLiteStorage:
"""
An updated SQLite storage class for kickoff task outputs storage.
"""
def __init__(
self, db_path: str = f"{db_storage_path()}/kickoff_task_outputs.db"
) -> None:
self.db_path = db_path
self._printer: Printer = Printer()
self._initialize_db()
def _initialize_db(self):
"""
Initializes the SQLite database and creates LTM table
"""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS latest_kickoff_task_outputs (
task_id TEXT PRIMARY KEY,
expected_output TEXT,
output JSON,
task_index INTEGER,
inputs JSON,
was_replayed BOOLEAN,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
)
"""
)
conn.commit()
except sqlite3.Error as e:
self._printer.print(
content=f"SAVING KICKOFF TASK OUTPUTS ERROR: An error occurred during database initialization: {e}",
color="red",
)
def add(
self,
task: Task,
output: Dict[str, Any],
task_index: int,
was_replayed: bool = False,
inputs: Dict[str, Any] = {},
):
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute(
"""
INSERT OR REPLACE INTO latest_kickoff_task_outputs
(task_id, expected_output, output, task_index, inputs, was_replayed)
VALUES (?, ?, ?, ?, ?, ?)
""",
(
str(task.id),
task.expected_output,
json.dumps(output, cls=CrewJSONEncoder),
task_index,
json.dumps(inputs),
was_replayed,
),
)
conn.commit()
except sqlite3.Error as e:
self._printer.print(
content=f"SAVING KICKOFF TASK OUTPUTS ERROR: An error occurred during database initialization: {e}",
color="red",
)
def update(
self,
task_index: int,
**kwargs,
):
"""
Updates an existing row in the latest_kickoff_task_outputs table based on task_index.
"""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
fields = []
values = []
for key, value in kwargs.items():
fields.append(f"{key} = ?")
values.append(
json.dumps(value, cls=CrewJSONEncoder)
if isinstance(value, dict)
else value
)
query = f"UPDATE latest_kickoff_task_outputs SET {', '.join(fields)} WHERE task_index = ?"
values.append(task_index)
cursor.execute(query, tuple(values))
conn.commit()
if cursor.rowcount == 0:
self._printer.print(
f"No row found with task_index {task_index}. No update performed.",
color="yellow",
)
else:
self._printer.print(
f"Updated row with task_index {task_index}.", color="green"
)
except sqlite3.Error as e:
self._printer.print(f"UPDATE ERROR: {e}", color="red")
def load(self) -> Optional[List[Dict[str, Any]]]:
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT task_id, expected_output, output, task_index, inputs, was_replayed, timestamp
FROM latest_kickoff_task_outputs
ORDER BY task_index
""")
rows = cursor.fetchall()
results = []
for row in rows:
result = {
"task_id": row[0],
"expected_output": row[1],
"output": json.loads(row[2]),
"task_index": row[3],
"inputs": json.loads(row[4]),
"was_replayed": row[5],
"timestamp": row[6],
}
results.append(result)
return results
except sqlite3.Error as e:
self._printer.print(
content=f"LOADING KICKOFF TASK OUTPUTS ERROR: An error occurred while querying kickoff task outputs: {e}",
color="red",
)
return None
def delete_all(self):
"""
Deletes all rows from the latest_kickoff_task_outputs table.
"""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("DELETE FROM latest_kickoff_task_outputs")
conn.commit()
except sqlite3.Error as e:
self._printer.print(
content=f"ERROR: Failed to delete all kickoff task outputs: {e}",
color="red",
)

View File

@@ -1,3 +1,2 @@
TRAINING_DATA_FILE = "training_data.pkl"
TRAINED_AGENTS_DATA_FILE = "trained_agents_data.pkl"
CREW_TASKS_OUTPUT_FILE = "crew_tasks_output.json"

View File

@@ -1,69 +0,0 @@
import json
import os
from pydantic import BaseModel, Field
from datetime import datetime
from typing import Dict, Any, Optional, List
from crewai.utilities.crew_json_encoder import CrewJSONEncoder
class ExecutionLog(BaseModel):
task_id: str
expected_output: Optional[str] = None
output: Dict[str, Any]
timestamp: datetime = Field(default_factory=datetime.now)
task_index: int
inputs: Dict[str, Any] = Field(default_factory=dict)
was_replayed: bool = False
def __getitem__(self, key: str) -> Any:
return getattr(self, key)
class TaskOutputJsonHandler:
def __init__(self, file_name: str) -> None:
self.file_path = os.path.join(os.getcwd(), file_name)
def initialize_file(self) -> None:
if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0:
with open(self.file_path, "w") as file:
json.dump([], file)
def update(self, task_index: int, log: ExecutionLog):
logs = self.load()
if task_index < len(logs):
logs[task_index] = log
else:
logs.append(log)
self.save(logs)
def save(self, logs: List[ExecutionLog]):
with open(self.file_path, "w") as file:
json.dump(logs, file, indent=2, cls=CrewJSONEncoder)
def reset(self):
"""Reset the JSON file by creating an empty file."""
with open(self.file_path, "w") as f:
json.dump([], f)
def load(self) -> List[ExecutionLog]:
try:
if (
not os.path.exists(self.file_path)
or os.path.getsize(self.file_path) == 0
):
return []
with open(self.file_path, "r") as file:
return json.load(file)
except FileNotFoundError:
print(f"File {self.file_path} not found. Returning empty list.")
return []
except json.JSONDecodeError:
print(
f"Error decoding JSON from file {self.file_path}. Returning empty list."
)
return []
except Exception as e:
print(f"An unexpected error occurred: {e}")
return []

View File

@@ -0,0 +1,61 @@
from pydantic import BaseModel, Field
from datetime import datetime
from typing import Dict, Any, Optional, List
from crewai.memory.storage.kickoff_task_outputs_storage import (
KickoffTaskOutputsSQLiteStorage,
)
from crewai.task import Task
class ExecutionLog(BaseModel):
task_id: str
expected_output: Optional[str] = None
output: Dict[str, Any]
timestamp: datetime = Field(default_factory=datetime.now)
task_index: int
inputs: Dict[str, Any] = Field(default_factory=dict)
was_replayed: bool = False
def __getitem__(self, key: str) -> Any:
return getattr(self, key)
class TaskOutputStorageHandler:
def __init__(self) -> None:
self.storage = KickoffTaskOutputsSQLiteStorage()
def update(self, task_index: int, log: Dict[str, Any]):
saved_outputs = self.load()
if saved_outputs is None:
raise ValueError("Logs cannot be None")
if log.get("was_replayed", False):
replayed = {
"task_id": str(log["task"].id),
"expected_output": log["task"].expected_output,
"output": log["output"],
"was_replayed": log["was_replayed"],
"inputs": log["inputs"],
}
self.storage.update(
task_index,
**replayed,
)
else:
self.storage.add(**log)
def add(
self,
task: Task,
output: Dict[str, Any],
task_index: int,
inputs: Dict[str, Any] = {},
was_replayed: bool = False,
):
self.storage.add(task, output, task_index, was_replayed, inputs)
def reset(self):
self.storage.delete_all()
def load(self) -> Optional[List[Dict[str, Any]]]:
return self.storage.load()

View File

@@ -1,6 +1,5 @@
"""Test Agent creation and execution basic functionality."""
import os
import json
from concurrent.futures import Future
from unittest import mock
@@ -19,7 +18,7 @@ from crewai.task import Task
from crewai.tasks.output_format import OutputFormat
from crewai.tasks.task_output import TaskOutput
from crewai.utilities import Logger, RPMController
from crewai.utilities.constants import CREW_TASKS_OUTPUT_FILE
from crewai.utilities.task_output_storage_handler import TaskOutputStorageHandler
ceo = Agent(
role="CEO",
@@ -1861,7 +1860,7 @@ def test_crew_replay_from_task_error():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_crew_task_output_file_creation():
def test_crew_task_db_init():
agent = Agent(
role="Content Writer",
goal="Write engaging content on various topics.",
@@ -1889,46 +1888,13 @@ def test_crew_task_output_file_creation():
crew.kickoff()
# Check if the crew_tasks_output.json file is created
assert os.path.exists(CREW_TASKS_OUTPUT_FILE)
# Clean up the file after test
if os.path.exists(CREW_TASKS_OUTPUT_FILE):
os.remove(CREW_TASKS_OUTPUT_FILE)
@pytest.mark.vcr(filter_headers=["authorization"])
def test_replay_without_output_tasks_json():
agent = Agent(
role="Technical Writer",
goal="Write detailed technical documentation.",
backstory="You have a background in software engineering and technical writing.",
)
task = Task(
description="Document the process of setting up a Python project.",
expected_output="A step-by-step guide on setting up a Python project.",
agent=agent,
)
crew = Crew(agents=[agent], tasks=[task])
with patch.object(Task, "execute_sync") as mock_execute_task:
mock_execute_task.return_value = TaskOutput(
description="Document the process of setting up a Python project.",
raw="To set up a Python project, first create a virtual environment...",
agent="Technical Writer",
json_dict=None,
output_format=OutputFormat.RAW,
pydantic=None,
summary="Document the process of setting up a Python project...",
)
if os.path.exists(CREW_TASKS_OUTPUT_FILE):
os.remove(CREW_TASKS_OUTPUT_FILE)
with pytest.raises(ValueError):
crew.replay_from_task(str(task.id))
# Check if this runs without raising an exception
try:
db_handler = TaskOutputStorageHandler()
db_handler.load()
assert True # If we reach this point, no exception was raised
except Exception as e:
pytest.fail(f"An exception was raised: {str(e)}")
@pytest.mark.vcr(filter_headers=["authorization"])
@@ -2018,20 +1984,16 @@ def test_replay_task_with_context():
]
crew.kickoff()
db_handler = TaskOutputStorageHandler()
assert db_handler.load() != []
# Check if the crew_tasks_output.json file is created
assert os.path.exists(CREW_TASKS_OUTPUT_FILE)
# Replay task4 and ensure it uses task1's context properly
with patch.object(Task, "execute_sync") as mock_replay_task:
mock_replay_task.return_value = mock_task_output4
replayed_output = crew.replay_from_task(str(task4.id))
assert replayed_output.raw == "Presentation on AI advancements..."
# Clean up the file after test
if os.path.exists(CREW_TASKS_OUTPUT_FILE):
os.remove(CREW_TASKS_OUTPUT_FILE)
db_handler.reset()
def test_replay_from_task_with_context():
@@ -2056,7 +2018,7 @@ def test_replay_from_task_with_context():
crew = Crew(agents=[agent], tasks=[task1, task2], process=Process.sequential)
with patch(
"crewai.utilities.task_output_handler.TaskOutputJsonHandler.load",
"crewai.utilities.task_output_storage_handler.TaskOutputStorageHandler.load",
return_value=[
{
"task_id": str(task1.id),
@@ -2114,7 +2076,7 @@ def test_replay_with_invalid_task_id():
crew = Crew(agents=[agent], tasks=[task1, task2], process=Process.sequential)
with patch(
"crewai.utilities.task_output_handler.TaskOutputJsonHandler.load",
"crewai.utilities.task_output_storage_handler.TaskOutputStorageHandler.load",
return_value=[
{
"task_id": str(task1.id),
@@ -2176,7 +2138,7 @@ def test_replay_interpolates_inputs_properly(mock_interpolate_inputs):
crew.kickoff(inputs={"name": "John"})
with patch(
"crewai.utilities.task_output_handler.TaskOutputJsonHandler.load",
"crewai.utilities.task_output_storage_handler.TaskOutputStorageHandler.load",
return_value=[
{
"task_id": str(task1.id),
@@ -2231,7 +2193,7 @@ def test_replay_from_task_setup_context():
task1.output = context_output
crew = Crew(agents=[agent], tasks=[task1, task2], process=Process.sequential)
with patch(
"crewai.utilities.task_output_handler.TaskOutputJsonHandler.load",
"crewai.utilities.task_output_storage_handler.TaskOutputStorageHandler.load",
return_value=[
{
"task_id": str(task1.id),