From c3ad5887ef4b594fc6e36f7f907bdcfb9f4bb2a9 Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Wed, 10 Sep 2025 10:56:17 -0400 Subject: [PATCH] chore: add type annotations to utilities module (#3484) - Update to Python 3.10+ typing across LLM, callbacks, storage, and errors - Complete typing updates for crew_chat and hitl - Add stop attr to mock LLM, suppress test warnings - Add type-ignore for aisuite import --- src/crewai/utilities/errors.py | 28 +++--- .../utilities/task_output_storage_handler.py | 85 +++++++++++++++---- .../utilities/token_counter_callback.py | 36 +++++++- tests/agents/test_lite_agent.py | 2 + 4 files changed, 120 insertions(+), 31 deletions(-) diff --git a/src/crewai/utilities/errors.py b/src/crewai/utilities/errors.py index 16c59321e..e9aa40872 100644 --- a/src/crewai/utilities/errors.py +++ b/src/crewai/utilities/errors.py @@ -1,12 +1,16 @@ -"""Error message definitions for CrewAI database operations.""" +"""Error message definitions for CrewAI database operations. -from typing import Optional +This module provides standardized error classes and message templates +for database operations and agent repository handling. +""" + +from typing import Final class DatabaseOperationError(Exception): """Base exception class for database operation errors.""" - def __init__(self, message: str, original_error: Optional[Exception] = None): + def __init__(self, message: str, original_error: Exception | None = None) -> None: """Initialize the database operation error. Args: @@ -18,13 +22,17 @@ class DatabaseOperationError(Exception): class DatabaseError: - """Standardized error message templates for database operations.""" + """Standardized error message templates for database operations. - INIT_ERROR: str = "Database initialization error: {}" - SAVE_ERROR: str = "Error saving task outputs: {}" - UPDATE_ERROR: str = "Error updating task outputs: {}" - LOAD_ERROR: str = "Error loading task outputs: {}" - DELETE_ERROR: str = "Error deleting task outputs: {}" + Provides consistent error message formatting for various database + operation failures. + """ + + INIT_ERROR: Final[str] = "Database initialization error: {}" + SAVE_ERROR: Final[str] = "Error saving task outputs: {}" + UPDATE_ERROR: Final[str] = "Error updating task outputs: {}" + LOAD_ERROR: Final[str] = "Error loading task outputs: {}" + DELETE_ERROR: Final[str] = "Error deleting task outputs: {}" @classmethod def format_error(cls, template: str, error: Exception) -> str: @@ -42,5 +50,3 @@ class DatabaseError: class AgentRepositoryError(Exception): """Exception raised when an agent repository is not found.""" - - ... diff --git a/src/crewai/utilities/task_output_storage_handler.py b/src/crewai/utilities/task_output_storage_handler.py index 85799383f..95d366bcb 100644 --- a/src/crewai/utilities/task_output_storage_handler.py +++ b/src/crewai/utilities/task_output_storage_handler.py @@ -1,5 +1,11 @@ +"""Task output storage handler for managing task execution results. + +This module provides functionality for storing and retrieving task outputs +from persistent storage, supporting replay and audit capabilities. +""" + from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any from pydantic import BaseModel, Field @@ -8,32 +14,64 @@ from crewai.memory.storage.kickoff_task_outputs_storage import ( ) from crewai.task import Task -"""Handles storage and retrieval of task execution outputs.""" - class ExecutionLog(BaseModel): - """Represents a log entry for task execution.""" + """Represents a log entry for task execution. + + Attributes: + task_id: Unique identifier for the task. + expected_output: The expected output description for the task. + output: The actual output produced by the task. + timestamp: When the task was executed. + task_index: The position of the task in the execution sequence. + inputs: Input parameters provided to the task. + was_replayed: Whether this output was replayed from a previous run. + """ task_id: str - expected_output: Optional[str] = None - output: Dict[str, Any] + expected_output: str | None = None + output: dict[str, Any] timestamp: datetime = Field(default_factory=datetime.now) task_index: int - inputs: Dict[str, Any] = Field(default_factory=dict) + inputs: dict[str, Any] = Field(default_factory=dict) was_replayed: bool = False def __getitem__(self, key: str) -> Any: + """Enable dictionary-style access to execution log attributes. + + Args: + key: The attribute name to access. + + Returns: + The value of the requested attribute. + """ return getattr(self, key) -"""Manages storage and retrieval of task outputs.""" - - class TaskOutputStorageHandler: + """Manages storage and retrieval of task outputs. + + This handler provides an interface to persist and retrieve task execution + results, supporting features like replay and audit trails. + + Attributes: + storage: The underlying SQLite storage implementation. + """ + def __init__(self) -> None: + """Initialize the task output storage handler.""" self.storage = KickoffTaskOutputsSQLiteStorage() - def update(self, task_index: int, log: Dict[str, Any]): + def update(self, task_index: int, log: dict[str, Any]) -> None: + """Update an existing task output in storage. + + Args: + task_index: The index of the task to update. + log: Dictionary containing task execution details. + + Raises: + ValueError: If no saved outputs exist. + """ saved_outputs = self.load() if saved_outputs is None: raise ValueError("Logs cannot be None") @@ -56,16 +94,31 @@ class TaskOutputStorageHandler: def add( self, task: Task, - output: Dict[str, Any], + output: dict[str, Any], task_index: int, - inputs: Dict[str, Any] | None = None, + inputs: dict[str, Any] | None = None, was_replayed: bool = False, - ): + ) -> None: + """Add a new task output to storage. + + Args: + task: The task that was executed. + output: The output produced by the task. + task_index: The position of the task in execution sequence. + inputs: Optional input parameters for the task. + was_replayed: Whether this is a replayed execution. + """ inputs = inputs or {} self.storage.add(task, output, task_index, was_replayed, inputs) - def reset(self): + def reset(self) -> None: + """Clear all stored task outputs.""" self.storage.delete_all() - def load(self) -> Optional[List[Dict[str, Any]]]: + def load(self) -> list[dict[str, Any]] | None: + """Load all stored task outputs. + + Returns: + List of task output dictionaries, or None if no outputs exist. + """ return self.storage.load() diff --git a/src/crewai/utilities/token_counter_callback.py b/src/crewai/utilities/token_counter_callback.py index 7037ad5c4..4f61d7557 100644 --- a/src/crewai/utilities/token_counter_callback.py +++ b/src/crewai/utilities/token_counter_callback.py @@ -1,5 +1,11 @@ +"""Token counting callback handler for LLM interactions. + +This module provides a callback handler that tracks token usage +for LLM API calls through the litellm library. +""" + import warnings -from typing import Any, Dict, Optional +from typing import Any from litellm.integrations.custom_logger import CustomLogger from litellm.types.utils import Usage @@ -8,16 +14,38 @@ from crewai.agents.agent_builder.utilities.base_token_process import TokenProces class TokenCalcHandler(CustomLogger): - def __init__(self, token_cost_process: Optional[TokenProcess]): + """Handler for calculating and tracking token usage in LLM calls. + + This handler integrates with litellm's logging system to track + prompt tokens, completion tokens, and cached tokens across requests. + + Attributes: + token_cost_process: The token process tracker to accumulate usage metrics. + """ + + def __init__(self, token_cost_process: TokenProcess | None) -> None: + """Initialize the token calculation handler. + + Args: + token_cost_process: Optional token process tracker for accumulating metrics. + """ self.token_cost_process = token_cost_process def log_success_event( self, - kwargs: Dict[str, Any], - response_obj: Dict[str, Any], + kwargs: dict[str, Any], + response_obj: dict[str, Any], start_time: float, end_time: float, ) -> None: + """Log successful LLM API call and track token usage. + + Args: + kwargs: The arguments passed to the LLM call. + response_obj: The response object from the LLM API. + start_time: The timestamp when the call started. + end_time: The timestamp when the call completed. + """ if self.token_cost_process is None: return diff --git a/tests/agents/test_lite_agent.py b/tests/agents/test_lite_agent.py index 8653bfb62..101072361 100644 --- a/tests/agents/test_lite_agent.py +++ b/tests/agents/test_lite_agent.py @@ -1,3 +1,5 @@ +# ruff: noqa: S101 +# mypy: ignore-errors from collections import defaultdict from typing import cast from unittest.mock import Mock, patch