From c9b240a86c7dfa861ff5950dba53ff52d628a314 Mon Sep 17 00:00:00 2001 From: Greyson Lalonde Date: Mon, 26 Jan 2026 08:42:06 -0500 Subject: [PATCH] feat: add response_format parameter to Azure and Gemini providers --- .../crewai/llms/providers/azure/completion.py | 25 ++++++++----- .../llms/providers/gemini/completion.py | 36 +++++++++++-------- 2 files changed, 37 insertions(+), 24 deletions(-) diff --git a/lib/crewai/src/crewai/llms/providers/azure/completion.py b/lib/crewai/src/crewai/llms/providers/azure/completion.py index a3aed7f4b..1de18d984 100644 --- a/lib/crewai/src/crewai/llms/providers/azure/completion.py +++ b/lib/crewai/src/crewai/llms/providers/azure/completion.py @@ -92,6 +92,7 @@ class AzureCompletion(BaseLLM): stop: list[str] | None = None, stream: bool = False, interceptor: BaseInterceptor[Any, Any] | None = None, + response_format: type[BaseModel] | None = None, **kwargs: Any, ): """Initialize Azure AI Inference chat completion client. @@ -111,6 +112,9 @@ class AzureCompletion(BaseLLM): stop: Stop sequences stream: Enable streaming responses interceptor: HTTP interceptor (not yet supported for Azure). + response_format: Pydantic model for structured output. Used as default when + response_model is not passed to call()/acall() methods. + Only works with OpenAI models deployed on Azure. **kwargs: Additional parameters """ if interceptor is not None: @@ -165,6 +169,7 @@ class AzureCompletion(BaseLLM): self.presence_penalty = presence_penalty self.max_tokens = max_tokens self.stream = stream + self.response_format = response_format self.is_openai_model = any( prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"] @@ -298,6 +303,7 @@ class AzureCompletion(BaseLLM): from_task=from_task, from_agent=from_agent, ) + effective_response_model = response_model or self.response_format # Format messages for Azure formatted_messages = self._format_messages_for_azure(messages) @@ -307,7 +313,7 @@ class AzureCompletion(BaseLLM): # Prepare completion parameters completion_params = self._prepare_completion_params( - formatted_messages, tools, response_model + formatted_messages, tools, effective_response_model ) # Handle streaming vs non-streaming @@ -317,7 +323,7 @@ class AzureCompletion(BaseLLM): available_functions, from_task, from_agent, - response_model, + effective_response_model, ) return self._handle_completion( @@ -325,7 +331,7 @@ class AzureCompletion(BaseLLM): available_functions, from_task, from_agent, - response_model, + effective_response_model, ) except Exception as e: @@ -364,11 +370,12 @@ class AzureCompletion(BaseLLM): from_task=from_task, from_agent=from_agent, ) + effective_response_model = response_model or self.response_format formatted_messages = self._format_messages_for_azure(messages) completion_params = self._prepare_completion_params( - formatted_messages, tools, response_model + formatted_messages, tools, effective_response_model ) if self.stream: @@ -377,7 +384,7 @@ class AzureCompletion(BaseLLM): available_functions, from_task, from_agent, - response_model, + effective_response_model, ) return await self._ahandle_completion( @@ -385,7 +392,7 @@ class AzureCompletion(BaseLLM): available_functions, from_task, from_agent, - response_model, + effective_response_model, ) except Exception as e: @@ -726,7 +733,7 @@ class AzureCompletion(BaseLLM): """ if update.choices: choice = update.choices[0] - response_id = update.id if hasattr(update,"id") else None + response_id = update.id if hasattr(update, "id") else None if choice.delta and choice.delta.content: content_delta = choice.delta.content full_response += content_delta @@ -734,7 +741,7 @@ class AzureCompletion(BaseLLM): chunk=content_delta, from_task=from_task, from_agent=from_agent, - response_id=response_id + response_id=response_id, ) if choice.delta and choice.delta.tool_calls: @@ -769,7 +776,7 @@ class AzureCompletion(BaseLLM): "index": idx, }, call_type=LLMCallType.TOOL_CALL, - response_id=response_id + response_id=response_id, ) return full_response diff --git a/lib/crewai/src/crewai/llms/providers/gemini/completion.py b/lib/crewai/src/crewai/llms/providers/gemini/completion.py index 9687f3d4f..d101ad0be 100644 --- a/lib/crewai/src/crewai/llms/providers/gemini/completion.py +++ b/lib/crewai/src/crewai/llms/providers/gemini/completion.py @@ -56,6 +56,7 @@ class GeminiCompletion(BaseLLM): client_params: dict[str, Any] | None = None, interceptor: BaseInterceptor[Any, Any] | None = None, use_vertexai: bool | None = None, + response_format: type[BaseModel] | None = None, **kwargs: Any, ): """Initialize Google Gemini chat completion client. @@ -86,6 +87,8 @@ class GeminiCompletion(BaseLLM): - None (default): Check GOOGLE_GENAI_USE_VERTEXAI env var When using Vertex AI with API key (Express mode), http_options with api_version="v1" is automatically configured. + response_format: Pydantic model for structured output. Used as default when + response_model is not passed to call()/acall() methods. **kwargs: Additional parameters """ if interceptor is not None: @@ -121,6 +124,7 @@ class GeminiCompletion(BaseLLM): self.safety_settings = safety_settings or {} self.stop_sequences = stop_sequences or [] self.tools: list[dict[str, Any]] | None = None + self.response_format = response_format # Model-specific settings version_match = re.search(r"gemini-(\d+(?:\.\d+)?)", model.lower()) @@ -292,6 +296,7 @@ class GeminiCompletion(BaseLLM): from_agent=from_agent, ) self.tools = tools + effective_response_model = response_model or self.response_format formatted_content, system_instruction = self._format_messages_for_gemini( messages @@ -303,7 +308,7 @@ class GeminiCompletion(BaseLLM): raise ValueError("LLM call blocked by before_llm_call hook") config = self._prepare_generation_config( - system_instruction, tools, response_model + system_instruction, tools, effective_response_model ) if self.stream: @@ -313,7 +318,7 @@ class GeminiCompletion(BaseLLM): available_functions, from_task, from_agent, - response_model, + effective_response_model, ) return self._handle_completion( @@ -322,7 +327,7 @@ class GeminiCompletion(BaseLLM): available_functions, from_task, from_agent, - response_model, + effective_response_model, ) except APIError as e: @@ -374,13 +379,14 @@ class GeminiCompletion(BaseLLM): from_agent=from_agent, ) self.tools = tools + effective_response_model = response_model or self.response_format formatted_content, system_instruction = self._format_messages_for_gemini( messages ) config = self._prepare_generation_config( - system_instruction, tools, response_model + system_instruction, tools, effective_response_model ) if self.stream: @@ -390,7 +396,7 @@ class GeminiCompletion(BaseLLM): available_functions, from_task, from_agent, - response_model, + effective_response_model, ) return await self._ahandle_completion( @@ -399,7 +405,7 @@ class GeminiCompletion(BaseLLM): available_functions, from_task, from_agent, - response_model, + effective_response_model, ) except APIError as e: @@ -570,10 +576,10 @@ class GeminiCompletion(BaseLLM): types.Content(role="user", parts=[function_response_part]) ) elif role == "assistant" and message.get("tool_calls"): - parts: list[types.Part] = [] + tool_parts: list[types.Part] = [] if text_content: - parts.append(types.Part.from_text(text=text_content)) + tool_parts.append(types.Part.from_text(text=text_content)) tool_calls: list[dict[str, Any]] = message.get("tool_calls") or [] for tool_call in tool_calls: @@ -592,11 +598,11 @@ class GeminiCompletion(BaseLLM): else: func_args = func_args_raw - parts.append( + tool_parts.append( types.Part.from_function_call(name=func_name, args=func_args) ) - contents.append(types.Content(role="model", parts=parts)) + contents.append(types.Content(role="model", parts=tool_parts)) else: # Convert role for Gemini (assistant -> model) gemini_role = "model" if role == "assistant" else "user" @@ -790,7 +796,7 @@ class GeminiCompletion(BaseLLM): Returns: Tuple of (updated full_response, updated function_calls, updated usage_data) """ - response_id=chunk.response_id if hasattr(chunk,"response_id") else None + response_id = chunk.response_id if hasattr(chunk, "response_id") else None if chunk.usage_metadata: usage_data = self._extract_token_usage(chunk) @@ -800,7 +806,7 @@ class GeminiCompletion(BaseLLM): chunk=chunk.text, from_task=from_task, from_agent=from_agent, - response_id=response_id + response_id=response_id, ) if chunk.candidates: @@ -837,7 +843,7 @@ class GeminiCompletion(BaseLLM): "index": call_index, }, call_type=LLMCallType.TOOL_CALL, - response_id=response_id + response_id=response_id, ) return full_response, function_calls, usage_data @@ -972,7 +978,7 @@ class GeminiCompletion(BaseLLM): from_task: Any | None = None, from_agent: Any | None = None, response_model: type[BaseModel] | None = None, - ) -> str: + ) -> str | Any: """Handle streaming content generation.""" full_response = "" function_calls: dict[int, dict[str, Any]] = {} @@ -1050,7 +1056,7 @@ class GeminiCompletion(BaseLLM): from_task: Any | None = None, from_agent: Any | None = None, response_model: type[BaseModel] | None = None, - ) -> str: + ) -> str | Any: """Handle async streaming content generation.""" full_response = "" function_calls: dict[int, dict[str, Any]] = {}