Compare commits

...

3 Commits

Author SHA1 Message Date
Devin AI
bfb578d506 fix: Add proper null checks for logger calls and improve type safety in LLM class
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-01-01 21:54:49 +00:00
Devin AI
5d3c34b3ea fix: Improve type annotations across multiple files
- Replace Optional[set[str]] with Union[set[str], None] in json methods
- Fix add_nodes_to_network call parameters in flow_visualizer
- Add __base__=BaseModel to create_model call in structured_tool
- Clean up imports in provider.py

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-01-01 21:29:15 +00:00
Devin AI
8ec2eb7d72 fix(agent): improve token tracking and logging functionality
- Add proper debug, info, warning, and error methods to Logger class
- Ensure warnings and errors are always shown regardless of verbose mode
- Fix token process initialization and tracking in Agent class
- Update TokenProcess import to use correct class from agent_builder utilities

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-01-01 20:59:39 +00:00
9 changed files with 490 additions and 117 deletions

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import os
import shutil
import subprocess
@@ -21,11 +23,11 @@ from crewai.tools.base_tool import Tool
from crewai.utilities import Converter, Prompts
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
from crewai.utilities.converter import generate_model_description
from crewai.utilities.token_counter_callback import TokenCalcHandler
from crewai.utilities.training_handler import CrewTrainingHandler
from crewai.utilities.logger import Logger
from crewai.utilities.rpm_controller import RPMController
from crewai.utilities.token_process import TokenProcess
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
from crewai.utilities.token_counter_callback import TokenCalcHandler
from crewai.utilities.training_handler import CrewTrainingHandler
agentops = None
@@ -132,7 +134,6 @@ class Agent(BaseAgent):
"verbose": verbose,
"max_rpm": max_rpm,
"tools": processed_tools,
"llm": processed_llm,
"max_iter": max_iter if max_iter is not None else 25,
"function_calling_llm": function_calling_llm,
"step_callback": step_callback,
@@ -148,11 +149,14 @@ class Agent(BaseAgent):
self._original_goal = goal
self._original_backstory = backstory
# Set LLM after base initialization to ensure proper model handling
self.llm = processed_llm
# Initialize private attributes
self._logger = Logger(verbose=self.verbose)
if self.max_rpm:
self._rpm_controller = RPMController(max_rpm=self.max_rpm, logger=self._logger)
self._token_process = TokenProcess() # type: ignore # Known type mismatch between utilities and agent_builder
self._token_process = TokenProcess()
_times_executed: int = PrivateAttr(default=0)
max_execution_time: Optional[int] = Field(
@@ -530,6 +534,32 @@ class Agent(BaseAgent):
self.response_template.split("{{ .Response }}")[1].strip()
)
# Ensure LLM is initialized with proper error handling
try:
if not self.llm:
self.llm = LLM(model="gpt-4")
if hasattr(self, '_logger'):
self._logger.debug("Initialized default LLM with gpt-4 model")
except Exception as e:
if hasattr(self, '_logger'):
self._logger.error(f"Failed to initialize LLM: {str(e)}")
raise
# Create token callback with proper error handling
try:
token_callback = None
if hasattr(self, '_token_process'):
token_callback = TokenCalcHandler(self._token_process)
except Exception as e:
if hasattr(self, '_logger'):
self._logger.warning(f"Failed to create token callback: {str(e)}")
token_callback = None
# Initialize callbacks list
executor_callbacks = []
if token_callback:
executor_callbacks.append(token_callback)
self.agent_executor = CrewAgentExecutor(
llm=self.llm,
task=task,
@@ -547,9 +577,9 @@ class Agent(BaseAgent):
function_calling_llm=self.function_calling_llm,
respect_context_window=self.respect_context_window,
request_within_rpm_limit=(
self._rpm_controller.check_or_wait if self._rpm_controller else None
self._rpm_controller.check_or_wait if (hasattr(self, '_rpm_controller') and self._rpm_controller is not None) else None
),
callbacks=[TokenCalcHandler(self._token_process)],
callbacks=executor_callbacks,
)
def get_delegation_tools(self, agents: List[BaseAgent]):

