mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-01 07:13:00 +00:00
fix: improve structured output handling across providers and agents
- add gemini 2.0 schema support using response_json_schema with propertyordering while retaining backward compatibility for earlier models - refactor llm completions to return validated pydantic models when a response_model is provided, updating hooks, types, and tests for consistent structured outputs - extend agentfinish and executors to support basemodel outputs, improve anthropic structured parsing, and clean up schema utilities, tests, and original_json handling
This commit is contained in:
@@ -348,18 +348,36 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
# breakpoint()
|
# breakpoint()
|
||||||
if self.response_model is not None:
|
if self.response_model is not None:
|
||||||
try:
|
try:
|
||||||
self.response_model.model_validate_json(answer)
|
if isinstance(answer, BaseModel):
|
||||||
formatted_answer = AgentFinish(
|
output_json = answer.model_dump_json()
|
||||||
thought="",
|
formatted_answer = AgentFinish(
|
||||||
output=answer,
|
thought="",
|
||||||
text=answer,
|
output=answer,
|
||||||
)
|
text=output_json,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.response_model.model_validate_json(answer)
|
||||||
|
formatted_answer = AgentFinish(
|
||||||
|
thought="",
|
||||||
|
output=answer,
|
||||||
|
text=answer,
|
||||||
|
)
|
||||||
except ValidationError:
|
except ValidationError:
|
||||||
|
# If validation fails, convert BaseModel to JSON string for parsing
|
||||||
|
answer_str = (
|
||||||
|
answer.model_dump_json()
|
||||||
|
if isinstance(answer, BaseModel)
|
||||||
|
else str(answer)
|
||||||
|
)
|
||||||
formatted_answer = process_llm_response(
|
formatted_answer = process_llm_response(
|
||||||
answer, self.use_stop_words
|
answer_str, self.use_stop_words
|
||||||
) # type: ignore[assignment]
|
) # type: ignore[assignment]
|
||||||
else:
|
else:
|
||||||
formatted_answer = process_llm_response(answer, self.use_stop_words) # type: ignore[assignment]
|
# When no response_model, answer should be a string
|
||||||
|
answer_str = str(answer) if not isinstance(answer, str) else answer
|
||||||
|
formatted_answer = process_llm_response(
|
||||||
|
answer_str, self.use_stop_words
|
||||||
|
) # type: ignore[assignment]
|
||||||
|
|
||||||
if isinstance(formatted_answer, AgentAction):
|
if isinstance(formatted_answer, AgentAction):
|
||||||
# Extract agent fingerprint if available
|
# Extract agent fingerprint if available
|
||||||
@@ -520,6 +538,18 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
self._show_logs(formatted_answer)
|
self._show_logs(formatted_answer)
|
||||||
return formatted_answer
|
return formatted_answer
|
||||||
|
|
||||||
|
if isinstance(answer, BaseModel):
|
||||||
|
output_json = answer.model_dump_json()
|
||||||
|
formatted_answer = AgentFinish(
|
||||||
|
thought="",
|
||||||
|
output=answer,
|
||||||
|
text=output_json,
|
||||||
|
)
|
||||||
|
self._invoke_step_callback(formatted_answer)
|
||||||
|
self._append_message(output_json)
|
||||||
|
self._show_logs(formatted_answer)
|
||||||
|
return formatted_answer
|
||||||
|
|
||||||
# Unexpected response type, treat as final answer
|
# Unexpected response type, treat as final answer
|
||||||
formatted_answer = AgentFinish(
|
formatted_answer = AgentFinish(
|
||||||
thought="",
|
thought="",
|
||||||
@@ -570,11 +600,20 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
verbose=self.agent.verbose,
|
verbose=self.agent.verbose,
|
||||||
)
|
)
|
||||||
|
|
||||||
formatted_answer = AgentFinish(
|
if isinstance(answer, BaseModel):
|
||||||
thought="",
|
output_json = answer.model_dump_json()
|
||||||
output=str(answer),
|
formatted_answer = AgentFinish(
|
||||||
text=str(answer),
|
thought="",
|
||||||
)
|
output=answer,
|
||||||
|
text=output_json,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
answer_str = answer if isinstance(answer, str) else str(answer)
|
||||||
|
formatted_answer = AgentFinish(
|
||||||
|
thought="",
|
||||||
|
output=answer_str,
|
||||||
|
text=answer_str,
|
||||||
|
)
|
||||||
self._show_logs(formatted_answer)
|
self._show_logs(formatted_answer)
|
||||||
return formatted_answer
|
return formatted_answer
|
||||||
|
|
||||||
@@ -1031,18 +1070,36 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
|
|
||||||
if self.response_model is not None:
|
if self.response_model is not None:
|
||||||
try:
|
try:
|
||||||
self.response_model.model_validate_json(answer)
|
if isinstance(answer, BaseModel):
|
||||||
formatted_answer = AgentFinish(
|
output_json = answer.model_dump_json()
|
||||||
thought="",
|
formatted_answer = AgentFinish(
|
||||||
output=answer,
|
thought="",
|
||||||
text=answer,
|
output=answer,
|
||||||
)
|
text=output_json,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.response_model.model_validate_json(answer)
|
||||||
|
formatted_answer = AgentFinish(
|
||||||
|
thought="",
|
||||||
|
output=answer,
|
||||||
|
text=answer,
|
||||||
|
)
|
||||||
except ValidationError:
|
except ValidationError:
|
||||||
|
# If validation fails, convert BaseModel to JSON string for parsing
|
||||||
|
answer_str = (
|
||||||
|
answer.model_dump_json()
|
||||||
|
if isinstance(answer, BaseModel)
|
||||||
|
else str(answer)
|
||||||
|
)
|
||||||
formatted_answer = process_llm_response(
|
formatted_answer = process_llm_response(
|
||||||
answer, self.use_stop_words
|
answer_str, self.use_stop_words
|
||||||
) # type: ignore[assignment]
|
) # type: ignore[assignment]
|
||||||
else:
|
else:
|
||||||
formatted_answer = process_llm_response(answer, self.use_stop_words) # type: ignore[assignment]
|
# When no response_model, answer should be a string
|
||||||
|
answer_str = str(answer) if not isinstance(answer, str) else answer
|
||||||
|
formatted_answer = process_llm_response(
|
||||||
|
answer_str, self.use_stop_words
|
||||||
|
) # type: ignore[assignment]
|
||||||
|
|
||||||
if isinstance(formatted_answer, AgentAction):
|
if isinstance(formatted_answer, AgentAction):
|
||||||
fingerprint_context = {}
|
fingerprint_context = {}
|
||||||
@@ -1194,6 +1251,18 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
self._show_logs(formatted_answer)
|
self._show_logs(formatted_answer)
|
||||||
return formatted_answer
|
return formatted_answer
|
||||||
|
|
||||||
|
if isinstance(answer, BaseModel):
|
||||||
|
output_json = answer.model_dump_json()
|
||||||
|
formatted_answer = AgentFinish(
|
||||||
|
thought="",
|
||||||
|
output=answer,
|
||||||
|
text=output_json,
|
||||||
|
)
|
||||||
|
self._invoke_step_callback(formatted_answer)
|
||||||
|
self._append_message(output_json)
|
||||||
|
self._show_logs(formatted_answer)
|
||||||
|
return formatted_answer
|
||||||
|
|
||||||
# Unexpected response type, treat as final answer
|
# Unexpected response type, treat as final answer
|
||||||
formatted_answer = AgentFinish(
|
formatted_answer = AgentFinish(
|
||||||
thought="",
|
thought="",
|
||||||
@@ -1244,11 +1313,20 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
verbose=self.agent.verbose,
|
verbose=self.agent.verbose,
|
||||||
)
|
)
|
||||||
|
|
||||||
formatted_answer = AgentFinish(
|
if isinstance(answer, BaseModel):
|
||||||
thought="",
|
output_json = answer.model_dump_json()
|
||||||
output=str(answer),
|
formatted_answer = AgentFinish(
|
||||||
text=str(answer),
|
thought="",
|
||||||
)
|
output=answer,
|
||||||
|
text=output_json,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
answer_str = answer if isinstance(answer, str) else str(answer)
|
||||||
|
formatted_answer = AgentFinish(
|
||||||
|
thought="",
|
||||||
|
output=answer_str,
|
||||||
|
text=answer_str,
|
||||||
|
)
|
||||||
self._show_logs(formatted_answer)
|
self._show_logs(formatted_answer)
|
||||||
return formatted_answer
|
return formatted_answer
|
||||||
|
|
||||||
@@ -1421,7 +1499,12 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
Returns:
|
Returns:
|
||||||
Final answer after feedback.
|
Final answer after feedback.
|
||||||
"""
|
"""
|
||||||
human_feedback = self._ask_human_input(formatted_answer.output)
|
output_str = (
|
||||||
|
formatted_answer.output
|
||||||
|
if isinstance(formatted_answer.output, str)
|
||||||
|
else formatted_answer.output.model_dump_json()
|
||||||
|
)
|
||||||
|
human_feedback = self._ask_human_input(output_str)
|
||||||
|
|
||||||
if self._is_training_mode():
|
if self._is_training_mode():
|
||||||
return self._handle_training_feedback(formatted_answer, human_feedback)
|
return self._handle_training_feedback(formatted_answer, human_feedback)
|
||||||
@@ -1480,7 +1563,12 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
self.ask_for_human_input = False
|
self.ask_for_human_input = False
|
||||||
else:
|
else:
|
||||||
answer = self._process_feedback_iteration(feedback)
|
answer = self._process_feedback_iteration(feedback)
|
||||||
feedback = self._ask_human_input(answer.output)
|
output_str = (
|
||||||
|
answer.output
|
||||||
|
if isinstance(answer.output, str)
|
||||||
|
else answer.output.model_dump_json()
|
||||||
|
)
|
||||||
|
feedback = self._ask_human_input(output_str)
|
||||||
|
|
||||||
return answer
|
return answer
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ AgentAction or AgentFinish objects.
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from json_repair import repair_json # type: ignore[import-untyped]
|
from json_repair import repair_json # type: ignore[import-untyped]
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from crewai.agents.constants import (
|
from crewai.agents.constants import (
|
||||||
ACTION_INPUT_ONLY_REGEX,
|
ACTION_INPUT_ONLY_REGEX,
|
||||||
@@ -40,7 +41,7 @@ class AgentFinish:
|
|||||||
"""Represents the final answer from an agent."""
|
"""Represents the final answer from an agent."""
|
||||||
|
|
||||||
thought: str
|
thought: str
|
||||||
output: str
|
output: str | BaseModel
|
||||||
text: str
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ if TYPE_CHECKING:
|
|||||||
try:
|
try:
|
||||||
from anthropic import Anthropic, AsyncAnthropic, transform_schema
|
from anthropic import Anthropic, AsyncAnthropic, transform_schema
|
||||||
from anthropic.types import Message, TextBlock, ThinkingBlock, ToolUseBlock
|
from anthropic.types import Message, TextBlock, ThinkingBlock, ToolUseBlock
|
||||||
from anthropic.types.beta import BetaMessage
|
from anthropic.types.beta import BetaMessage, BetaTextBlock
|
||||||
import httpx
|
import httpx
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
@@ -337,6 +337,7 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
available_functions: Available functions for tool calling
|
available_functions: Available functions for tool calling
|
||||||
from_task: Task that initiated the call
|
from_task: Task that initiated the call
|
||||||
from_agent: Agent that initiated the call
|
from_agent: Agent that initiated the call
|
||||||
|
response_model: Optional response model.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Chat completion response or tool call result
|
Chat completion response or tool call result
|
||||||
@@ -677,31 +678,31 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
if _is_pydantic_model_class(response_model) and response.content:
|
if _is_pydantic_model_class(response_model) and response.content:
|
||||||
if use_native_structured_output:
|
if use_native_structured_output:
|
||||||
for block in response.content:
|
for block in response.content:
|
||||||
if isinstance(block, TextBlock):
|
if isinstance(block, (TextBlock, BetaTextBlock)):
|
||||||
structured_json = block.text
|
structured_data = response_model.model_validate_json(block.text)
|
||||||
self._emit_call_completed_event(
|
self._emit_call_completed_event(
|
||||||
response=structured_json,
|
response=structured_data.model_dump_json(),
|
||||||
call_type=LLMCallType.LLM_CALL,
|
call_type=LLMCallType.LLM_CALL,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
messages=params["messages"],
|
messages=params["messages"],
|
||||||
)
|
)
|
||||||
return structured_json
|
return structured_data
|
||||||
else:
|
else:
|
||||||
for block in response.content:
|
for block in response.content:
|
||||||
if (
|
if (
|
||||||
isinstance(block, ToolUseBlock)
|
isinstance(block, ToolUseBlock)
|
||||||
and block.name == "structured_output"
|
and block.name == "structured_output"
|
||||||
):
|
):
|
||||||
structured_json = json.dumps(block.input)
|
structured_data = response_model.model_validate(block.input)
|
||||||
self._emit_call_completed_event(
|
self._emit_call_completed_event(
|
||||||
response=structured_json,
|
response=structured_data.model_dump_json(),
|
||||||
call_type=LLMCallType.LLM_CALL,
|
call_type=LLMCallType.LLM_CALL,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
messages=params["messages"],
|
messages=params["messages"],
|
||||||
)
|
)
|
||||||
return structured_json
|
return structured_data
|
||||||
|
|
||||||
# Check if Claude wants to use tools
|
# Check if Claude wants to use tools
|
||||||
if response.content:
|
if response.content:
|
||||||
@@ -897,28 +898,29 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
|
|
||||||
if _is_pydantic_model_class(response_model):
|
if _is_pydantic_model_class(response_model):
|
||||||
if use_native_structured_output:
|
if use_native_structured_output:
|
||||||
|
structured_data = response_model.model_validate_json(full_response)
|
||||||
self._emit_call_completed_event(
|
self._emit_call_completed_event(
|
||||||
response=full_response,
|
response=structured_data.model_dump_json(),
|
||||||
call_type=LLMCallType.LLM_CALL,
|
call_type=LLMCallType.LLM_CALL,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
messages=params["messages"],
|
messages=params["messages"],
|
||||||
)
|
)
|
||||||
return full_response
|
return structured_data
|
||||||
for block in final_message.content:
|
for block in final_message.content:
|
||||||
if (
|
if (
|
||||||
isinstance(block, ToolUseBlock)
|
isinstance(block, ToolUseBlock)
|
||||||
and block.name == "structured_output"
|
and block.name == "structured_output"
|
||||||
):
|
):
|
||||||
structured_json = json.dumps(block.input)
|
structured_data = response_model.model_validate(block.input)
|
||||||
self._emit_call_completed_event(
|
self._emit_call_completed_event(
|
||||||
response=structured_json,
|
response=structured_data.model_dump_json(),
|
||||||
call_type=LLMCallType.LLM_CALL,
|
call_type=LLMCallType.LLM_CALL,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
messages=params["messages"],
|
messages=params["messages"],
|
||||||
)
|
)
|
||||||
return structured_json
|
return structured_data
|
||||||
|
|
||||||
if final_message.content:
|
if final_message.content:
|
||||||
tool_uses = [
|
tool_uses = [
|
||||||
@@ -1166,31 +1168,31 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
if _is_pydantic_model_class(response_model) and response.content:
|
if _is_pydantic_model_class(response_model) and response.content:
|
||||||
if use_native_structured_output:
|
if use_native_structured_output:
|
||||||
for block in response.content:
|
for block in response.content:
|
||||||
if isinstance(block, TextBlock):
|
if isinstance(block, (TextBlock, BetaTextBlock)):
|
||||||
structured_json = block.text
|
structured_data = response_model.model_validate_json(block.text)
|
||||||
self._emit_call_completed_event(
|
self._emit_call_completed_event(
|
||||||
response=structured_json,
|
response=structured_data.model_dump_json(),
|
||||||
call_type=LLMCallType.LLM_CALL,
|
call_type=LLMCallType.LLM_CALL,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
messages=params["messages"],
|
messages=params["messages"],
|
||||||
)
|
)
|
||||||
return structured_json
|
return structured_data
|
||||||
else:
|
else:
|
||||||
for block in response.content:
|
for block in response.content:
|
||||||
if (
|
if (
|
||||||
isinstance(block, ToolUseBlock)
|
isinstance(block, ToolUseBlock)
|
||||||
and block.name == "structured_output"
|
and block.name == "structured_output"
|
||||||
):
|
):
|
||||||
structured_json = json.dumps(block.input)
|
structured_data = response_model.model_validate(block.input)
|
||||||
self._emit_call_completed_event(
|
self._emit_call_completed_event(
|
||||||
response=structured_json,
|
response=structured_data.model_dump_json(),
|
||||||
call_type=LLMCallType.LLM_CALL,
|
call_type=LLMCallType.LLM_CALL,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
messages=params["messages"],
|
messages=params["messages"],
|
||||||
)
|
)
|
||||||
return structured_json
|
return structured_data
|
||||||
|
|
||||||
if response.content:
|
if response.content:
|
||||||
tool_uses = [
|
tool_uses = [
|
||||||
@@ -1362,28 +1364,29 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
|
|
||||||
if _is_pydantic_model_class(response_model):
|
if _is_pydantic_model_class(response_model):
|
||||||
if use_native_structured_output:
|
if use_native_structured_output:
|
||||||
|
structured_data = response_model.model_validate_json(full_response)
|
||||||
self._emit_call_completed_event(
|
self._emit_call_completed_event(
|
||||||
response=full_response,
|
response=structured_data.model_dump_json(),
|
||||||
call_type=LLMCallType.LLM_CALL,
|
call_type=LLMCallType.LLM_CALL,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
messages=params["messages"],
|
messages=params["messages"],
|
||||||
)
|
)
|
||||||
return full_response
|
return structured_data
|
||||||
for block in final_message.content:
|
for block in final_message.content:
|
||||||
if (
|
if (
|
||||||
isinstance(block, ToolUseBlock)
|
isinstance(block, ToolUseBlock)
|
||||||
and block.name == "structured_output"
|
and block.name == "structured_output"
|
||||||
):
|
):
|
||||||
structured_json = json.dumps(block.input)
|
structured_data = response_model.model_validate(block.input)
|
||||||
self._emit_call_completed_event(
|
self._emit_call_completed_event(
|
||||||
response=structured_json,
|
response=structured_data.model_dump_json(),
|
||||||
call_type=LLMCallType.LLM_CALL,
|
call_type=LLMCallType.LLM_CALL,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
messages=params["messages"],
|
messages=params["messages"],
|
||||||
)
|
)
|
||||||
return structured_json
|
return structured_data
|
||||||
|
|
||||||
if final_message.content:
|
if final_message.content:
|
||||||
tool_uses = [
|
tool_uses = [
|
||||||
|
|||||||
@@ -557,7 +557,7 @@ class AzureCompletion(BaseLLM):
|
|||||||
params: AzureCompletionParams,
|
params: AzureCompletionParams,
|
||||||
from_task: Any | None = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Any | None = None,
|
||||||
) -> str:
|
) -> BaseModel:
|
||||||
"""Validate content against response model and emit completion event.
|
"""Validate content against response model and emit completion event.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -568,24 +568,23 @@ class AzureCompletion(BaseLLM):
|
|||||||
from_agent: Agent that initiated the call
|
from_agent: Agent that initiated the call
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Validated and serialized JSON string
|
Validated Pydantic model instance
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If validation fails
|
ValueError: If validation fails
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
structured_data = response_model.model_validate_json(content)
|
structured_data = response_model.model_validate_json(content)
|
||||||
structured_json = structured_data.model_dump_json()
|
|
||||||
|
|
||||||
self._emit_call_completed_event(
|
self._emit_call_completed_event(
|
||||||
response=structured_json,
|
response=structured_data.model_dump_json(),
|
||||||
call_type=LLMCallType.LLM_CALL,
|
call_type=LLMCallType.LLM_CALL,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
messages=params["messages"],
|
messages=params["messages"],
|
||||||
)
|
)
|
||||||
|
|
||||||
return structured_json
|
return structured_data
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Failed to validate structured output with model {response_model.__name__}: {e}"
|
error_msg = f"Failed to validate structured output with model {response_model.__name__}: {e}"
|
||||||
logging.error(error_msg)
|
logging.error(error_msg)
|
||||||
|
|||||||
@@ -132,6 +132,9 @@ class GeminiCompletion(BaseLLM):
|
|||||||
self.supports_tools = bool(
|
self.supports_tools = bool(
|
||||||
version_match and float(version_match.group(1)) >= 1.5
|
version_match and float(version_match.group(1)) >= 1.5
|
||||||
)
|
)
|
||||||
|
self.is_gemini_2_0 = bool(
|
||||||
|
version_match and float(version_match.group(1)) >= 2.0
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def stop(self) -> list[str]:
|
def stop(self) -> list[str]:
|
||||||
@@ -439,6 +442,11 @@ class GeminiCompletion(BaseLLM):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
GenerateContentConfig object for Gemini API
|
GenerateContentConfig object for Gemini API
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Structured output support varies by model version:
|
||||||
|
- Gemini 1.5 and earlier: Uses response_schema (Pydantic model)
|
||||||
|
- Gemini 2.0+: Uses response_json_schema (JSON Schema) with propertyOrdering
|
||||||
"""
|
"""
|
||||||
self.tools = tools
|
self.tools = tools
|
||||||
config_params: dict[str, Any] = {}
|
config_params: dict[str, Any] = {}
|
||||||
@@ -466,9 +474,13 @@ class GeminiCompletion(BaseLLM):
|
|||||||
if response_model:
|
if response_model:
|
||||||
config_params["response_mime_type"] = "application/json"
|
config_params["response_mime_type"] = "application/json"
|
||||||
schema_output = generate_model_description(response_model)
|
schema_output = generate_model_description(response_model)
|
||||||
config_params["response_schema"] = schema_output.get("json_schema", {}).get(
|
schema = schema_output.get("json_schema", {}).get("schema", {})
|
||||||
"schema", {}
|
|
||||||
)
|
if self.is_gemini_2_0:
|
||||||
|
schema = self._add_property_ordering(schema)
|
||||||
|
config_params["response_json_schema"] = schema
|
||||||
|
else:
|
||||||
|
config_params["response_schema"] = response_model
|
||||||
|
|
||||||
# Handle tools for supported models
|
# Handle tools for supported models
|
||||||
if tools and self.supports_tools:
|
if tools and self.supports_tools:
|
||||||
@@ -632,7 +644,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
messages_for_event: list[LLMMessage],
|
messages_for_event: list[LLMMessage],
|
||||||
from_task: Any | None = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Any | None = None,
|
||||||
) -> str:
|
) -> BaseModel:
|
||||||
"""Validate content against response model and emit completion event.
|
"""Validate content against response model and emit completion event.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -643,24 +655,23 @@ class GeminiCompletion(BaseLLM):
|
|||||||
from_agent: Agent that initiated the call
|
from_agent: Agent that initiated the call
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Validated and serialized JSON string
|
Validated Pydantic model instance
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If validation fails
|
ValueError: If validation fails
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
structured_data = response_model.model_validate_json(content)
|
structured_data = response_model.model_validate_json(content)
|
||||||
structured_json = structured_data.model_dump_json()
|
|
||||||
|
|
||||||
self._emit_call_completed_event(
|
self._emit_call_completed_event(
|
||||||
response=structured_json,
|
response=structured_data.model_dump_json(),
|
||||||
call_type=LLMCallType.LLM_CALL,
|
call_type=LLMCallType.LLM_CALL,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
messages=messages_for_event,
|
messages=messages_for_event,
|
||||||
)
|
)
|
||||||
|
|
||||||
return structured_json
|
return structured_data
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Failed to validate structured output with model {response_model.__name__}: {e}"
|
error_msg = f"Failed to validate structured output with model {response_model.__name__}: {e}"
|
||||||
logging.error(error_msg)
|
logging.error(error_msg)
|
||||||
@@ -673,7 +684,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
from_task: Any | None = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Any | None = None,
|
||||||
) -> str:
|
) -> str | BaseModel:
|
||||||
"""Finalize completion response with validation and event emission.
|
"""Finalize completion response with validation and event emission.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -684,7 +695,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
from_agent: Agent that initiated the call
|
from_agent: Agent that initiated the call
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Final response content after processing
|
Final response content after processing (str or Pydantic model if response_model provided)
|
||||||
"""
|
"""
|
||||||
messages_for_event = self._convert_contents_to_dict(contents)
|
messages_for_event = self._convert_contents_to_dict(contents)
|
||||||
|
|
||||||
@@ -870,7 +881,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
from_task: Any | None = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Any | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str | list[dict[str, Any]]:
|
) -> str | BaseModel | list[dict[str, Any]]:
|
||||||
"""Finalize streaming response with usage tracking, function execution, and events.
|
"""Finalize streaming response with usage tracking, function execution, and events.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -990,7 +1001,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
from_task: Any | None = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Any | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str | Any:
|
) -> str | BaseModel | list[dict[str, Any]] | Any:
|
||||||
"""Handle streaming content generation."""
|
"""Handle streaming content generation."""
|
||||||
full_response = ""
|
full_response = ""
|
||||||
function_calls: dict[int, dict[str, Any]] = {}
|
function_calls: dict[int, dict[str, Any]] = {}
|
||||||
@@ -1190,6 +1201,36 @@ class GeminiCompletion(BaseLLM):
|
|||||||
|
|
||||||
return "".join(text_parts)
|
return "".join(text_parts)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _add_property_ordering(schema: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Add propertyOrdering to JSON schema for Gemini 2.0 compatibility.
|
||||||
|
|
||||||
|
Gemini 2.0 models require an explicit propertyOrdering list to define
|
||||||
|
the preferred structure of JSON objects. This recursively adds
|
||||||
|
propertyOrdering to all objects in the schema.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
schema: JSON schema dictionary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Modified schema with propertyOrdering added to all objects.
|
||||||
|
"""
|
||||||
|
if isinstance(schema, dict):
|
||||||
|
if schema.get("type") == "object" and "properties" in schema:
|
||||||
|
properties = schema["properties"]
|
||||||
|
if properties and "propertyOrdering" not in schema:
|
||||||
|
schema["propertyOrdering"] = list(properties.keys())
|
||||||
|
|
||||||
|
for value in schema.values():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
GeminiCompletion._add_property_ordering(value)
|
||||||
|
elif isinstance(value, list):
|
||||||
|
for item in value:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
GeminiCompletion._add_property_ordering(item)
|
||||||
|
|
||||||
|
return schema
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _convert_contents_to_dict(
|
def _convert_contents_to_dict(
|
||||||
contents: list[types.Content],
|
contents: list[types.Content],
|
||||||
|
|||||||
@@ -1570,15 +1570,14 @@ class OpenAICompletion(BaseLLM):
|
|||||||
|
|
||||||
parsed_object = parsed_response.choices[0].message.parsed
|
parsed_object = parsed_response.choices[0].message.parsed
|
||||||
if parsed_object:
|
if parsed_object:
|
||||||
structured_json = parsed_object.model_dump_json()
|
|
||||||
self._emit_call_completed_event(
|
self._emit_call_completed_event(
|
||||||
response=structured_json,
|
response=parsed_object.model_dump_json(),
|
||||||
call_type=LLMCallType.LLM_CALL,
|
call_type=LLMCallType.LLM_CALL,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
messages=params["messages"],
|
messages=params["messages"],
|
||||||
)
|
)
|
||||||
return structured_json
|
return parsed_object
|
||||||
|
|
||||||
response: ChatCompletion = self.client.chat.completions.create(**params)
|
response: ChatCompletion = self.client.chat.completions.create(**params)
|
||||||
|
|
||||||
@@ -1692,7 +1691,7 @@ class OpenAICompletion(BaseLLM):
|
|||||||
from_task: Any | None = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Any | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str:
|
) -> str | BaseModel:
|
||||||
"""Handle streaming chat completion."""
|
"""Handle streaming chat completion."""
|
||||||
full_response = ""
|
full_response = ""
|
||||||
tool_calls: dict[int, dict[str, Any]] = {}
|
tool_calls: dict[int, dict[str, Any]] = {}
|
||||||
@@ -1728,15 +1727,14 @@ class OpenAICompletion(BaseLLM):
|
|||||||
if final_completion.choices:
|
if final_completion.choices:
|
||||||
parsed_result = final_completion.choices[0].message.parsed
|
parsed_result = final_completion.choices[0].message.parsed
|
||||||
if parsed_result:
|
if parsed_result:
|
||||||
structured_json = parsed_result.model_dump_json()
|
|
||||||
self._emit_call_completed_event(
|
self._emit_call_completed_event(
|
||||||
response=structured_json,
|
response=parsed_result.model_dump_json(),
|
||||||
call_type=LLMCallType.LLM_CALL,
|
call_type=LLMCallType.LLM_CALL,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
messages=params["messages"],
|
messages=params["messages"],
|
||||||
)
|
)
|
||||||
return structured_json
|
return parsed_result
|
||||||
|
|
||||||
logging.error("Failed to get parsed result from stream")
|
logging.error("Failed to get parsed result from stream")
|
||||||
return ""
|
return ""
|
||||||
@@ -1887,15 +1885,14 @@ class OpenAICompletion(BaseLLM):
|
|||||||
|
|
||||||
parsed_object = parsed_response.choices[0].message.parsed
|
parsed_object = parsed_response.choices[0].message.parsed
|
||||||
if parsed_object:
|
if parsed_object:
|
||||||
structured_json = parsed_object.model_dump_json()
|
|
||||||
self._emit_call_completed_event(
|
self._emit_call_completed_event(
|
||||||
response=structured_json,
|
response=parsed_object.model_dump_json(),
|
||||||
call_type=LLMCallType.LLM_CALL,
|
call_type=LLMCallType.LLM_CALL,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
messages=params["messages"],
|
messages=params["messages"],
|
||||||
)
|
)
|
||||||
return structured_json
|
return parsed_object
|
||||||
|
|
||||||
response: ChatCompletion = await self.async_client.chat.completions.create(
|
response: ChatCompletion = await self.async_client.chat.completions.create(
|
||||||
**params
|
**params
|
||||||
@@ -2006,7 +2003,7 @@ class OpenAICompletion(BaseLLM):
|
|||||||
from_task: Any | None = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Any | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str:
|
) -> str | BaseModel:
|
||||||
"""Handle async streaming chat completion."""
|
"""Handle async streaming chat completion."""
|
||||||
full_response = ""
|
full_response = ""
|
||||||
tool_calls: dict[int, dict[str, Any]] = {}
|
tool_calls: dict[int, dict[str, Any]] = {}
|
||||||
@@ -2044,17 +2041,16 @@ class OpenAICompletion(BaseLLM):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
parsed_object = response_model.model_validate_json(accumulated_content)
|
parsed_object = response_model.model_validate_json(accumulated_content)
|
||||||
structured_json = parsed_object.model_dump_json()
|
|
||||||
|
|
||||||
self._emit_call_completed_event(
|
self._emit_call_completed_event(
|
||||||
response=structured_json,
|
response=parsed_object.model_dump_json(),
|
||||||
call_type=LLMCallType.LLM_CALL,
|
call_type=LLMCallType.LLM_CALL,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
messages=params["messages"],
|
messages=params["messages"],
|
||||||
)
|
)
|
||||||
|
|
||||||
return structured_json
|
return parsed_object
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Failed to parse structured output from stream: {e}")
|
logging.error(f"Failed to parse structured output from stream: {e}")
|
||||||
self._emit_call_completed_event(
|
self._emit_call_completed_event(
|
||||||
|
|||||||
@@ -327,7 +327,7 @@ def get_llm_response(
|
|||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
executor_context: CrewAgentExecutor | AgentExecutor | LiteAgent | None = None,
|
executor_context: CrewAgentExecutor | AgentExecutor | LiteAgent | None = None,
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
) -> str | Any:
|
) -> str | BaseModel | Any:
|
||||||
"""Call the LLM and return the response, handling any invalid responses.
|
"""Call the LLM and return the response, handling any invalid responses.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -341,10 +341,11 @@ def get_llm_response(
|
|||||||
from_agent: Optional agent context for the LLM call.
|
from_agent: Optional agent context for the LLM call.
|
||||||
response_model: Optional Pydantic model for structured outputs.
|
response_model: Optional Pydantic model for structured outputs.
|
||||||
executor_context: Optional executor context for hook invocation.
|
executor_context: Optional executor context for hook invocation.
|
||||||
|
verbose: Whether to print output.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The response from the LLM as a string, or tool call results if
|
The response from the LLM as a string, Pydantic model (when response_model is provided),
|
||||||
native function calling is used.
|
or tool call results if native function calling is used.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
Exception: If an error occurs.
|
Exception: If an error occurs.
|
||||||
@@ -393,7 +394,7 @@ async def aget_llm_response(
|
|||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
executor_context: CrewAgentExecutor | AgentExecutor | None = None,
|
executor_context: CrewAgentExecutor | AgentExecutor | None = None,
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
) -> str | Any:
|
) -> str | BaseModel | Any:
|
||||||
"""Call the LLM asynchronously and return the response.
|
"""Call the LLM asynchronously and return the response.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -409,8 +410,8 @@ async def aget_llm_response(
|
|||||||
executor_context: Optional executor context for hook invocation.
|
executor_context: Optional executor context for hook invocation.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The response from the LLM as a string, or tool call results if
|
The response from the LLM as a string, Pydantic model (when response_model is provided),
|
||||||
native function calling is used.
|
or tool call results if native function calling is used.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
Exception: If an error occurs.
|
Exception: If an error occurs.
|
||||||
@@ -986,32 +987,41 @@ def _setup_before_llm_call_hooks(
|
|||||||
|
|
||||||
def _setup_after_llm_call_hooks(
|
def _setup_after_llm_call_hooks(
|
||||||
executor_context: CrewAgentExecutor | AgentExecutor | LiteAgent | None,
|
executor_context: CrewAgentExecutor | AgentExecutor | LiteAgent | None,
|
||||||
answer: str,
|
answer: str | BaseModel,
|
||||||
printer: Printer,
|
printer: Printer,
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
) -> str:
|
) -> str | BaseModel:
|
||||||
"""Setup and invoke after_llm_call hooks for the executor context.
|
"""Setup and invoke after_llm_call hooks for the executor context.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
executor_context: The executor context to setup the hooks for.
|
executor_context: The executor context to setup the hooks for.
|
||||||
answer: The LLM response string.
|
answer: The LLM response (string or Pydantic model).
|
||||||
printer: Printer instance for error logging.
|
printer: Printer instance for error logging.
|
||||||
verbose: Whether to print output.
|
verbose: Whether to print output.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The potentially modified response string.
|
The potentially modified response (string or Pydantic model).
|
||||||
"""
|
"""
|
||||||
if executor_context and executor_context.after_llm_call_hooks:
|
if executor_context and executor_context.after_llm_call_hooks:
|
||||||
from crewai.hooks.llm_hooks import LLMCallHookContext
|
from crewai.hooks.llm_hooks import LLMCallHookContext
|
||||||
|
|
||||||
original_messages = executor_context.messages
|
original_messages = executor_context.messages
|
||||||
|
|
||||||
hook_context = LLMCallHookContext(executor_context, response=answer)
|
# For Pydantic models, serialize to JSON for hooks
|
||||||
|
if isinstance(answer, BaseModel):
|
||||||
|
pydantic_answer = answer
|
||||||
|
hook_response: str = pydantic_answer.model_dump_json()
|
||||||
|
original_json: str = hook_response
|
||||||
|
else:
|
||||||
|
pydantic_answer = None
|
||||||
|
hook_response = str(answer)
|
||||||
|
|
||||||
|
hook_context = LLMCallHookContext(executor_context, response=hook_response)
|
||||||
try:
|
try:
|
||||||
for hook in executor_context.after_llm_call_hooks:
|
for hook in executor_context.after_llm_call_hooks:
|
||||||
modified_response = hook(hook_context)
|
modified_response = hook(hook_context)
|
||||||
if modified_response is not None and isinstance(modified_response, str):
|
if modified_response is not None and isinstance(modified_response, str):
|
||||||
answer = modified_response
|
hook_response = modified_response
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if verbose:
|
if verbose:
|
||||||
@@ -1035,4 +1045,21 @@ def _setup_after_llm_call_hooks(
|
|||||||
else:
|
else:
|
||||||
executor_context.messages = []
|
executor_context.messages = []
|
||||||
|
|
||||||
|
# If hooks modified the response, update answer accordingly
|
||||||
|
if pydantic_answer is not None:
|
||||||
|
# For Pydantic models, reparse the JSON if it was modified
|
||||||
|
if hook_response != original_json:
|
||||||
|
try:
|
||||||
|
model_class: type[BaseModel] = type(pydantic_answer)
|
||||||
|
answer = model_class.model_validate_json(hook_response)
|
||||||
|
except Exception as e:
|
||||||
|
if verbose:
|
||||||
|
printer.print(
|
||||||
|
content=f"Warning: Hook modified response but failed to reparse as {type(pydantic_answer).__name__}: {e}. Using original model.",
|
||||||
|
color="yellow",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# For string responses, use the hook-modified response
|
||||||
|
answer = hook_response
|
||||||
|
|
||||||
return answer
|
return answer
|
||||||
|
|||||||
@@ -62,7 +62,10 @@ class Converter(OutputConverter):
|
|||||||
],
|
],
|
||||||
response_model=self.model,
|
response_model=self.model,
|
||||||
)
|
)
|
||||||
result = self.model.model_validate_json(response)
|
if isinstance(response, BaseModel):
|
||||||
|
result = response
|
||||||
|
else:
|
||||||
|
result = self.model.model_validate_json(response)
|
||||||
else:
|
else:
|
||||||
response = self.llm.call(
|
response = self.llm.call(
|
||||||
[
|
[
|
||||||
|
|||||||
@@ -157,10 +157,10 @@ async def test_anthropic_async_with_response_model():
|
|||||||
"Say hello in French",
|
"Say hello in French",
|
||||||
response_model=GreetingResponse
|
response_model=GreetingResponse
|
||||||
)
|
)
|
||||||
model = GreetingResponse.model_validate_json(result)
|
# When response_model is provided, the result is already a parsed Pydantic model instance
|
||||||
assert isinstance(model, GreetingResponse)
|
assert isinstance(result, GreetingResponse)
|
||||||
assert isinstance(model.greeting, str)
|
assert isinstance(result.greeting, str)
|
||||||
assert isinstance(model.language, str)
|
assert isinstance(result.language, str)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.vcr()
|
@pytest.mark.vcr()
|
||||||
|
|||||||
@@ -799,3 +799,131 @@ def test_google_express_mode_works() -> None:
|
|||||||
assert result.token_usage.prompt_tokens > 0
|
assert result.token_usage.prompt_tokens > 0
|
||||||
assert result.token_usage.completion_tokens > 0
|
assert result.token_usage.completion_tokens > 0
|
||||||
assert result.token_usage.successful_requests >= 1
|
assert result.token_usage.successful_requests >= 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_gemini_2_0_model_detection():
|
||||||
|
"""Test that Gemini 2.0 models are properly detected."""
|
||||||
|
# Test Gemini 2.0 models
|
||||||
|
llm_2_0 = LLM(model="google/gemini-2.0-flash-001")
|
||||||
|
from crewai.llms.providers.gemini.completion import GeminiCompletion
|
||||||
|
assert isinstance(llm_2_0, GeminiCompletion)
|
||||||
|
assert llm_2_0.is_gemini_2_0 is True
|
||||||
|
|
||||||
|
llm_2_5 = LLM(model="google/gemini-2.5-flash")
|
||||||
|
assert isinstance(llm_2_5, GeminiCompletion)
|
||||||
|
assert llm_2_5.is_gemini_2_0 is True
|
||||||
|
|
||||||
|
# Test non-2.0 models
|
||||||
|
llm_1_5 = LLM(model="google/gemini-1.5-pro")
|
||||||
|
assert isinstance(llm_1_5, GeminiCompletion)
|
||||||
|
assert llm_1_5.is_gemini_2_0 is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_property_ordering_to_schema():
|
||||||
|
"""Test that _add_property_ordering correctly adds propertyOrdering to schemas."""
|
||||||
|
from crewai.llms.providers.gemini.completion import GeminiCompletion
|
||||||
|
|
||||||
|
# Test simple object schema
|
||||||
|
simple_schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {"type": "string"},
|
||||||
|
"age": {"type": "integer"},
|
||||||
|
"email": {"type": "string"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result = GeminiCompletion._add_property_ordering(simple_schema)
|
||||||
|
|
||||||
|
assert "propertyOrdering" in result
|
||||||
|
assert result["propertyOrdering"] == ["name", "age", "email"]
|
||||||
|
|
||||||
|
# Test nested object schema
|
||||||
|
nested_schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"user": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {"type": "string"},
|
||||||
|
"contact": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"email": {"type": "string"},
|
||||||
|
"phone": {"type": "string"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": {"type": "integer"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result = GeminiCompletion._add_property_ordering(nested_schema)
|
||||||
|
|
||||||
|
assert "propertyOrdering" in result
|
||||||
|
assert result["propertyOrdering"] == ["user", "id"]
|
||||||
|
assert "propertyOrdering" in result["properties"]["user"]
|
||||||
|
assert result["properties"]["user"]["propertyOrdering"] == ["name", "contact"]
|
||||||
|
assert "propertyOrdering" in result["properties"]["user"]["properties"]["contact"]
|
||||||
|
assert result["properties"]["user"]["properties"]["contact"]["propertyOrdering"] == ["email", "phone"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_gemini_2_0_response_model_with_property_ordering():
|
||||||
|
"""Test that Gemini 2.0 models include propertyOrdering in response schemas."""
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
class TestResponse(BaseModel):
|
||||||
|
"""Test response model."""
|
||||||
|
name: str = Field(..., description="The name")
|
||||||
|
age: int = Field(..., description="The age")
|
||||||
|
email: str = Field(..., description="The email")
|
||||||
|
|
||||||
|
llm = LLM(model="google/gemini-2.0-flash-001")
|
||||||
|
|
||||||
|
# Prepare generation config with response model
|
||||||
|
config = llm._prepare_generation_config(response_model=TestResponse)
|
||||||
|
|
||||||
|
# Verify that the config has response_json_schema
|
||||||
|
assert hasattr(config, 'response_json_schema') or 'response_json_schema' in config.__dict__
|
||||||
|
|
||||||
|
# Get the schema
|
||||||
|
if hasattr(config, 'response_json_schema'):
|
||||||
|
schema = config.response_json_schema
|
||||||
|
else:
|
||||||
|
schema = config.__dict__.get('response_json_schema', {})
|
||||||
|
|
||||||
|
# Verify propertyOrdering is present for Gemini 2.0
|
||||||
|
assert "propertyOrdering" in schema
|
||||||
|
assert "name" in schema["propertyOrdering"]
|
||||||
|
assert "age" in schema["propertyOrdering"]
|
||||||
|
assert "email" in schema["propertyOrdering"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_gemini_1_5_response_model_uses_response_schema():
|
||||||
|
"""Test that Gemini 1.5 models use response_schema parameter (not response_json_schema)."""
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
class TestResponse(BaseModel):
|
||||||
|
"""Test response model."""
|
||||||
|
name: str = Field(..., description="The name")
|
||||||
|
age: int = Field(..., description="The age")
|
||||||
|
|
||||||
|
llm = LLM(model="google/gemini-1.5-pro")
|
||||||
|
|
||||||
|
# Prepare generation config with response model
|
||||||
|
config = llm._prepare_generation_config(response_model=TestResponse)
|
||||||
|
|
||||||
|
# Verify that the config uses response_schema (not response_json_schema)
|
||||||
|
assert hasattr(config, 'response_schema') or 'response_schema' in config.__dict__
|
||||||
|
assert not (hasattr(config, 'response_json_schema') and config.response_json_schema is not None)
|
||||||
|
|
||||||
|
# Get the schema
|
||||||
|
if hasattr(config, 'response_schema'):
|
||||||
|
schema = config.response_schema
|
||||||
|
else:
|
||||||
|
schema = config.__dict__.get('response_schema')
|
||||||
|
|
||||||
|
# For Gemini 1.5, response_schema should be the Pydantic model itself
|
||||||
|
# The SDK handles conversion internally
|
||||||
|
assert schema is TestResponse or isinstance(schema, type)
|
||||||
|
|||||||
@@ -540,7 +540,9 @@ def test_openai_streaming_with_response_model():
|
|||||||
result = llm.call("Test question", response_model=TestResponse)
|
result = llm.call("Test question", response_model=TestResponse)
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert isinstance(result, str)
|
assert isinstance(result, TestResponse)
|
||||||
|
assert result.answer == "test"
|
||||||
|
assert result.confidence == 0.95
|
||||||
|
|
||||||
assert mock_stream.called
|
assert mock_stream.called
|
||||||
call_kwargs = mock_stream.call_args[1]
|
call_kwargs = mock_stream.call_args[1]
|
||||||
|
|||||||
Reference in New Issue
Block a user