mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-27 17:18:13 +00:00
feat: add response_format parameter to Azure and Gemini providers
This commit is contained in:
@@ -92,6 +92,7 @@ class AzureCompletion(BaseLLM):
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = False,
|
||||
interceptor: BaseInterceptor[Any, Any] | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize Azure AI Inference chat completion client.
|
||||
@@ -111,6 +112,9 @@ class AzureCompletion(BaseLLM):
|
||||
stop: Stop sequences
|
||||
stream: Enable streaming responses
|
||||
interceptor: HTTP interceptor (not yet supported for Azure).
|
||||
response_format: Pydantic model for structured output. Used as default when
|
||||
response_model is not passed to call()/acall() methods.
|
||||
Only works with OpenAI models deployed on Azure.
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
if interceptor is not None:
|
||||
@@ -165,6 +169,7 @@ class AzureCompletion(BaseLLM):
|
||||
self.presence_penalty = presence_penalty
|
||||
self.max_tokens = max_tokens
|
||||
self.stream = stream
|
||||
self.response_format = response_format
|
||||
|
||||
self.is_openai_model = any(
|
||||
prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"]
|
||||
@@ -298,6 +303,7 @@ class AzureCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
effective_response_model = response_model or self.response_format
|
||||
|
||||
# Format messages for Azure
|
||||
formatted_messages = self._format_messages_for_azure(messages)
|
||||
@@ -307,7 +313,7 @@ class AzureCompletion(BaseLLM):
|
||||
|
||||
# Prepare completion parameters
|
||||
completion_params = self._prepare_completion_params(
|
||||
formatted_messages, tools, response_model
|
||||
formatted_messages, tools, effective_response_model
|
||||
)
|
||||
|
||||
# Handle streaming vs non-streaming
|
||||
@@ -317,7 +323,7 @@ class AzureCompletion(BaseLLM):
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
effective_response_model,
|
||||
)
|
||||
|
||||
return self._handle_completion(
|
||||
@@ -325,7 +331,7 @@ class AzureCompletion(BaseLLM):
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
effective_response_model,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -364,11 +370,12 @@ class AzureCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
effective_response_model = response_model or self.response_format
|
||||
|
||||
formatted_messages = self._format_messages_for_azure(messages)
|
||||
|
||||
completion_params = self._prepare_completion_params(
|
||||
formatted_messages, tools, response_model
|
||||
formatted_messages, tools, effective_response_model
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
@@ -377,7 +384,7 @@ class AzureCompletion(BaseLLM):
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
effective_response_model,
|
||||
)
|
||||
|
||||
return await self._ahandle_completion(
|
||||
@@ -385,7 +392,7 @@ class AzureCompletion(BaseLLM):
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
effective_response_model,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -726,7 +733,7 @@ class AzureCompletion(BaseLLM):
|
||||
"""
|
||||
if update.choices:
|
||||
choice = update.choices[0]
|
||||
response_id = update.id if hasattr(update,"id") else None
|
||||
response_id = update.id if hasattr(update, "id") else None
|
||||
if choice.delta and choice.delta.content:
|
||||
content_delta = choice.delta.content
|
||||
full_response += content_delta
|
||||
@@ -734,7 +741,7 @@ class AzureCompletion(BaseLLM):
|
||||
chunk=content_delta,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_id=response_id
|
||||
response_id=response_id,
|
||||
)
|
||||
|
||||
if choice.delta and choice.delta.tool_calls:
|
||||
@@ -769,7 +776,7 @@ class AzureCompletion(BaseLLM):
|
||||
"index": idx,
|
||||
},
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
response_id=response_id
|
||||
response_id=response_id,
|
||||
)
|
||||
|
||||
return full_response
|
||||
|
||||
@@ -56,6 +56,7 @@ class GeminiCompletion(BaseLLM):
|
||||
client_params: dict[str, Any] | None = None,
|
||||
interceptor: BaseInterceptor[Any, Any] | None = None,
|
||||
use_vertexai: bool | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize Google Gemini chat completion client.
|
||||
@@ -86,6 +87,8 @@ class GeminiCompletion(BaseLLM):
|
||||
- None (default): Check GOOGLE_GENAI_USE_VERTEXAI env var
|
||||
When using Vertex AI with API key (Express mode), http_options with
|
||||
api_version="v1" is automatically configured.
|
||||
response_format: Pydantic model for structured output. Used as default when
|
||||
response_model is not passed to call()/acall() methods.
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
if interceptor is not None:
|
||||
@@ -121,6 +124,7 @@ class GeminiCompletion(BaseLLM):
|
||||
self.safety_settings = safety_settings or {}
|
||||
self.stop_sequences = stop_sequences or []
|
||||
self.tools: list[dict[str, Any]] | None = None
|
||||
self.response_format = response_format
|
||||
|
||||
# Model-specific settings
|
||||
version_match = re.search(r"gemini-(\d+(?:\.\d+)?)", model.lower())
|
||||
@@ -292,6 +296,7 @@ class GeminiCompletion(BaseLLM):
|
||||
from_agent=from_agent,
|
||||
)
|
||||
self.tools = tools
|
||||
effective_response_model = response_model or self.response_format
|
||||
|
||||
formatted_content, system_instruction = self._format_messages_for_gemini(
|
||||
messages
|
||||
@@ -303,7 +308,7 @@ class GeminiCompletion(BaseLLM):
|
||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||
|
||||
config = self._prepare_generation_config(
|
||||
system_instruction, tools, response_model
|
||||
system_instruction, tools, effective_response_model
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
@@ -313,7 +318,7 @@ class GeminiCompletion(BaseLLM):
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
effective_response_model,
|
||||
)
|
||||
|
||||
return self._handle_completion(
|
||||
@@ -322,7 +327,7 @@ class GeminiCompletion(BaseLLM):
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
effective_response_model,
|
||||
)
|
||||
|
||||
except APIError as e:
|
||||
@@ -374,13 +379,14 @@ class GeminiCompletion(BaseLLM):
|
||||
from_agent=from_agent,
|
||||
)
|
||||
self.tools = tools
|
||||
effective_response_model = response_model or self.response_format
|
||||
|
||||
formatted_content, system_instruction = self._format_messages_for_gemini(
|
||||
messages
|
||||
)
|
||||
|
||||
config = self._prepare_generation_config(
|
||||
system_instruction, tools, response_model
|
||||
system_instruction, tools, effective_response_model
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
@@ -390,7 +396,7 @@ class GeminiCompletion(BaseLLM):
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
effective_response_model,
|
||||
)
|
||||
|
||||
return await self._ahandle_completion(
|
||||
@@ -399,7 +405,7 @@ class GeminiCompletion(BaseLLM):
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
effective_response_model,
|
||||
)
|
||||
|
||||
except APIError as e:
|
||||
@@ -570,10 +576,10 @@ class GeminiCompletion(BaseLLM):
|
||||
types.Content(role="user", parts=[function_response_part])
|
||||
)
|
||||
elif role == "assistant" and message.get("tool_calls"):
|
||||
parts: list[types.Part] = []
|
||||
tool_parts: list[types.Part] = []
|
||||
|
||||
if text_content:
|
||||
parts.append(types.Part.from_text(text=text_content))
|
||||
tool_parts.append(types.Part.from_text(text=text_content))
|
||||
|
||||
tool_calls: list[dict[str, Any]] = message.get("tool_calls") or []
|
||||
for tool_call in tool_calls:
|
||||
@@ -592,11 +598,11 @@ class GeminiCompletion(BaseLLM):
|
||||
else:
|
||||
func_args = func_args_raw
|
||||
|
||||
parts.append(
|
||||
tool_parts.append(
|
||||
types.Part.from_function_call(name=func_name, args=func_args)
|
||||
)
|
||||
|
||||
contents.append(types.Content(role="model", parts=parts))
|
||||
contents.append(types.Content(role="model", parts=tool_parts))
|
||||
else:
|
||||
# Convert role for Gemini (assistant -> model)
|
||||
gemini_role = "model" if role == "assistant" else "user"
|
||||
@@ -790,7 +796,7 @@ class GeminiCompletion(BaseLLM):
|
||||
Returns:
|
||||
Tuple of (updated full_response, updated function_calls, updated usage_data)
|
||||
"""
|
||||
response_id=chunk.response_id if hasattr(chunk,"response_id") else None
|
||||
response_id = chunk.response_id if hasattr(chunk, "response_id") else None
|
||||
if chunk.usage_metadata:
|
||||
usage_data = self._extract_token_usage(chunk)
|
||||
|
||||
@@ -800,7 +806,7 @@ class GeminiCompletion(BaseLLM):
|
||||
chunk=chunk.text,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_id=response_id
|
||||
response_id=response_id,
|
||||
)
|
||||
|
||||
if chunk.candidates:
|
||||
@@ -837,7 +843,7 @@ class GeminiCompletion(BaseLLM):
|
||||
"index": call_index,
|
||||
},
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
response_id=response_id
|
||||
response_id=response_id,
|
||||
)
|
||||
|
||||
return full_response, function_calls, usage_data
|
||||
@@ -972,7 +978,7 @@ class GeminiCompletion(BaseLLM):
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
) -> str | Any:
|
||||
"""Handle streaming content generation."""
|
||||
full_response = ""
|
||||
function_calls: dict[int, dict[str, Any]] = {}
|
||||
@@ -1050,7 +1056,7 @@ class GeminiCompletion(BaseLLM):
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
) -> str | Any:
|
||||
"""Handle async streaming content generation."""
|
||||
full_response = ""
|
||||
function_calls: dict[int, dict[str, Any]] = {}
|
||||
|
||||
Reference in New Issue
Block a user