View File

@@ -5,6 +5,7 @@ from pathlib import Path
import click
import requests
from typing import Any
from crewai.cli.constants import JSON_URL, MODELS, PROVIDERS
@@ -192,7 +193,7 @@ def download_data(response):
data_chunks = []
with click.progressbar(
length=total_size, label="Downloading", show_pos=True
) as progress_bar:
) as progress_bar: # type: Any
for chunk in response.iter_content(block_size):
if chunk:
data_chunks.append(chunk)

View File

@@ -1,5 +1,5 @@
import json
from typing import Any, Dict, Optional
from typing import Any, Callable, Dict, Optional, Union
from pydantic import BaseModel, Field
@@ -23,14 +23,25 @@ class CrewOutput(BaseModel):
)
token_usage: UsageMetrics = Field(description="Processed token summary", default={})
@property
def json(self) -> Optional[str]:
def json(
self,
*,
include: Union[set[str], None] = None,
exclude: Union[set[str], None] = None,
by_alias: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
encoder: Optional[Callable[[Any], Any]] = None,
models_as_dict: bool = True,
**dumps_kwargs: Any,
) -> str:
if self.tasks_output[-1].output_format != OutputFormat.JSON:
raise ValueError(
"No JSON output found in the final task. Please make sure to set the output_json property in the final task in your crew."
)
return json.dumps(self.json_dict)
return json.dumps(self.json_dict, default=encoder, **dumps_kwargs)
def to_dict(self) -> Dict[str, Any]:
"""Convert json_output and pydantic_output to a dictionary."""

View File

@@ -106,7 +106,12 @@ class FlowPlot:
# Add nodes to the network
try:
add_nodes_to_network(net, self.flow, node_positions, self.node_styles)
add_nodes_to_network(
net,
flow=self.flow,
pos=node_positions,
node_styles=self.node_styles
)
except Exception as e:
raise RuntimeError(f"Failed to add nodes to network: {str(e)}")

View File

