mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-28 01:28:14 +00:00
feat: add response_format parameter to Azure and Gemini providers
This commit is contained in:
@@ -92,6 +92,7 @@ class AzureCompletion(BaseLLM):
|
|||||||
stop: list[str] | None = None,
|
stop: list[str] | None = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
interceptor: BaseInterceptor[Any, Any] | None = None,
|
interceptor: BaseInterceptor[Any, Any] | None = None,
|
||||||
|
response_format: type[BaseModel] | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
"""Initialize Azure AI Inference chat completion client.
|
"""Initialize Azure AI Inference chat completion client.
|
||||||
@@ -111,6 +112,9 @@ class AzureCompletion(BaseLLM):
|
|||||||
stop: Stop sequences
|
stop: Stop sequences
|
||||||
stream: Enable streaming responses
|
stream: Enable streaming responses
|
||||||
interceptor: HTTP interceptor (not yet supported for Azure).
|
interceptor: HTTP interceptor (not yet supported for Azure).
|
||||||
|
response_format: Pydantic model for structured output. Used as default when
|
||||||
|
response_model is not passed to call()/acall() methods.
|
||||||
|
Only works with OpenAI models deployed on Azure.
|
||||||
**kwargs: Additional parameters
|
**kwargs: Additional parameters
|
||||||
"""
|
"""
|
||||||
if interceptor is not None:
|
if interceptor is not None:
|
||||||
@@ -165,6 +169,7 @@ class AzureCompletion(BaseLLM):
|
|||||||
self.presence_penalty = presence_penalty
|
self.presence_penalty = presence_penalty
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
self.stream = stream
|
self.stream = stream
|
||||||
|
self.response_format = response_format
|
||||||
|
|
||||||
self.is_openai_model = any(
|
self.is_openai_model = any(
|
||||||
prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"]
|
prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"]
|
||||||
@@ -298,6 +303,7 @@ class AzureCompletion(BaseLLM):
|
|||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
)
|
)
|
||||||
|
effective_response_model = response_model or self.response_format
|
||||||
|
|
||||||
# Format messages for Azure
|
# Format messages for Azure
|
||||||
formatted_messages = self._format_messages_for_azure(messages)
|
formatted_messages = self._format_messages_for_azure(messages)
|
||||||
@@ -307,7 +313,7 @@ class AzureCompletion(BaseLLM):
|
|||||||
|
|
||||||
# Prepare completion parameters
|
# Prepare completion parameters
|
||||||
completion_params = self._prepare_completion_params(
|
completion_params = self._prepare_completion_params(
|
||||||
formatted_messages, tools, response_model
|
formatted_messages, tools, effective_response_model
|
||||||
)
|
)
|
||||||
|
|
||||||
# Handle streaming vs non-streaming
|
# Handle streaming vs non-streaming
|
||||||
@@ -317,7 +323,7 @@ class AzureCompletion(BaseLLM):
|
|||||||
available_functions,
|
available_functions,
|
||||||
from_task,
|
from_task,
|
||||||
from_agent,
|
from_agent,
|
||||||
response_model,
|
effective_response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._handle_completion(
|
return self._handle_completion(
|
||||||
@@ -325,7 +331,7 @@ class AzureCompletion(BaseLLM):
|
|||||||
available_functions,
|
available_functions,
|
||||||
from_task,
|
from_task,
|
||||||
from_agent,
|
from_agent,
|
||||||
response_model,
|
effective_response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -364,11 +370,12 @@ class AzureCompletion(BaseLLM):
|
|||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
)
|
)
|
||||||
|
effective_response_model = response_model or self.response_format
|
||||||
|
|
||||||
formatted_messages = self._format_messages_for_azure(messages)
|
formatted_messages = self._format_messages_for_azure(messages)
|
||||||
|
|
||||||
completion_params = self._prepare_completion_params(
|
completion_params = self._prepare_completion_params(
|
||||||
formatted_messages, tools, response_model
|
formatted_messages, tools, effective_response_model
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.stream:
|
if self.stream:
|
||||||
@@ -377,7 +384,7 @@ class AzureCompletion(BaseLLM):
|
|||||||
available_functions,
|
available_functions,
|
||||||
from_task,
|
from_task,
|
||||||
from_agent,
|
from_agent,
|
||||||
response_model,
|
effective_response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
return await self._ahandle_completion(
|
return await self._ahandle_completion(
|
||||||
@@ -385,7 +392,7 @@ class AzureCompletion(BaseLLM):
|
|||||||
available_functions,
|
available_functions,
|
||||||
from_task,
|
from_task,
|
||||||
from_agent,
|
from_agent,
|
||||||
response_model,
|
effective_response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -726,7 +733,7 @@ class AzureCompletion(BaseLLM):
|
|||||||
"""
|
"""
|
||||||
if update.choices:
|
if update.choices:
|
||||||
choice = update.choices[0]
|
choice = update.choices[0]
|
||||||
response_id = update.id if hasattr(update,"id") else None
|
response_id = update.id if hasattr(update, "id") else None
|
||||||
if choice.delta and choice.delta.content:
|
if choice.delta and choice.delta.content:
|
||||||
content_delta = choice.delta.content
|
content_delta = choice.delta.content
|
||||||
full_response += content_delta
|
full_response += content_delta
|
||||||
@@ -734,7 +741,7 @@ class AzureCompletion(BaseLLM):
|
|||||||
chunk=content_delta,
|
chunk=content_delta,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
response_id=response_id
|
response_id=response_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if choice.delta and choice.delta.tool_calls:
|
if choice.delta and choice.delta.tool_calls:
|
||||||
@@ -769,7 +776,7 @@ class AzureCompletion(BaseLLM):
|
|||||||
"index": idx,
|
"index": idx,
|
||||||
},
|
},
|
||||||
call_type=LLMCallType.TOOL_CALL,
|
call_type=LLMCallType.TOOL_CALL,
|
||||||
response_id=response_id
|
response_id=response_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return full_response
|
return full_response
|
||||||
|
|||||||
@@ -56,6 +56,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
client_params: dict[str, Any] | None = None,
|
client_params: dict[str, Any] | None = None,
|
||||||
interceptor: BaseInterceptor[Any, Any] | None = None,
|
interceptor: BaseInterceptor[Any, Any] | None = None,
|
||||||
use_vertexai: bool | None = None,
|
use_vertexai: bool | None = None,
|
||||||
|
response_format: type[BaseModel] | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
"""Initialize Google Gemini chat completion client.
|
"""Initialize Google Gemini chat completion client.
|
||||||
@@ -86,6 +87,8 @@ class GeminiCompletion(BaseLLM):
|
|||||||
- None (default): Check GOOGLE_GENAI_USE_VERTEXAI env var
|
- None (default): Check GOOGLE_GENAI_USE_VERTEXAI env var
|
||||||
When using Vertex AI with API key (Express mode), http_options with
|
When using Vertex AI with API key (Express mode), http_options with
|
||||||
api_version="v1" is automatically configured.
|
api_version="v1" is automatically configured.
|
||||||
|
response_format: Pydantic model for structured output. Used as default when
|
||||||
|
response_model is not passed to call()/acall() methods.
|
||||||
**kwargs: Additional parameters
|
**kwargs: Additional parameters
|
||||||
"""
|
"""
|
||||||
if interceptor is not None:
|
if interceptor is not None:
|
||||||
@@ -121,6 +124,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
self.safety_settings = safety_settings or {}
|
self.safety_settings = safety_settings or {}
|
||||||
self.stop_sequences = stop_sequences or []
|
self.stop_sequences = stop_sequences or []
|
||||||
self.tools: list[dict[str, Any]] | None = None
|
self.tools: list[dict[str, Any]] | None = None
|
||||||
|
self.response_format = response_format
|
||||||
|
|
||||||
# Model-specific settings
|
# Model-specific settings
|
||||||
version_match = re.search(r"gemini-(\d+(?:\.\d+)?)", model.lower())
|
version_match = re.search(r"gemini-(\d+(?:\.\d+)?)", model.lower())
|
||||||
@@ -292,6 +296,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
)
|
)
|
||||||
self.tools = tools
|
self.tools = tools
|
||||||
|
effective_response_model = response_model or self.response_format
|
||||||
|
|
||||||
formatted_content, system_instruction = self._format_messages_for_gemini(
|
formatted_content, system_instruction = self._format_messages_for_gemini(
|
||||||
messages
|
messages
|
||||||
@@ -303,7 +308,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||||
|
|
||||||
config = self._prepare_generation_config(
|
config = self._prepare_generation_config(
|
||||||
system_instruction, tools, response_model
|
system_instruction, tools, effective_response_model
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.stream:
|
if self.stream:
|
||||||
@@ -313,7 +318,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
available_functions,
|
available_functions,
|
||||||
from_task,
|
from_task,
|
||||||
from_agent,
|
from_agent,
|
||||||
response_model,
|
effective_response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._handle_completion(
|
return self._handle_completion(
|
||||||
@@ -322,7 +327,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
available_functions,
|
available_functions,
|
||||||
from_task,
|
from_task,
|
||||||
from_agent,
|
from_agent,
|
||||||
response_model,
|
effective_response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
except APIError as e:
|
except APIError as e:
|
||||||
@@ -374,13 +379,14 @@ class GeminiCompletion(BaseLLM):
|
|||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
)
|
)
|
||||||
self.tools = tools
|
self.tools = tools
|
||||||
|
effective_response_model = response_model or self.response_format
|
||||||
|
|
||||||
formatted_content, system_instruction = self._format_messages_for_gemini(
|
formatted_content, system_instruction = self._format_messages_for_gemini(
|
||||||
messages
|
messages
|
||||||
)
|
)
|
||||||
|
|
||||||
config = self._prepare_generation_config(
|
config = self._prepare_generation_config(
|
||||||
system_instruction, tools, response_model
|
system_instruction, tools, effective_response_model
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.stream:
|
if self.stream:
|
||||||
@@ -390,7 +396,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
available_functions,
|
available_functions,
|
||||||
from_task,
|
from_task,
|
||||||
from_agent,
|
from_agent,
|
||||||
response_model,
|
effective_response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
return await self._ahandle_completion(
|
return await self._ahandle_completion(
|
||||||
@@ -399,7 +405,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
available_functions,
|
available_functions,
|
||||||
from_task,
|
from_task,
|
||||||
from_agent,
|
from_agent,
|
||||||
response_model,
|
effective_response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
except APIError as e:
|
except APIError as e:
|
||||||
@@ -570,10 +576,10 @@ class GeminiCompletion(BaseLLM):
|
|||||||
types.Content(role="user", parts=[function_response_part])
|
types.Content(role="user", parts=[function_response_part])
|
||||||
)
|
)
|
||||||
elif role == "assistant" and message.get("tool_calls"):
|
elif role == "assistant" and message.get("tool_calls"):
|
||||||
parts: list[types.Part] = []
|
tool_parts: list[types.Part] = []
|
||||||
|
|
||||||
if text_content:
|
if text_content:
|
||||||
parts.append(types.Part.from_text(text=text_content))
|
tool_parts.append(types.Part.from_text(text=text_content))
|
||||||
|
|
||||||
tool_calls: list[dict[str, Any]] = message.get("tool_calls") or []
|
tool_calls: list[dict[str, Any]] = message.get("tool_calls") or []
|
||||||
for tool_call in tool_calls:
|
for tool_call in tool_calls:
|
||||||
@@ -592,11 +598,11 @@ class GeminiCompletion(BaseLLM):
|
|||||||
else:
|
else:
|
||||||
func_args = func_args_raw
|
func_args = func_args_raw
|
||||||
|
|
||||||
parts.append(
|
tool_parts.append(
|
||||||
types.Part.from_function_call(name=func_name, args=func_args)
|
types.Part.from_function_call(name=func_name, args=func_args)
|
||||||
)
|
)
|
||||||
|
|
||||||
contents.append(types.Content(role="model", parts=parts))
|
contents.append(types.Content(role="model", parts=tool_parts))
|
||||||
else:
|
else:
|
||||||
# Convert role for Gemini (assistant -> model)
|
# Convert role for Gemini (assistant -> model)
|
||||||
gemini_role = "model" if role == "assistant" else "user"
|
gemini_role = "model" if role == "assistant" else "user"
|
||||||
@@ -790,7 +796,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple of (updated full_response, updated function_calls, updated usage_data)
|
Tuple of (updated full_response, updated function_calls, updated usage_data)
|
||||||
"""
|
"""
|
||||||
response_id=chunk.response_id if hasattr(chunk,"response_id") else None
|
response_id = chunk.response_id if hasattr(chunk, "response_id") else None
|
||||||
if chunk.usage_metadata:
|
if chunk.usage_metadata:
|
||||||
usage_data = self._extract_token_usage(chunk)
|
usage_data = self._extract_token_usage(chunk)
|
||||||
|
|
||||||
@@ -800,7 +806,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
chunk=chunk.text,
|
chunk=chunk.text,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
response_id=response_id
|
response_id=response_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if chunk.candidates:
|
if chunk.candidates:
|
||||||
@@ -837,7 +843,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
"index": call_index,
|
"index": call_index,
|
||||||
},
|
},
|
||||||
call_type=LLMCallType.TOOL_CALL,
|
call_type=LLMCallType.TOOL_CALL,
|
||||||
response_id=response_id
|
response_id=response_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return full_response, function_calls, usage_data
|
return full_response, function_calls, usage_data
|
||||||
@@ -972,7 +978,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
from_task: Any | None = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Any | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str:
|
) -> str | Any:
|
||||||
"""Handle streaming content generation."""
|
"""Handle streaming content generation."""
|
||||||
full_response = ""
|
full_response = ""
|
||||||
function_calls: dict[int, dict[str, Any]] = {}
|
function_calls: dict[int, dict[str, Any]] = {}
|
||||||
@@ -1050,7 +1056,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
from_task: Any | None = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Any | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str:
|
) -> str | Any:
|
||||||
"""Handle async streaming content generation."""
|
"""Handle async streaming content generation."""
|
||||||
full_response = ""
|
full_response = ""
|
||||||
function_calls: dict[int, dict[str, Any]] = {}
|
function_calls: dict[int, dict[str, Any]] = {}
|
||||||
|
|||||||
Reference in New Issue
Block a user