Fix SQLite log handling issue causing ValueError: Logs cannot be None in tests (#1899)

* Fix SQLite log handling issue causing ValueError: Logs cannot be None in tests

- Add proper error handling in SQLite storage operations
- Set up isolated test environment with temporary storage directory
- Ensure consistent error messages across all database operations

Co-Authored-By: Joe Moura <joao@crewai.com>

* fix: Sort imports in conftest.py

Co-Authored-By: Joe Moura <joao@crewai.com>

* fix: Convert TokenProcess counters to instance variables to fix callback tracking

Co-Authored-By: Joe Moura <joao@crewai.com>

* refactor: Replace print statements with logging and improve error handling

- Add proper logging setup in kickoff_task_outputs_storage.py
- Replace self._printer.print() with logger calls
- Use appropriate log levels (error/warning)
- Add directory validation in test environment setup
- Maintain consistent error messages with DatabaseError format

Co-Authored-By: Joe Moura <joao@crewai.com>

* fix: Comprehensive improvements to database and token handling

- Fix SQLite database path handling in storage classes
- Add proper directory creation and error handling
- Improve token tracking with robust type checking
- Convert TokenProcess counters to instance variables
- Add standardized database error handling
- Set up isolated test environment with temporary storage

Resolves test failures in PR #1899

Co-Authored-By: Joe Moura <joao@crewai.com>

---------

Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com>
Co-authored-by: Joe Moura <joao@crewai.com>
Co-authored-by: João Moura <joaomdmoura@gmail.com>
This commit is contained in:
devin-ai-integration[bot]
2025-01-16 11:18:54 -03:00
committed by GitHub
parent 294f2cc3a9
commit 42311d9c7a
7 changed files with 193 additions and 49 deletions

View File

@@ -2,11 +2,12 @@ from crewai.types.usage_metrics import UsageMetrics
class TokenProcess: class TokenProcess:
total_tokens: int = 0 def __init__(self):
prompt_tokens: int = 0 self.total_tokens: int = 0
cached_prompt_tokens: int = 0 self.prompt_tokens: int = 0
completion_tokens: int = 0 self.cached_prompt_tokens: int = 0
successful_requests: int = 0 self.completion_tokens: int = 0
self.successful_requests: int = 0
def sum_prompt_tokens(self, tokens: int): def sum_prompt_tokens(self, tokens: int):
self.prompt_tokens = self.prompt_tokens + tokens self.prompt_tokens = self.prompt_tokens + tokens

View File

@@ -222,6 +222,19 @@ class LLM:
].message ].message
text_response = response_message.content or "" text_response = response_message.content or ""
tool_calls = getattr(response_message, "tool_calls", []) tool_calls = getattr(response_message, "tool_calls", [])
# Ensure callbacks get the full response object with usage info
if callbacks and len(callbacks) > 0:
for callback in callbacks:
if hasattr(callback, "log_success_event"):
usage_info = getattr(response, "usage", None)
if usage_info:
callback.log_success_event(
kwargs=params,
response_obj={"usage": usage_info},
start_time=0,
end_time=0,
)
# --- 2) If no tool calls, return the text response # --- 2) If no tool calls, return the text response
if not tool_calls or not available_functions: if not tool_calls or not available_functions:

View File