@@ -6,6 +6,8 @@ import warnings
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
import litellm
@@ -93,10 +95,33 @@ def suppress_warnings():
sys.stderr = old_stderr
class LLM:
class LLM(BaseModel):
model: str = "gpt-4" # Set default model
timeout: Optional[Union[float, int]] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
n: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
max_completion_tokens: Optional[int] = None
max_tokens: Optional[int] = None
presence_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None
logit_bias: Optional[Dict[int, float]] = None
response_format: Optional[Dict[str, Any]] = None
seed: Optional[int] = None
logprobs: Optional[bool] = None
top_logprobs: Optional[int] = None
base_url: Optional[str] = None
api_version: Optional[str] = None
api_key: Optional[str] = None
callbacks: Optional[List[Any]] = None
context_window_size: Optional[int] = None
kwargs: Dict[str, Any] = Field(default_factory=dict)
logger: Optional[logging.Logger] = Field(default_factory=lambda: logging.getLogger(__name__))
def __init__(
self,
model: Union[str, 'LLM'],
model: Optional[Union[str, 'LLM']] = "gpt-4",
timeout: Optional[Union[float, int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
@@ -114,12 +139,103 @@ class LLM:
base_url: Optional[str] = None,
api_version: Optional[str] = None,
api_key: Optional[str] = None,
callbacks: List[Any] = [],
**kwargs,
):
callbacks: Optional[List[Any]] = None,
context_window_size: Optional[int] = None,
**kwargs: Any,
) -> None:
# Initialize with default values
init_dict = {
"model": model if isinstance(model, str) else "gpt-4",
"timeout": timeout,
"temperature": temperature,
"top_p": top_p,
"n": n,
"stop": stop,
"max_completion_tokens": max_completion_tokens,
"max_tokens": max_tokens,
"presence_penalty": presence_penalty,
"frequency_penalty": frequency_penalty,
"logit_bias": logit_bias,
"response_format": response_format,
"seed": seed,
"logprobs": logprobs,
"top_logprobs": top_logprobs,
"base_url": base_url,
"api_version": api_version,
"api_key": api_key,
"callbacks": callbacks,
"context_window_size": context_window_size,
"kwargs": kwargs,
}
super().__init__(**init_dict)
# Initialize model with default value
self.model = "gpt-4" # Default fallback
# Extract and validate model name
if isinstance(model, LLM):
# Extract and validate model name from LLM instance
if hasattr(model, 'model'):
if isinstance(model.model, str):
self.model = model.model
else:
# Try to extract string model name from nested LLM
if isinstance(model.model, LLM):
self.model = str(model.model.model) if hasattr(model.model, 'model') else "gpt-4"
else:
self.model = "gpt-4"
if self.logger:
self.logger.warning("Nested LLM model is not a string, using default: gpt-4")
else:
self.model = "gpt-4"
if self.logger:
self.logger.warning("LLM instance has no model attribute, using default: gpt-4")
else:
# Extract and validate model name for non-LLM instances
if not isinstance(model, str):
if self.logger:
self.logger.debug(f"Model is not a string, attempting to extract name. Type: {type(model)}")
if model is not None:
if hasattr(model, 'model_name'):
model_name = getattr(model, 'model_name', None)
self.model = str(model_name) if model_name is not None else "gpt-4"
elif hasattr(model, 'model'):
model_attr = getattr(model, 'model', None)
self.model = str(model_attr) if model_attr is not None else "gpt-4"
elif hasattr(model, '_model_name'):
model_name = getattr(model, '_model_name', None)
self.model = str(model_name) if model_name is not None else "gpt-4"
else:
self.model = "gpt-4" # Default fallback
if self.logger:
self.logger.warning(f"Could not extract model name from {type(model)}, using default: {self.model}")
else:
self.model = "gpt-4" # Default fallback for None
if self.logger:
self.logger.warning("Model is None, using default: gpt-4")
else:
self.model = str(model) # Ensure it's a string
# If model is an LLM instance, copy its configuration
if isinstance(model, LLM):
self.model = model.model
# Extract and validate model name first
if hasattr(model, 'model'):
if isinstance(model.model, str):
self.model = model.model
else:
# Try to extract string model name from nested LLM
if isinstance(model.model, LLM):
self.model = str(model.model.model) if hasattr(model.model, 'model') else "gpt-4"
else:
self.model = "gpt-4"
if self.logger:
self.logger.warning("Nested LLM model is not a string, using default: gpt-4")
else:
self.model = "gpt-4"
if self.logger:
self.logger.warning("LLM instance has no model attribute, using default: gpt-4")
# Copy other configuration
self.timeout = model.timeout
self.temperature = model.temperature
self.top_p = model.top_p
@@ -140,8 +256,44 @@ class LLM:
self.callbacks = model.callbacks
self.context_window_size = model.context_window_size
self.kwargs = model.kwargs
# Final validation of model name
if not isinstance(self.model, str):
self.model = "gpt-4"
if self.logger:
self.logger.warning("Model name is still not a string after LLM copy, using default: gpt-4")
else:
self.model = model
# Extract and validate model name for non-LLM instances
if not isinstance(model, str):
if self.logger:
self.logger.debug(f"Model is not a string, attempting to extract name. Type: {type(model)}")
if model is not None:
if hasattr(model, 'model_name'):
model_name = getattr(model, 'model_name', None)
self.model = str(model_name) if model_name is not None else "gpt-4"
elif hasattr(model, 'model'):
model_attr = getattr(model, 'model', None)
self.model = str(model_attr) if model_attr is not None else "gpt-4"
elif hasattr(model, '_model_name'):
model_name = getattr(model, '_model_name', None)
self.model = str(model_name) if model_name is not None else "gpt-4"
else:
self.model = "gpt-4" # Default fallback
if self.logger:
self.logger.warning(f"Could not extract model name from {type(model)}, using default: {self.model}")
else:
self.model = "gpt-4" # Default fallback for None
if self.logger:
self.logger.warning("Model is None, using default: gpt-4")
else:
self.model = str(model) # Ensure it's a string
# Final validation
if not isinstance(self.model, str):
self.model = "gpt-4"
if self.logger:
self.logger.warning("Model name is still not a string after extraction, using default: gpt-4")
self.timeout = timeout
self.temperature = temperature
self.top_p = top_p
@@ -163,69 +315,155 @@ class LLM:
self.context_window_size = 0
self.kwargs = kwargs
# Ensure model is a string after initialization
if not isinstance(self.model, str):
self.model = "gpt-4"
self.logger.warning(f"Model is still not a string after initialization, using default: {self.model}")
litellm.drop_params = True
self.set_callbacks(callbacks)
self.set_env_callbacks()
def call(self, messages: List[Dict[str, str]], callbacks: List[Any] = []) -> str:
def call(
self,
messages: List[Dict[str, str]],
callbacks: Optional[List[Any]] = None
) -> str:
with suppress_warnings():
if callbacks and len(callbacks) > 0:
self.set_callbacks(callbacks)
try:
# Ensure model is a string and set default
model_name = "gpt-4" # Default model
# Extract model name from self.model
current = self.model
while current is not None:
if isinstance(current, str):
model_name = current
break
elif isinstance(current, LLM):
current = current.model
elif hasattr(current, "model"):
current = getattr(current, "model")
else:
break
# Store original model to restore later
original_model = self.model
try:
# Ensure model is a string before making the call
if not isinstance(self.model, str):
if self.logger:
self.logger.warning(f"Model is not a string in call method: {type(self.model)}. Attempting to convert...")
if isinstance(self.model, LLM):
self.model = self.model.model if isinstance(self.model.model, str) else "gpt-4"
elif hasattr(self.model, 'model_name'):
self.model = str(self.model.model_name)
elif hasattr(self.model, 'model'):
if isinstance(self.model.model, str):
self.model = str(self.model.model)
elif hasattr(self.model.model, 'model_name'):
self.model = str(self.model.model.model_name)
else:
self.model = "gpt-4"
if self.logger:
self.logger.warning("Could not extract model name from nested model object, using default: gpt-4")
else:
self.model = "gpt-4"
if self.logger:
self.logger.warning("Could not extract model name, using default: gpt-4")
if self.logger:
self.logger.debug(f"Using model: {self.model} (type: {type(self.model)}) for LiteLLM call")
# Create base params with validated model name
# Extract model name string
model_name = None
if isinstance(self.model, str):
model_name = self.model
elif hasattr(self.model, 'model_name'):
model_name = str(self.model.model_name)
elif hasattr(self.model, 'model'):
if isinstance(self.model.model, str):
model_name = str(self.model.model)
elif hasattr(self.model.model, 'model_name'):
model_name = str(self.model.model.model_name)
if not model_name:
model_name = "gpt-4"
if self.logger:
self.logger.warning("Could not extract model name, using default: gpt-4")
# Set parameters for litellm
# Build base params dict with required fields
params = {
"model": model_name,
"custom_llm_provider": "openai",
"messages": messages,
"stream": False # Always set stream to False
"stream": False,
"api_key": self.api_key or os.getenv("OPENAI_API_KEY"),
"api_base": self.base_url,
"api_version": self.api_version,
}
# Add API configuration
if self.logger:
self.logger.debug(f"Using model parameters: {params}")
# Add API configuration if available
api_key = self.api_key or os.getenv("OPENAI_API_KEY")
if api_key:
params["api_key"] = api_key
# Define optional parameters
optional_params = {
"timeout": self.timeout,
"temperature": self.temperature,
"top_p": self.top_p,
"n": self.n,
"stop": self.stop,
"max_tokens": self.max_tokens or self.max_completion_tokens,
"presence_penalty": self.presence_penalty,
"frequency_penalty": self.frequency_penalty,
"logit_bias": self.logit_bias,
"response_format": self.response_format,
"seed": self.seed,
"logprobs": self.logprobs,
"top_logprobs": self.top_logprobs,
}
# Try to get supported parameters for the model
try:
supported_params = get_supported_openai_params(self.model)
optional_params = {}
if supported_params:
param_mapping = {
"timeout": self.timeout,
"temperature": self.temperature,
"top_p": self.top_p,
"n": self.n,
"stop": self.stop,
"max_tokens": self.max_tokens or self.max_completion_tokens,
"presence_penalty": self.presence_penalty,
"frequency_penalty": self.frequency_penalty,
"logit_bias": self.logit_bias,
"response_format": self.response_format,
"seed": self.seed,
"logprobs": self.logprobs,
"top_logprobs": self.top_logprobs
}
# Only add parameters that are supported and not None
optional_params = {
k: v for k, v in param_mapping.items()
if k in supported_params and v is not None
}
if "logprobs" in supported_params and self.logprobs is not None:
optional_params["logprobs"] = self.logprobs
if "top_logprobs" in supported_params and self.top_logprobs is not None:
optional_params["top_logprobs"] = self.top_logprobs
except Exception as e:
if self.logger:
self.logger.error(f"Failed to get supported params for model {self.model}: {str(e)}")
# If we can't get supported params, just add non-None parameters
param_mapping = {
"timeout": self.timeout,
"temperature": self.temperature,
"top_p": self.top_p,
"n": self.n,
"stop": self.stop,
"max_tokens": self.max_tokens or self.max_completion_tokens,
"presence_penalty": self.presence_penalty,
"frequency_penalty": self.frequency_penalty,
"logit_bias": self.logit_bias,
"response_format": self.response_format,
"seed": self.seed,
"logprobs": self.logprobs,
"top_logprobs": self.top_logprobs
}
optional_params = {k: v for k, v in param_mapping.items() if v is not None}
# Update params with optional parameters
params.update(optional_params)
# Add API endpoint configuration if available
if self.base_url:
optional_params["api_base"] = self.base_url
params["api_base"] = self.base_url
if self.api_version:
optional_params["api_version"] = self.api_version
params["api_version"] = self.api_version
# Final validation of model parameter
if not isinstance(params["model"], str):
if self.logger:
self.logger.error(f"Model is still not a string after all conversions: {type(params['model'])}")
params["model"] = "gpt-4"
# Update params with non-None optional parameters
params.update({k: v for k, v in optional_params.items() if v is not None})
@@ -238,21 +476,38 @@ class LLM:
params = {k: v for k, v in params.items() if v is not None}
response = litellm.completion(**params)
return response["choices"][0]["message"]["content"]
content = response["choices"][0]["message"]["content"]
# Extract usage metrics
usage = response.get("usage", {})
if callbacks:
for callback in callbacks:
if hasattr(callback, "update_token_usage"):
callback.update_token_usage(usage)
return content
except Exception as e:
if not LLMContextLengthExceededException(
str(e)
)._is_context_limit_error(str(e)):
logging.error(f"LiteLLM call failed: {str(e)}")
raise # Re-raise the exception after logging
finally:
# Always restore the original model object
self.model = original_model
def supports_function_calling(self) -> bool:
"""Check if the LLM supports function calling.
Returns:
bool: True if the model supports function calling, False otherwise
"""
try:
params = get_supported_openai_params(model=self.model)
return "response_format" in params
except Exception as e:
logging.error(f"Failed to get supported params: {str(e)}")
if self.logger:
self.logger.error(f"Failed to get supported params: {str(e)}")
return False
def supports_stop_words(self) -> bool:
@@ -264,33 +519,47 @@ class LLM:
params = get_supported_openai_params(model=self.model)
return "stop" in params
except Exception as e:
logging.error(f"Failed to get supported params: {str(e)}")
if self.logger:
self.logger.error(f"Failed to get supported params: {str(e)}")
return False
def get_context_window_size(self) -> int:
"""Get the context window size for the current model.
Returns:
int: The context window size in tokens
"""
# Only using 75% of the context window size to avoid cutting the message in the middle
if self.context_window_size != 0:
return self.context_window_size
if self.context_window_size is not None and self.context_window_size != 0:
return int(self.context_window_size)
self.context_window_size = int(
DEFAULT_CONTEXT_WINDOW_SIZE * CONTEXT_WINDOW_USAGE_RATIO
)
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
if self.model.startswith(key):
self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO)
window_size = DEFAULT_CONTEXT_WINDOW_SIZE
if isinstance(self.model, str):
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
if self.model.startswith(key):
window_size = value
break
self.context_window_size = int(window_size * CONTEXT_WINDOW_USAGE_RATIO)
return self.context_window_size
def set_callbacks(self, callbacks: List[Any]):
callback_types = [type(callback) for callback in callbacks]
for callback in litellm.success_callback[:]:
if type(callback) in callback_types:
litellm.success_callback.remove(callback)
def set_callbacks(self, callbacks: Optional[List[Any]] = None) -> None:
"""Set callbacks for the LLM.
Args:
callbacks: Optional list of callback functions. If None, no callbacks will be set.
"""
if callbacks is not None:
callback_types = [type(callback) for callback in callbacks]
for callback in litellm.success_callback[:]:
if type(callback) in callback_types:
litellm.success_callback.remove(callback)
for callback in litellm._async_success_callback[:]:
if type(callback) in callback_types:
litellm._async_success_callback.remove(callback)
for callback in litellm._async_success_callback[:]:
if type(callback) in callback_types:
litellm._async_success_callback.remove(callback)
litellm.callbacks = callbacks
litellm.callbacks = callbacks
def set_env_callbacks(self):
"""

View File

@@ -1,5 +1,5 @@
import json
from typing import Any, Dict, Optional
from typing import Any, Callable, Dict, Optional, Union
from pydantic import BaseModel, Field, model_validator
@@ -34,8 +34,19 @@ class TaskOutput(BaseModel):
self.summary = f"{excerpt}..."
return self
@property
def json(self) -> Optional[str]:
def json(
self,
*,
include: Union[set[str], None] = None,
exclude: Union[set[str], None] = None,
by_alias: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
encoder: Optional[Callable[[Any], Any]] = None,
models_as_dict: bool = True,
**dumps_kwargs: Any,
) -> str:
if self.output_format != OutputFormat.JSON:
raise ValueError(
"""
@@ -45,7 +56,7 @@ class TaskOutput(BaseModel):
"""
)
return json.dumps(self.json_dict)
return json.dumps(self.json_dict, default=encoder, **dumps_kwargs)
def to_dict(self) -> Dict[str, Any]:
"""Convert json_output and pydantic_output to a dictionary."""

View File

@@ -142,7 +142,12 @@ class CrewStructuredTool:
# Create model
schema_name = f"{name.title()}Schema"
return create_model(schema_name, **fields)
return create_model(
schema_name,
__base__=BaseModel,
__config__=None,
**{k: v for k, v in fields.items()}
)
def _validate_function_signature(self) -> None:
"""Validate that the function signature matches the args schema."""
@@ -170,7 +175,7 @@ class CrewStructuredTool:
f"not found in args_schema"
)
def _parse_args(self, raw_args: Union[str, dict]) -> dict:
def _parse_args(self, raw_args: Union[str, dict[str, Any]]) -> dict[str, Any]:
"""Parse and validate the input arguments against the schema.
Args:
@@ -178,6 +183,9 @@ class CrewStructuredTool:
Returns:
The validated arguments as a dictionary
Raises:
ValueError: If the arguments cannot be parsed or fail validation
"""
if isinstance(raw_args, str):
try:
@@ -195,8 +203,8 @@ class CrewStructuredTool:
async def ainvoke(
self,
input: Union[str, dict],
config: Optional[dict] = None,
input: Union[str, dict[str, Any]],
config: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Asynchronously invoke the tool.
@@ -229,7 +237,10 @@ class CrewStructuredTool:
return self.invoke(input_dict)
def invoke(
self, input: Union[str, dict], config: Optional[dict] = None, **kwargs: Any
self,
input: Union[str, dict[str, Any]],
config: Optional[dict[str, Any]] = None,
**kwargs: Any
) -> Any:
"""Main method for tool execution."""
parsed_args = self._parse_args(input)

View File

@@ -10,8 +10,24 @@ class Logger(BaseModel):
_printer: Printer = PrivateAttr(default_factory=Printer)
def log(self, level, message, color="bold_yellow"):
if self.verbose:
if self.verbose or level.upper() in ["WARNING", "ERROR"]:
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self._printer.print(
f"\n[{timestamp}][{level.upper()}]: {message}", color=color
)
def debug(self, message: str) -> None:
"""Log a debug message if verbose is enabled."""
self.log("debug", message, color="bold_blue")
def info(self, message: str) -> None:
"""Log an info message if verbose is enabled."""
self.log("info", message, color="bold_green")
def warning(self, message: str) -> None:
"""Log a warning message."""
self.log("warning", message, color="bold_yellow")
def error(self, message: str) -> None:
"""Log an error message."""
self.log("error", message, color="bold_red")

View File

@@ -1,44 +1,63 @@
"""Token processing utility for tracking and managing token usage."""
from __future__ import annotations
from typing import Any, Dict, List, Optional, Union
from crewai.types.usage_metrics import UsageMetrics
class TokenProcess:
"""Handles token processing and tracking for agents."""
def __init__(self):
"""Initialize the token processor."""
self._token_count = 0
self._last_tokens = 0
self._total_tokens = 0
self._prompt_tokens = 0
self._completion_tokens = 0
self._cached_prompt_tokens = 0
self._successful_requests = 0
def update_token_count(self, count: int) -> None:
"""Update the token count.
def sum_prompt_tokens(self, count: int) -> None:
"""Add to prompt token count.
Args:
count (int): Number of tokens to add to the count
count (int): Number of prompt tokens to add
"""
self._token_count += count
self._last_tokens = count
self._prompt_tokens += count
self._total_tokens += count
def get_token_count(self) -> int:
"""Get the total token count.
def sum_completion_tokens(self, count: int) -> None:
"""Add to completion token count.
Returns:
int: Total number of tokens processed
Args:
count (int): Number of completion tokens to add
"""
return self._token_count
self._completion_tokens += count
self._total_tokens += count
def get_last_tokens(self) -> int:
"""Get the number of tokens from the last update.
def sum_cached_prompt_tokens(self, count: int) -> None:
"""Add to cached prompt token count.
Returns:
int: Number of tokens from last update
Args:
count (int): Number of cached prompt tokens to add
"""
return self._last_tokens
self._cached_prompt_tokens += count
def sum_successful_requests(self, count: int) -> None:
"""Add to successful requests count.
Args:
count (int): Number of successful requests to add
"""
self._successful_requests += count
def reset(self) -> None:
"""Reset the token counts to zero."""
self._token_count = 0
self._last_tokens = 0
"""Reset all token counts to zero."""
self._total_tokens = 0
self._prompt_tokens = 0
self._completion_tokens = 0
self._cached_prompt_tokens = 0
self._successful_requests = 0
def get_summary(self) -> UsageMetrics:
"""Get a summary of token usage.
@@ -47,9 +66,9 @@ class TokenProcess:
UsageMetrics: Object containing token usage metrics
"""
return UsageMetrics(
total_tokens=self._token_count,
prompt_tokens=0, # These will be set by the LLM handler
cached_prompt_tokens=0,
completion_tokens=self._last_tokens,
successful_requests=1 if self._token_count > 0 else 0
total_tokens=self._total_tokens,
prompt_tokens=self._prompt_tokens,
cached_prompt_tokens=self._cached_prompt_tokens,
completion_tokens=self._completion_tokens,
successful_requests=self._successful_requests
)