mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 16:48:30 +00:00
chore: add type annotations to utilities module (#3484)
- Update to Python 3.10+ typing across LLM, callbacks, storage, and errors - Complete typing updates for crew_chat and hitl - Add stop attr to mock LLM, suppress test warnings - Add type-ignore for aisuite import
This commit is contained in:
@@ -1,12 +1,16 @@
|
||||
"""Error message definitions for CrewAI database operations."""
|
||||
"""Error message definitions for CrewAI database operations.
|
||||
|
||||
from typing import Optional
|
||||
This module provides standardized error classes and message templates
|
||||
for database operations and agent repository handling.
|
||||
"""
|
||||
|
||||
from typing import Final
|
||||
|
||||
|
||||
class DatabaseOperationError(Exception):
|
||||
"""Base exception class for database operation errors."""
|
||||
|
||||
def __init__(self, message: str, original_error: Optional[Exception] = None):
|
||||
def __init__(self, message: str, original_error: Exception | None = None) -> None:
|
||||
"""Initialize the database operation error.
|
||||
|
||||
Args:
|
||||
@@ -18,13 +22,17 @@ class DatabaseOperationError(Exception):
|
||||
|
||||
|
||||
class DatabaseError:
|
||||
"""Standardized error message templates for database operations."""
|
||||
"""Standardized error message templates for database operations.
|
||||
|
||||
INIT_ERROR: str = "Database initialization error: {}"
|
||||
SAVE_ERROR: str = "Error saving task outputs: {}"
|
||||
UPDATE_ERROR: str = "Error updating task outputs: {}"
|
||||
LOAD_ERROR: str = "Error loading task outputs: {}"
|
||||
DELETE_ERROR: str = "Error deleting task outputs: {}"
|
||||
Provides consistent error message formatting for various database
|
||||
operation failures.
|
||||
"""
|
||||
|
||||
INIT_ERROR: Final[str] = "Database initialization error: {}"
|
||||
SAVE_ERROR: Final[str] = "Error saving task outputs: {}"
|
||||
UPDATE_ERROR: Final[str] = "Error updating task outputs: {}"
|
||||
LOAD_ERROR: Final[str] = "Error loading task outputs: {}"
|
||||
DELETE_ERROR: Final[str] = "Error deleting task outputs: {}"
|
||||
|
||||
@classmethod
|
||||
def format_error(cls, template: str, error: Exception) -> str:
|
||||
@@ -42,5 +50,3 @@ class DatabaseError:
|
||||
|
||||
class AgentRepositoryError(Exception):
|
||||
"""Exception raised when an agent repository is not found."""
|
||||
|
||||
...
|
||||
|
||||
@@ -1,5 +1,11 @@
|
||||
"""Task output storage handler for managing task execution results.
|
||||
|
||||
This module provides functionality for storing and retrieving task outputs
|
||||
from persistent storage, supporting replay and audit capabilities.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -8,32 +14,64 @@ from crewai.memory.storage.kickoff_task_outputs_storage import (
|
||||
)
|
||||
from crewai.task import Task
|
||||
|
||||
"""Handles storage and retrieval of task execution outputs."""
|
||||
|
||||
|
||||
class ExecutionLog(BaseModel):
|
||||
"""Represents a log entry for task execution."""
|
||||
"""Represents a log entry for task execution.
|
||||
|
||||
Attributes:
|
||||
task_id: Unique identifier for the task.
|
||||
expected_output: The expected output description for the task.
|
||||
output: The actual output produced by the task.
|
||||
timestamp: When the task was executed.
|
||||
task_index: The position of the task in the execution sequence.
|
||||
inputs: Input parameters provided to the task.
|
||||
was_replayed: Whether this output was replayed from a previous run.
|
||||
"""
|
||||
|
||||
task_id: str
|
||||
expected_output: Optional[str] = None
|
||||
output: Dict[str, Any]
|
||||
expected_output: str | None = None
|
||||
output: dict[str, Any]
|
||||
timestamp: datetime = Field(default_factory=datetime.now)
|
||||
task_index: int
|
||||
inputs: Dict[str, Any] = Field(default_factory=dict)
|
||||
inputs: dict[str, Any] = Field(default_factory=dict)
|
||||
was_replayed: bool = False
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
"""Enable dictionary-style access to execution log attributes.
|
||||
|
||||
Args:
|
||||
key: The attribute name to access.
|
||||
|
||||
Returns:
|
||||
The value of the requested attribute.
|
||||
"""
|
||||
return getattr(self, key)
|
||||
|
||||
|
||||
"""Manages storage and retrieval of task outputs."""
|
||||
|
||||
|
||||
class TaskOutputStorageHandler:
|
||||
"""Manages storage and retrieval of task outputs.
|
||||
|
||||
This handler provides an interface to persist and retrieve task execution
|
||||
results, supporting features like replay and audit trails.
|
||||
|
||||
Attributes:
|
||||
storage: The underlying SQLite storage implementation.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the task output storage handler."""
|
||||
self.storage = KickoffTaskOutputsSQLiteStorage()
|
||||
|
||||
def update(self, task_index: int, log: Dict[str, Any]):
|
||||
def update(self, task_index: int, log: dict[str, Any]) -> None:
|
||||
"""Update an existing task output in storage.
|
||||
|
||||
Args:
|
||||
task_index: The index of the task to update.
|
||||
log: Dictionary containing task execution details.
|
||||
|
||||
Raises:
|
||||
ValueError: If no saved outputs exist.
|
||||
"""
|
||||
saved_outputs = self.load()
|
||||
if saved_outputs is None:
|
||||
raise ValueError("Logs cannot be None")
|
||||
@@ -56,16 +94,31 @@ class TaskOutputStorageHandler:
|
||||
def add(
|
||||
self,
|
||||
task: Task,
|
||||
output: Dict[str, Any],
|
||||
output: dict[str, Any],
|
||||
task_index: int,
|
||||
inputs: Dict[str, Any] | None = None,
|
||||
inputs: dict[str, Any] | None = None,
|
||||
was_replayed: bool = False,
|
||||
):
|
||||
) -> None:
|
||||
"""Add a new task output to storage.
|
||||
|
||||
Args:
|
||||
task: The task that was executed.
|
||||
output: The output produced by the task.
|
||||
task_index: The position of the task in execution sequence.
|
||||
inputs: Optional input parameters for the task.
|
||||
was_replayed: Whether this is a replayed execution.
|
||||
"""
|
||||
inputs = inputs or {}
|
||||
self.storage.add(task, output, task_index, was_replayed, inputs)
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
"""Clear all stored task outputs."""
|
||||
self.storage.delete_all()
|
||||
|
||||
def load(self) -> Optional[List[Dict[str, Any]]]:
|
||||
def load(self) -> list[dict[str, Any]] | None:
|
||||
"""Load all stored task outputs.
|
||||
|
||||
Returns:
|
||||
List of task output dictionaries, or None if no outputs exist.
|
||||
"""
|
||||
return self.storage.load()
|
||||
|
||||
@@ -1,5 +1,11 @@
|
||||
"""Token counting callback handler for LLM interactions.
|
||||
|
||||
This module provides a callback handler that tracks token usage
|
||||
for LLM API calls through the litellm library.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.types.utils import Usage
|
||||
@@ -8,16 +14,38 @@ from crewai.agents.agent_builder.utilities.base_token_process import TokenProces
|
||||
|
||||
|
||||
class TokenCalcHandler(CustomLogger):
|
||||
def __init__(self, token_cost_process: Optional[TokenProcess]):
|
||||
"""Handler for calculating and tracking token usage in LLM calls.
|
||||
|
||||
This handler integrates with litellm's logging system to track
|
||||
prompt tokens, completion tokens, and cached tokens across requests.
|
||||
|
||||
Attributes:
|
||||
token_cost_process: The token process tracker to accumulate usage metrics.
|
||||
"""
|
||||
|
||||
def __init__(self, token_cost_process: TokenProcess | None) -> None:
|
||||
"""Initialize the token calculation handler.
|
||||
|
||||
Args:
|
||||
token_cost_process: Optional token process tracker for accumulating metrics.
|
||||
"""
|
||||
self.token_cost_process = token_cost_process
|
||||
|
||||
def log_success_event(
|
||||
self,
|
||||
kwargs: Dict[str, Any],
|
||||
response_obj: Dict[str, Any],
|
||||
kwargs: dict[str, Any],
|
||||
response_obj: dict[str, Any],
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
) -> None:
|
||||
"""Log successful LLM API call and track token usage.
|
||||
|
||||
Args:
|
||||
kwargs: The arguments passed to the LLM call.
|
||||
response_obj: The response object from the LLM API.
|
||||
start_time: The timestamp when the call started.
|
||||
end_time: The timestamp when the call completed.
|
||||
"""
|
||||
if self.token_cost_process is None:
|
||||
return
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
# ruff: noqa: S101
|
||||
# mypy: ignore-errors
|
||||
from collections import defaultdict
|
||||
from typing import cast
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
Reference in New Issue
Block a user