@@ -1,12 +1,17 @@
import json import json
import logging
import sqlite3 import sqlite3
from pathlib import Path
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from crewai.task import Task from crewai.task import Task
from crewai.utilities import Printer from crewai.utilities import Printer
from crewai.utilities.crew_json_encoder import CrewJSONEncoder from crewai.utilities.crew_json_encoder import CrewJSONEncoder
from crewai.utilities.errors import DatabaseError, DatabaseOperationError
from crewai.utilities.paths import db_storage_path from crewai.utilities.paths import db_storage_path
logger = logging.getLogger(__name__)
class KickoffTaskOutputsSQLiteStorage: class KickoffTaskOutputsSQLiteStorage:
""" """
@@ -14,15 +19,24 @@ class KickoffTaskOutputsSQLiteStorage:
""" """
def __init__( def __init__(
self, db_path: str = f"{db_storage_path()}/latest_kickoff_task_outputs.db" self, db_path: Optional[str] = None
) -> None: ) -> None:
if db_path is None:
# Get the parent directory of the default db path and create our db file there
db_path = str(Path(db_storage_path()).parent / "latest_kickoff_task_outputs.db")
self.db_path = db_path self.db_path = db_path
self._printer: Printer = Printer() self._printer: Printer = Printer()
self._initialize_db() self._initialize_db()
def _initialize_db(self): def _initialize_db(self) -> None:
""" """Initialize the SQLite database and create the latest_kickoff_task_outputs table.
Initializes the SQLite database and creates LTM table
This method sets up the database schema for storing task outputs. It creates
a table with columns for task_id, expected_output, output (as JSON),
task_index, inputs (as JSON), was_replayed flag, and timestamp.
Raises:
DatabaseOperationError: If database initialization fails due to SQLite errors.
""" """
try: try:
with sqlite3.connect(self.db_path) as conn: with sqlite3.connect(self.db_path) as conn:
@@ -43,10 +57,9 @@ class KickoffTaskOutputsSQLiteStorage:
conn.commit() conn.commit()
except sqlite3.Error as e: except sqlite3.Error as e:
self._printer.print( error_msg = DatabaseError.format_error(DatabaseError.INIT_ERROR, e)
content=f"SAVING KICKOFF TASK OUTPUTS ERROR: An error occurred during database initialization: {e}", logger.error(error_msg)
color="red", raise DatabaseOperationError(error_msg, e)
)
def add( def add(
self, self,
@@ -55,9 +68,22 @@ class KickoffTaskOutputsSQLiteStorage:
task_index: int, task_index: int,
was_replayed: bool = False, was_replayed: bool = False,
inputs: Dict[str, Any] = {}, inputs: Dict[str, Any] = {},
): ) -> None:
"""Add a new task output record to the database.
Args:
task: The Task object containing task details.
output: Dictionary containing the task's output data.
task_index: Integer index of the task in the sequence.
was_replayed: Boolean indicating if this was a replay execution.
inputs: Dictionary of input parameters used for the task.
Raises:
DatabaseOperationError: If saving the task output fails due to SQLite errors.
"""
try: try:
with sqlite3.connect(self.db_path) as conn: with sqlite3.connect(self.db_path) as conn:
conn.execute("BEGIN TRANSACTION")
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute( cursor.execute(
""" """
@@ -76,21 +102,31 @@ class KickoffTaskOutputsSQLiteStorage:
) )
conn.commit() conn.commit()
except sqlite3.Error as e: except sqlite3.Error as e:
self._printer.print( error_msg = DatabaseError.format_error(DatabaseError.SAVE_ERROR, e)
content=f"SAVING KICKOFF TASK OUTPUTS ERROR: An error occurred during database initialization: {e}", logger.error(error_msg)
color="red", raise DatabaseOperationError(error_msg, e)
)
def update( def update(
self, self,
task_index: int, task_index: int,
**kwargs, **kwargs: Any,
): ) -> None:
""" """Update an existing task output record in the database.
Updates an existing row in the latest_kickoff_task_outputs table based on task_index.
Updates fields of a task output record identified by task_index. The fields
to update are provided as keyword arguments.
Args:
task_index: Integer index of the task to update.
**kwargs: Arbitrary keyword arguments representing fields to update.
Values that are dictionaries will be JSON encoded.
Raises:
DatabaseOperationError: If updating the task output fails due to SQLite errors.
""" """
try: try:
with sqlite3.connect(self.db_path) as conn: with sqlite3.connect(self.db_path) as conn:
conn.execute("BEGIN TRANSACTION")
cursor = conn.cursor() cursor = conn.cursor()
fields = [] fields = []
@@ -110,14 +146,23 @@ class KickoffTaskOutputsSQLiteStorage:
conn.commit() conn.commit()
if cursor.rowcount == 0: if cursor.rowcount == 0:
self._printer.print( logger.warning(f"No row found with task_index {task_index}. No update performed.")
f"No row found with task_index {task_index}. No update performed.",
color="red",
)
except sqlite3.Error as e: except sqlite3.Error as e:
self._printer.print(f"UPDATE KICKOFF TASK OUTPUTS ERROR: {e}", color="red") error_msg = DatabaseError.format_error(DatabaseError.UPDATE_ERROR, e)
logger.error(error_msg)
raise DatabaseOperationError(error_msg, e)
def load(self) -> Optional[List[Dict[str, Any]]]: def load(self) -> List[Dict[str, Any]]:
"""Load all task output records from the database.
Returns:
List of dictionaries containing task output records, ordered by task_index.
Each dictionary contains: task_id, expected_output, output, task_index,
inputs, was_replayed, and timestamp.
Raises:
DatabaseOperationError: If loading task outputs fails due to SQLite errors.
"""
try: try:
with sqlite3.connect(self.db_path) as conn: with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor() cursor = conn.cursor()
@@ -144,23 +189,26 @@ class KickoffTaskOutputsSQLiteStorage:
return results return results
except sqlite3.Error as e: except sqlite3.Error as e:
self._printer.print( error_msg = DatabaseError.format_error(DatabaseError.LOAD_ERROR, e)
content=f"LOADING KICKOFF TASK OUTPUTS ERROR: An error occurred while querying kickoff task outputs: {e}", logger.error(error_msg)
color="red", raise DatabaseOperationError(error_msg, e)
)
return None
def delete_all(self): def delete_all(self) -> None:
""" """Delete all task output records from the database.
Deletes all rows from the latest_kickoff_task_outputs table.
This method removes all records from the latest_kickoff_task_outputs table.
Use with caution as this operation cannot be undone.
Raises:
DatabaseOperationError: If deleting task outputs fails due to SQLite errors.
""" """
try: try:
with sqlite3.connect(self.db_path) as conn: with sqlite3.connect(self.db_path) as conn:
conn.execute("BEGIN TRANSACTION")
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("DELETE FROM latest_kickoff_task_outputs") cursor.execute("DELETE FROM latest_kickoff_task_outputs")
conn.commit() conn.commit()
except sqlite3.Error as e: except sqlite3.Error as e:
self._printer.print( error_msg = DatabaseError.format_error(DatabaseError.DELETE_ERROR, e)
content=f"ERROR: Failed to delete all kickoff task outputs: {e}", logger.error(error_msg)
color="red", raise DatabaseOperationError(error_msg, e)
)

