style: resolve type checks issues

This commit is contained in:
Lucas Gomide
2026-03-31 12:06:21 -03:00
parent 8521aa1d8c
commit 1bb5a46493
4 changed files with 27 additions and 22 deletions

View File

@@ -1990,7 +1990,8 @@ class LLM(BaseLLM):
if isinstance(usage, dict):
return usage
if hasattr(usage, "model_dump"):
return usage.model_dump()
result: dict[str, Any] = usage.model_dump()
return result
if hasattr(usage, "__dict__"):
return {k: v for k, v in vars(usage).items() if not k.startswith("_")}
return None

View File

@@ -799,7 +799,7 @@ class AzureCompletion(BaseLLM):
self,
full_response: str,
tool_calls: dict[int, dict[str, Any]],
usage_data: dict[str, int],
usage_data: dict[str, Any] | None,
params: AzureCompletionParams,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
@@ -811,7 +811,7 @@ class AzureCompletion(BaseLLM):
Args:
full_response: The complete streamed response content
tool_calls: Dictionary of tool calls accumulated during streaming
usage_data: Token usage data from the stream
usage_data: Token usage data from the stream, or None if unavailable
params: Completion parameters containing messages
available_functions: Available functions for tool calling
from_task: Task that initiated the call
@@ -821,7 +821,8 @@ class AzureCompletion(BaseLLM):
Returns:
Final response content after processing, or structured output
"""
self._track_token_usage_internal(usage_data)
if usage_data:
self._track_token_usage_internal(usage_data)
# Handle structured output validation
if response_model and self.is_openai_model:
@@ -910,7 +911,7 @@ class AzureCompletion(BaseLLM):
full_response = ""
tool_calls: dict[int, dict[str, Any]] = {}
usage_data = {"total_tokens": 0}
usage_data: dict[str, Any] | None = None
for update in self._client.complete(**params):
if isinstance(update, StreamingChatCompletionsUpdate):
if update.usage:
@@ -976,7 +977,7 @@ class AzureCompletion(BaseLLM):
full_response = ""
tool_calls: dict[int, dict[str, Any]] = {}
usage_data = {"total_tokens": 0}
usage_data: dict[str, Any] | None = None
stream = await self._async_client.complete(**params)
async for update in stream:

View File

@@ -911,10 +911,10 @@ class GeminiCompletion(BaseLLM):
chunk: GenerateContentResponse,
full_response: str,
function_calls: dict[int, dict[str, Any]],
usage_data: dict[str, int],
usage_data: dict[str, int] | None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> tuple[str, dict[int, dict[str, Any]], dict[str, int]]:
) -> tuple[str, dict[int, dict[str, Any]], dict[str, int] | None]:
"""Process a single streaming chunk.
Args:
@@ -990,7 +990,7 @@ class GeminiCompletion(BaseLLM):
self,
full_response: str,
function_calls: dict[int, dict[str, Any]],
usage_data: dict[str, int],
usage_data: dict[str, int] | None,
contents: list[types.Content],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
@@ -1002,7 +1002,7 @@ class GeminiCompletion(BaseLLM):
Args:
full_response: The complete streamed response content
function_calls: Dictionary of function calls accumulated during streaming
usage_data: Token usage data from the stream
usage_data: Token usage data from the stream, or None if unavailable
contents: Original contents for event conversion
available_functions: Available functions for function calling
from_task: Task that initiated the call
@@ -1012,7 +1012,8 @@ class GeminiCompletion(BaseLLM):
Returns:
Final response content after processing
"""
self._track_token_usage_internal(usage_data)
if usage_data:
self._track_token_usage_internal(usage_data)
if response_model and function_calls:
for call_data in function_calls.values():
@@ -1147,7 +1148,7 @@ class GeminiCompletion(BaseLLM):
"""Handle streaming content generation."""
full_response = ""
function_calls: dict[int, dict[str, Any]] = {}
usage_data = {"total_tokens": 0}
usage_data: dict[str, int] | None = None
# The API accepts list[Content] but mypy is overly strict about variance
contents_for_api: Any = contents
@@ -1226,7 +1227,7 @@ class GeminiCompletion(BaseLLM):
"""Handle async streaming content generation."""
full_response = ""
function_calls: dict[int, dict[str, Any]] = {}
usage_data = {"total_tokens": 0}
usage_data: dict[str, int] | None = None
# The API accepts list[Content] but mypy is overly strict about variance
contents_for_api: Any = contents

View File

@@ -1053,7 +1053,7 @@ class OpenAICompletion(BaseLLM):
full_response = ""
function_calls: list[dict[str, Any]] = []
final_response: Response | None = None
usage: dict[str, Any] = {"total_tokens": 0}
usage: dict[str, Any] | None = None
stream = self._client.responses.create(**params)
response_id_stream = None
@@ -1181,7 +1181,7 @@ class OpenAICompletion(BaseLLM):
full_response = ""
function_calls: list[dict[str, Any]] = []
final_response: Response | None = None
usage: dict[str, Any] = {"total_tokens": 0}
usage: dict[str, Any] | None = None
stream = await self._async_client.responses.create(**params)
response_id_stream = None
@@ -1713,7 +1713,7 @@ class OpenAICompletion(BaseLLM):
self,
full_response: str,
tool_calls: dict[int, dict[str, Any]],
usage_data: dict[str, int],
usage_data: dict[str, Any] | None,
params: dict[str, Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
@@ -1724,7 +1724,7 @@ class OpenAICompletion(BaseLLM):
Args:
full_response: The accumulated text response from the stream.
tool_calls: Accumulated tool calls from the stream, keyed by index.
usage_data: Token usage data from the stream.
usage_data: Token usage data from the stream, or None if unavailable.
params: The completion parameters containing messages.
available_functions: Available functions for tool calling.
from_task: Task that initiated the call.
@@ -1735,7 +1735,8 @@ class OpenAICompletion(BaseLLM):
tool execution result when available_functions is provided,
or the text response string.
"""
self._track_token_usage_internal(usage_data)
if usage_data:
self._track_token_usage_internal(usage_data)
if tool_calls and not available_functions:
tool_calls_list = [
@@ -1864,7 +1865,7 @@ class OpenAICompletion(BaseLLM):
self._client.chat.completions.create(**params)
)
usage_data = {"total_tokens": 0}
usage_data: dict[str, Any] | None = None
for completion_chunk in completion_stream:
response_id_stream = (
@@ -2106,7 +2107,7 @@ class OpenAICompletion(BaseLLM):
] = await self._async_client.chat.completions.create(**params)
accumulated_content = ""
usage_data = {"total_tokens": 0}
usage_data: dict[str, Any] | None = None
async for chunk in completion_stream:
response_id_stream = chunk.id if hasattr(chunk, "id") else None
@@ -2129,7 +2130,8 @@ class OpenAICompletion(BaseLLM):
response_id=response_id_stream,
)
self._track_token_usage_internal(usage_data)
if usage_data:
self._track_token_usage_internal(usage_data)
try:
parsed_object = response_model.model_validate_json(accumulated_content)
@@ -2160,7 +2162,7 @@ class OpenAICompletion(BaseLLM):
ChatCompletionChunk
] = await self._async_client.chat.completions.create(**params)
usage_data = {"total_tokens": 0}
usage_data = None
async for chunk in stream:
response_id_stream = chunk.id if hasattr(chunk, "id") else None