feat: add response_format parameter to Azure and Gemini providers

This commit is contained in:
Greyson Lalonde
2026-01-26 08:42:06 -05:00
parent a32de6bdac
commit c9b240a86c
2 changed files with 37 additions and 24 deletions

View File

@@ -92,6 +92,7 @@ class AzureCompletion(BaseLLM):
stop: list[str] | None = None,
stream: bool = False,
interceptor: BaseInterceptor[Any, Any] | None = None,
response_format: type[BaseModel] | None = None,
**kwargs: Any,
):
"""Initialize Azure AI Inference chat completion client.
@@ -111,6 +112,9 @@ class AzureCompletion(BaseLLM):
stop: Stop sequences
stream: Enable streaming responses
interceptor: HTTP interceptor (not yet supported for Azure).
response_format: Pydantic model for structured output. Used as default when
response_model is not passed to call()/acall() methods.
Only works with OpenAI models deployed on Azure.
**kwargs: Additional parameters
"""
if interceptor is not None:
@@ -165,6 +169,7 @@ class AzureCompletion(BaseLLM):
self.presence_penalty = presence_penalty
self.max_tokens = max_tokens
self.stream = stream
self.response_format = response_format
self.is_openai_model = any(
prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"]
@@ -298,6 +303,7 @@ class AzureCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
)
effective_response_model = response_model or self.response_format
# Format messages for Azure
formatted_messages = self._format_messages_for_azure(messages)
@@ -307,7 +313,7 @@ class AzureCompletion(BaseLLM):
# Prepare completion parameters
completion_params = self._prepare_completion_params(
formatted_messages, tools, response_model
formatted_messages, tools, effective_response_model
)
# Handle streaming vs non-streaming
@@ -317,7 +323,7 @@ class AzureCompletion(BaseLLM):
available_functions,
from_task,
from_agent,
response_model,
effective_response_model,
)
return self._handle_completion(
@@ -325,7 +331,7 @@ class AzureCompletion(BaseLLM):
available_functions,
from_task,
from_agent,
response_model,
effective_response_model,
)
except Exception as e:
@@ -364,11 +370,12 @@ class AzureCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
)
effective_response_model = response_model or self.response_format
formatted_messages = self._format_messages_for_azure(messages)
completion_params = self._prepare_completion_params(
formatted_messages, tools, response_model
formatted_messages, tools, effective_response_model
)
if self.stream:
@@ -377,7 +384,7 @@ class AzureCompletion(BaseLLM):
available_functions,
from_task,
from_agent,
response_model,
effective_response_model,
)
return await self._ahandle_completion(
@@ -385,7 +392,7 @@ class AzureCompletion(BaseLLM):
available_functions,
from_task,
from_agent,
response_model,
effective_response_model,
)
except Exception as e:
@@ -726,7 +733,7 @@ class AzureCompletion(BaseLLM):
"""
if update.choices:
choice = update.choices[0]
response_id = update.id if hasattr(update,"id") else None
response_id = update.id if hasattr(update, "id") else None
if choice.delta and choice.delta.content:
content_delta = choice.delta.content
full_response += content_delta
@@ -734,7 +741,7 @@ class AzureCompletion(BaseLLM):
chunk=content_delta,
from_task=from_task,
from_agent=from_agent,
response_id=response_id
response_id=response_id,
)
if choice.delta and choice.delta.tool_calls:
@@ -769,7 +776,7 @@ class AzureCompletion(BaseLLM):
"index": idx,
},
call_type=LLMCallType.TOOL_CALL,
response_id=response_id
response_id=response_id,
)
return full_response

View File

@@ -56,6 +56,7 @@ class GeminiCompletion(BaseLLM):
client_params: dict[str, Any] | None = None,
interceptor: BaseInterceptor[Any, Any] | None = None,
use_vertexai: bool | None = None,
response_format: type[BaseModel] | None = None,
**kwargs: Any,
):
"""Initialize Google Gemini chat completion client.
@@ -86,6 +87,8 @@ class GeminiCompletion(BaseLLM):
- None (default): Check GOOGLE_GENAI_USE_VERTEXAI env var
When using Vertex AI with API key (Express mode), http_options with
api_version="v1" is automatically configured.
response_format: Pydantic model for structured output. Used as default when
response_model is not passed to call()/acall() methods.
**kwargs: Additional parameters
"""
if interceptor is not None:
@@ -121,6 +124,7 @@ class GeminiCompletion(BaseLLM):
self.safety_settings = safety_settings or {}
self.stop_sequences = stop_sequences or []
self.tools: list[dict[str, Any]] | None = None
self.response_format = response_format
# Model-specific settings
version_match = re.search(r"gemini-(\d+(?:\.\d+)?)", model.lower())
@@ -292,6 +296,7 @@ class GeminiCompletion(BaseLLM):
from_agent=from_agent,
)
self.tools = tools
effective_response_model = response_model or self.response_format
formatted_content, system_instruction = self._format_messages_for_gemini(
messages
@@ -303,7 +308,7 @@ class GeminiCompletion(BaseLLM):
raise ValueError("LLM call blocked by before_llm_call hook")
config = self._prepare_generation_config(
system_instruction, tools, response_model
system_instruction, tools, effective_response_model
)
if self.stream:
@@ -313,7 +318,7 @@ class GeminiCompletion(BaseLLM):
available_functions,
from_task,
from_agent,
response_model,
effective_response_model,
)
return self._handle_completion(
@@ -322,7 +327,7 @@ class GeminiCompletion(BaseLLM):
available_functions,
from_task,
from_agent,
response_model,
effective_response_model,
)
except APIError as e:
@@ -374,13 +379,14 @@ class GeminiCompletion(BaseLLM):
from_agent=from_agent,
)
self.tools = tools
effective_response_model = response_model or self.response_format
formatted_content, system_instruction = self._format_messages_for_gemini(
messages
)
config = self._prepare_generation_config(
system_instruction, tools, response_model
system_instruction, tools, effective_response_model
)
if self.stream:
@@ -390,7 +396,7 @@ class GeminiCompletion(BaseLLM):
available_functions,
from_task,
from_agent,
response_model,
effective_response_model,
)
return await self._ahandle_completion(
@@ -399,7 +405,7 @@ class GeminiCompletion(BaseLLM):
available_functions,
from_task,
from_agent,
response_model,
effective_response_model,
)
except APIError as e:
@@ -570,10 +576,10 @@ class GeminiCompletion(BaseLLM):
types.Content(role="user", parts=[function_response_part])
)
elif role == "assistant" and message.get("tool_calls"):
parts: list[types.Part] = []
tool_parts: list[types.Part] = []
if text_content:
parts.append(types.Part.from_text(text=text_content))
tool_parts.append(types.Part.from_text(text=text_content))
tool_calls: list[dict[str, Any]] = message.get("tool_calls") or []
for tool_call in tool_calls:
@@ -592,11 +598,11 @@ class GeminiCompletion(BaseLLM):
else:
func_args = func_args_raw
parts.append(
tool_parts.append(
types.Part.from_function_call(name=func_name, args=func_args)
)
contents.append(types.Content(role="model", parts=parts))
contents.append(types.Content(role="model", parts=tool_parts))
else:
# Convert role for Gemini (assistant -> model)
gemini_role = "model" if role == "assistant" else "user"
@@ -790,7 +796,7 @@ class GeminiCompletion(BaseLLM):
Returns:
Tuple of (updated full_response, updated function_calls, updated usage_data)
"""
response_id=chunk.response_id if hasattr(chunk,"response_id") else None
response_id = chunk.response_id if hasattr(chunk, "response_id") else None
if chunk.usage_metadata:
usage_data = self._extract_token_usage(chunk)
@@ -800,7 +806,7 @@ class GeminiCompletion(BaseLLM):
chunk=chunk.text,
from_task=from_task,
from_agent=from_agent,
response_id=response_id
response_id=response_id,
)
if chunk.candidates:
@@ -837,7 +843,7 @@ class GeminiCompletion(BaseLLM):
"index": call_index,
},
call_type=LLMCallType.TOOL_CALL,
response_id=response_id
response_id=response_id,
)
return full_response, function_calls, usage_data
@@ -972,7 +978,7 @@ class GeminiCompletion(BaseLLM):
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str:
) -> str | Any:
"""Handle streaming content generation."""
full_response = ""
function_calls: dict[int, dict[str, Any]] = {}
@@ -1050,7 +1056,7 @@ class GeminiCompletion(BaseLLM):
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str:
) -> str | Any:
"""Handle async streaming content generation."""
full_response = ""
function_calls: dict[int, dict[str, Any]] = {}