View File

@@ -1,5 +1,6 @@
import json import json
import sqlite3 import sqlite3
from pathlib import Path
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from crewai.utilities import Printer from crewai.utilities import Printer
@@ -12,10 +13,15 @@ class LTMSQLiteStorage:
""" """
def __init__( def __init__(
self, db_path: str = f"{db_storage_path()}/long_term_memory_storage.db" self, db_path: Optional[str] = None
) -> None: ) -> None:
if db_path is None:
# Get the parent directory of the default db path and create our db file there
db_path = str(Path(db_storage_path()).parent / "long_term_memory_storage.db")
self.db_path = db_path self.db_path = db_path
self._printer: Printer = Printer() self._printer: Printer = Printer()
# Ensure parent directory exists
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
self._initialize_db() self._initialize_db()
def _initialize_db(self): def _initialize_db(self):

View File

@@ -0,0 +1,39 @@
"""Error message definitions for CrewAI database operations."""
from typing import Optional
class DatabaseOperationError(Exception):
"""Base exception class for database operation errors."""
def __init__(self, message: str, original_error: Optional[Exception] = None):
"""Initialize the database operation error.
Args:
message: The error message to display
original_error: The original exception that caused this error, if any
"""
super().__init__(message)
self.original_error = original_error
class DatabaseError:
"""Standardized error message templates for database operations."""
INIT_ERROR: str = "Database initialization error: {}"
SAVE_ERROR: str = "Error saving task outputs: {}"
UPDATE_ERROR: str = "Error updating task outputs: {}"
LOAD_ERROR: str = "Error loading task outputs: {}"
DELETE_ERROR: str = "Error deleting task outputs: {}"
@classmethod
def format_error(cls, template: str, error: Exception) -> str:
"""Format an error message with the given template and error.
Args:
template: The error message template to use
error: The exception to format into the template
Returns:
The formatted error message
"""
return template.format(str(error))

View File

@@ -23,11 +23,15 @@ class TokenCalcHandler(CustomLogger):
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning) warnings.simplefilter("ignore", UserWarning)
usage: Usage = response_obj["usage"] if isinstance(response_obj, dict) and "usage" in response_obj:
self.token_cost_process.sum_successful_requests(1) usage: Usage = response_obj["usage"]
self.token_cost_process.sum_prompt_tokens(usage.prompt_tokens) if usage:
self.token_cost_process.sum_completion_tokens(usage.completion_tokens) self.token_cost_process.sum_successful_requests(1)
if usage.prompt_tokens_details: if hasattr(usage, "prompt_tokens"):
self.token_cost_process.sum_cached_prompt_tokens( self.token_cost_process.sum_prompt_tokens(usage.prompt_tokens)
usage.prompt_tokens_details.cached_tokens if hasattr(usage, "completion_tokens"):
) self.token_cost_process.sum_completion_tokens(usage.completion_tokens)
if hasattr(usage, "prompt_tokens_details") and usage.prompt_tokens_details:
self.token_cost_process.sum_cached_prompt_tokens(
usage.prompt_tokens_details.cached_tokens
)

View File

@@ -1,4 +1,37 @@
# conftest.py # conftest.py
import os
import tempfile
from pathlib import Path
import pytest
from dotenv import load_dotenv from dotenv import load_dotenv
load_result = load_dotenv(override=True) load_result = load_dotenv(override=True)
@pytest.fixture(autouse=True)
def setup_test_environment():
"""Set up test environment with a temporary directory for SQLite storage."""
with tempfile.TemporaryDirectory() as temp_dir:
# Create the directory with proper permissions
storage_dir = Path(temp_dir) / "crewai_test_storage"
storage_dir.mkdir(parents=True, exist_ok=True)
# Validate that the directory was created successfully
if not storage_dir.exists() or not storage_dir.is_dir():
raise RuntimeError(f"Failed to create test storage directory: {storage_dir}")
# Verify directory permissions
try:
# Try to create a test file to verify write permissions
test_file = storage_dir / ".permissions_test"
test_file.touch()
test_file.unlink()
except (OSError, IOError) as e:
raise RuntimeError(f"Test storage directory {storage_dir} is not writable: {e}")
# Set environment variable to point to the test storage directory
os.environ["CREWAI_STORAGE_DIR"] = str(storage_dir)
yield
# Cleanup is handled automatically when tempfile context exits