fix: improve structured output handling across providers and agents

- add gemini 2.0 schema support using response_json_schema with propertyordering while retaining backward compatibility for earlier models
- refactor llm completions to return validated pydantic models when a response_model is provided, updating hooks, types, and tests for consistent structured outputs
- extend agentfinish and executors to support basemodel outputs, improve anthropic structured parsing, and clean up schema utilities, tests, and original_json handling
This commit is contained in:
Greyson LaLonde
2026-01-28 16:59:55 -05:00
committed by GitHub
parent 1e27cf3f0f
commit a731efac8d
11 changed files with 391 additions and 103 deletions

View File

@@ -348,18 +348,36 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
# breakpoint()
if self.response_model is not None:
try:
self.response_model.model_validate_json(answer)
formatted_answer = AgentFinish(
thought="",
output=answer,
text=answer,
)
if isinstance(answer, BaseModel):
output_json = answer.model_dump_json()
formatted_answer = AgentFinish(
thought="",
output=answer,
text=output_json,
)
else:
self.response_model.model_validate_json(answer)
formatted_answer = AgentFinish(
thought="",
output=answer,
text=answer,
)
except ValidationError:
# If validation fails, convert BaseModel to JSON string for parsing
answer_str = (
answer.model_dump_json()
if isinstance(answer, BaseModel)
else str(answer)
)
formatted_answer = process_llm_response(
answer, self.use_stop_words
answer_str, self.use_stop_words
) # type: ignore[assignment]
else:
formatted_answer = process_llm_response(answer, self.use_stop_words) # type: ignore[assignment]
# When no response_model, answer should be a string
answer_str = str(answer) if not isinstance(answer, str) else answer
formatted_answer = process_llm_response(
answer_str, self.use_stop_words
) # type: ignore[assignment]
if isinstance(formatted_answer, AgentAction):
# Extract agent fingerprint if available
@@ -520,6 +538,18 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self._show_logs(formatted_answer)
return formatted_answer
if isinstance(answer, BaseModel):
output_json = answer.model_dump_json()
formatted_answer = AgentFinish(
thought="",
output=answer,
text=output_json,
)
self._invoke_step_callback(formatted_answer)
self._append_message(output_json)
self._show_logs(formatted_answer)
return formatted_answer
# Unexpected response type, treat as final answer
formatted_answer = AgentFinish(
thought="",
@@ -570,11 +600,20 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
verbose=self.agent.verbose,
)
formatted_answer = AgentFinish(
thought="",
output=str(answer),
text=str(answer),
)
if isinstance(answer, BaseModel):
output_json = answer.model_dump_json()
formatted_answer = AgentFinish(
thought="",
output=answer,
text=output_json,
)
else:
answer_str = answer if isinstance(answer, str) else str(answer)
formatted_answer = AgentFinish(
thought="",
output=answer_str,
text=answer_str,
)
self._show_logs(formatted_answer)
return formatted_answer
@@ -1031,18 +1070,36 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
if self.response_model is not None:
try:
self.response_model.model_validate_json(answer)
formatted_answer = AgentFinish(
thought="",
output=answer,
text=answer,
)
if isinstance(answer, BaseModel):
output_json = answer.model_dump_json()
formatted_answer = AgentFinish(
thought="",
output=answer,
text=output_json,
)
else:
self.response_model.model_validate_json(answer)
formatted_answer = AgentFinish(
thought="",
output=answer,
text=answer,
)
except ValidationError:
# If validation fails, convert BaseModel to JSON string for parsing
answer_str = (
answer.model_dump_json()
if isinstance(answer, BaseModel)
else str(answer)
)
formatted_answer = process_llm_response(
answer, self.use_stop_words
answer_str, self.use_stop_words
) # type: ignore[assignment]
else:
formatted_answer = process_llm_response(answer, self.use_stop_words) # type: ignore[assignment]
# When no response_model, answer should be a string
answer_str = str(answer) if not isinstance(answer, str) else answer
formatted_answer = process_llm_response(
answer_str, self.use_stop_words
) # type: ignore[assignment]
if isinstance(formatted_answer, AgentAction):
fingerprint_context = {}
@@ -1194,6 +1251,18 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self._show_logs(formatted_answer)
return formatted_answer
if isinstance(answer, BaseModel):
output_json = answer.model_dump_json()
formatted_answer = AgentFinish(
thought="",
output=answer,
text=output_json,
)
self._invoke_step_callback(formatted_answer)
self._append_message(output_json)
self._show_logs(formatted_answer)
return formatted_answer
# Unexpected response type, treat as final answer
formatted_answer = AgentFinish(
thought="",
@@ -1244,11 +1313,20 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
verbose=self.agent.verbose,
)
formatted_answer = AgentFinish(
thought="",
output=str(answer),
text=str(answer),
)
if isinstance(answer, BaseModel):
output_json = answer.model_dump_json()
formatted_answer = AgentFinish(
thought="",
output=answer,
text=output_json,
)
else:
answer_str = answer if isinstance(answer, str) else str(answer)
formatted_answer = AgentFinish(
thought="",
output=answer_str,
text=answer_str,
)
self._show_logs(formatted_answer)
return formatted_answer
@@ -1421,7 +1499,12 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
Returns:
Final answer after feedback.
"""
human_feedback = self._ask_human_input(formatted_answer.output)
output_str = (
formatted_answer.output
if isinstance(formatted_answer.output, str)
else formatted_answer.output.model_dump_json()
)
human_feedback = self._ask_human_input(output_str)
if self._is_training_mode():
return self._handle_training_feedback(formatted_answer, human_feedback)
@@ -1480,7 +1563,12 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self.ask_for_human_input = False
else:
answer = self._process_feedback_iteration(feedback)
feedback = self._ask_human_input(answer.output)
output_str = (
answer.output
if isinstance(answer.output, str)
else answer.output.model_dump_json()
)
feedback = self._ask_human_input(output_str)
return answer

View File

@@ -8,6 +8,7 @@ AgentAction or AgentFinish objects.
from dataclasses import dataclass
from json_repair import repair_json # type: ignore[import-untyped]
from pydantic import BaseModel
from crewai.agents.constants import (
ACTION_INPUT_ONLY_REGEX,
@@ -40,7 +41,7 @@ class AgentFinish:
"""Represents the final answer from an agent."""
thought: str
output: str
output: str | BaseModel
text: str

View File

@@ -23,7 +23,7 @@ if TYPE_CHECKING:
try:
from anthropic import Anthropic, AsyncAnthropic, transform_schema
from anthropic.types import Message, TextBlock, ThinkingBlock, ToolUseBlock
from anthropic.types.beta import BetaMessage
from anthropic.types.beta import BetaMessage, BetaTextBlock
import httpx
except ImportError:
raise ImportError(
@@ -337,6 +337,7 @@ class AnthropicCompletion(BaseLLM):
available_functions: Available functions for tool calling
from_task: Task that initiated the call
from_agent: Agent that initiated the call
response_model: Optional response model.
Returns:
Chat completion response or tool call result
@@ -677,31 +678,31 @@ class AnthropicCompletion(BaseLLM):
if _is_pydantic_model_class(response_model) and response.content:
if use_native_structured_output:
for block in response.content:
if isinstance(block, TextBlock):
structured_json = block.text
if isinstance(block, (TextBlock, BetaTextBlock)):
structured_data = response_model.model_validate_json(block.text)
self._emit_call_completed_event(
response=structured_json,
response=structured_data.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
return structured_data
else:
for block in response.content:
if (
isinstance(block, ToolUseBlock)
and block.name == "structured_output"
):
structured_json = json.dumps(block.input)
structured_data = response_model.model_validate(block.input)
self._emit_call_completed_event(
response=structured_json,
response=structured_data.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
return structured_data
# Check if Claude wants to use tools
if response.content:
@@ -897,28 +898,29 @@ class AnthropicCompletion(BaseLLM):
if _is_pydantic_model_class(response_model):
if use_native_structured_output:
structured_data = response_model.model_validate_json(full_response)
self._emit_call_completed_event(
response=full_response,
response=structured_data.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return full_response
return structured_data
for block in final_message.content:
if (
isinstance(block, ToolUseBlock)
and block.name == "structured_output"
):
structured_json = json.dumps(block.input)
structured_data = response_model.model_validate(block.input)
self._emit_call_completed_event(
response=structured_json,
response=structured_data.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
return structured_data
if final_message.content:
tool_uses = [
@@ -1166,31 +1168,31 @@ class AnthropicCompletion(BaseLLM):
if _is_pydantic_model_class(response_model) and response.content:
if use_native_structured_output:
for block in response.content:
if isinstance(block, TextBlock):
structured_json = block.text
if isinstance(block, (TextBlock, BetaTextBlock)):
structured_data = response_model.model_validate_json(block.text)
self._emit_call_completed_event(
response=structured_json,
response=structured_data.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
return structured_data
else:
for block in response.content:
if (
isinstance(block, ToolUseBlock)
and block.name == "structured_output"
):
structured_json = json.dumps(block.input)
structured_data = response_model.model_validate(block.input)
self._emit_call_completed_event(
response=structured_json,
response=structured_data.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
return structured_data
if response.content:
tool_uses = [
@@ -1362,28 +1364,29 @@ class AnthropicCompletion(BaseLLM):
if _is_pydantic_model_class(response_model):
if use_native_structured_output:
structured_data = response_model.model_validate_json(full_response)
self._emit_call_completed_event(
response=full_response,
response=structured_data.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return full_response
return structured_data
for block in final_message.content:
if (
isinstance(block, ToolUseBlock)
and block.name == "structured_output"
):
structured_json = json.dumps(block.input)
structured_data = response_model.model_validate(block.input)
self._emit_call_completed_event(
response=structured_json,
response=structured_data.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
return structured_data
if final_message.content:
tool_uses = [

View File

@@ -557,7 +557,7 @@ class AzureCompletion(BaseLLM):
params: AzureCompletionParams,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str:
) -> BaseModel:
"""Validate content against response model and emit completion event.
Args:
@@ -568,24 +568,23 @@ class AzureCompletion(BaseLLM):
from_agent: Agent that initiated the call
Returns:
Validated and serialized JSON string
Validated Pydantic model instance
Raises:
ValueError: If validation fails
"""
try:
structured_data = response_model.model_validate_json(content)
structured_json = structured_data.model_dump_json()
self._emit_call_completed_event(
response=structured_json,
response=structured_data.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
return structured_data
except Exception as e:
error_msg = f"Failed to validate structured output with model {response_model.__name__}: {e}"
logging.error(error_msg)

View File

@@ -132,6 +132,9 @@ class GeminiCompletion(BaseLLM):
self.supports_tools = bool(
version_match and float(version_match.group(1)) >= 1.5
)
self.is_gemini_2_0 = bool(
version_match and float(version_match.group(1)) >= 2.0
)
@property
def stop(self) -> list[str]:
@@ -439,6 +442,11 @@ class GeminiCompletion(BaseLLM):
Returns:
GenerateContentConfig object for Gemini API
Note:
Structured output support varies by model version:
- Gemini 1.5 and earlier: Uses response_schema (Pydantic model)
- Gemini 2.0+: Uses response_json_schema (JSON Schema) with propertyOrdering
"""
self.tools = tools
config_params: dict[str, Any] = {}
@@ -466,9 +474,13 @@ class GeminiCompletion(BaseLLM):
if response_model:
config_params["response_mime_type"] = "application/json"
schema_output = generate_model_description(response_model)
config_params["response_schema"] = schema_output.get("json_schema", {}).get(
"schema", {}
)
schema = schema_output.get("json_schema", {}).get("schema", {})
if self.is_gemini_2_0:
schema = self._add_property_ordering(schema)
config_params["response_json_schema"] = schema
else:
config_params["response_schema"] = response_model
# Handle tools for supported models
if tools and self.supports_tools:
@@ -632,7 +644,7 @@ class GeminiCompletion(BaseLLM):
messages_for_event: list[LLMMessage],
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str:
) -> BaseModel:
"""Validate content against response model and emit completion event.
Args:
@@ -643,24 +655,23 @@ class GeminiCompletion(BaseLLM):
from_agent: Agent that initiated the call
Returns:
Validated and serialized JSON string
Validated Pydantic model instance
Raises:
ValueError: If validation fails
"""
try:
structured_data = response_model.model_validate_json(content)
structured_json = structured_data.model_dump_json()
self._emit_call_completed_event(
response=structured_json,
response=structured_data.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=messages_for_event,
)
return structured_json
return structured_data
except Exception as e:
error_msg = f"Failed to validate structured output with model {response_model.__name__}: {e}"
logging.error(error_msg)
@@ -673,7 +684,7 @@ class GeminiCompletion(BaseLLM):
response_model: type[BaseModel] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str:
) -> str | BaseModel:
"""Finalize completion response with validation and event emission.
Args:
@@ -684,7 +695,7 @@ class GeminiCompletion(BaseLLM):
from_agent: Agent that initiated the call
Returns:
Final response content after processing
Final response content after processing (str or Pydantic model if response_model provided)
"""
messages_for_event = self._convert_contents_to_dict(contents)
@@ -870,7 +881,7 @@ class GeminiCompletion(BaseLLM):
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str | list[dict[str, Any]]:
) -> str | BaseModel | list[dict[str, Any]]:
"""Finalize streaming response with usage tracking, function execution, and events.
Args:
@@ -990,7 +1001,7 @@ class GeminiCompletion(BaseLLM):
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
) -> str | BaseModel | list[dict[str, Any]] | Any:
"""Handle streaming content generation."""
full_response = ""
function_calls: dict[int, dict[str, Any]] = {}
@@ -1190,6 +1201,36 @@ class GeminiCompletion(BaseLLM):
return "".join(text_parts)
@staticmethod
def _add_property_ordering(schema: dict[str, Any]) -> dict[str, Any]:
"""Add propertyOrdering to JSON schema for Gemini 2.0 compatibility.
Gemini 2.0 models require an explicit propertyOrdering list to define
the preferred structure of JSON objects. This recursively adds
propertyOrdering to all objects in the schema.
Args:
schema: JSON schema dictionary.
Returns:
Modified schema with propertyOrdering added to all objects.
"""
if isinstance(schema, dict):
if schema.get("type") == "object" and "properties" in schema:
properties = schema["properties"]
if properties and "propertyOrdering" not in schema:
schema["propertyOrdering"] = list(properties.keys())
for value in schema.values():
if isinstance(value, dict):
GeminiCompletion._add_property_ordering(value)
elif isinstance(value, list):
for item in value:
if isinstance(item, dict):
GeminiCompletion._add_property_ordering(item)
return schema
@staticmethod
def _convert_contents_to_dict(
contents: list[types.Content],

View File

@@ -1570,15 +1570,14 @@ class OpenAICompletion(BaseLLM):
parsed_object = parsed_response.choices[0].message.parsed
if parsed_object:
structured_json = parsed_object.model_dump_json()
self._emit_call_completed_event(
response=structured_json,
response=parsed_object.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
return parsed_object
response: ChatCompletion = self.client.chat.completions.create(**params)
@@ -1692,7 +1691,7 @@ class OpenAICompletion(BaseLLM):
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str:
) -> str | BaseModel:
"""Handle streaming chat completion."""
full_response = ""
tool_calls: dict[int, dict[str, Any]] = {}
@@ -1728,15 +1727,14 @@ class OpenAICompletion(BaseLLM):
if final_completion.choices:
parsed_result = final_completion.choices[0].message.parsed
if parsed_result:
structured_json = parsed_result.model_dump_json()
self._emit_call_completed_event(
response=structured_json,
response=parsed_result.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
return parsed_result
logging.error("Failed to get parsed result from stream")
return ""
@@ -1887,15 +1885,14 @@ class OpenAICompletion(BaseLLM):
parsed_object = parsed_response.choices[0].message.parsed
if parsed_object:
structured_json = parsed_object.model_dump_json()
self._emit_call_completed_event(
response=structured_json,
response=parsed_object.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
return parsed_object
response: ChatCompletion = await self.async_client.chat.completions.create(
**params
@@ -2006,7 +2003,7 @@ class OpenAICompletion(BaseLLM):
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str:
) -> str | BaseModel:
"""Handle async streaming chat completion."""
full_response = ""
tool_calls: dict[int, dict[str, Any]] = {}
@@ -2044,17 +2041,16 @@ class OpenAICompletion(BaseLLM):
try:
parsed_object = response_model.model_validate_json(accumulated_content)
structured_json = parsed_object.model_dump_json()
self._emit_call_completed_event(
response=structured_json,
response=parsed_object.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
return parsed_object
except Exception as e:
logging.error(f"Failed to parse structured output from stream: {e}")
self._emit_call_completed_event(

View File

@@ -327,7 +327,7 @@ def get_llm_response(
response_model: type[BaseModel] | None = None,
executor_context: CrewAgentExecutor | AgentExecutor | LiteAgent | None = None,
verbose: bool = True,
) -> str | Any:
) -> str | BaseModel | Any:
"""Call the LLM and return the response, handling any invalid responses.
Args:
@@ -341,10 +341,11 @@ def get_llm_response(
from_agent: Optional agent context for the LLM call.
response_model: Optional Pydantic model for structured outputs.
executor_context: Optional executor context for hook invocation.
verbose: Whether to print output.
Returns:
The response from the LLM as a string, or tool call results if
native function calling is used.
The response from the LLM as a string, Pydantic model (when response_model is provided),
or tool call results if native function calling is used.
Raises:
Exception: If an error occurs.
@@ -393,7 +394,7 @@ async def aget_llm_response(
response_model: type[BaseModel] | None = None,
executor_context: CrewAgentExecutor | AgentExecutor | None = None,
verbose: bool = True,
) -> str | Any:
) -> str | BaseModel | Any:
"""Call the LLM asynchronously and return the response.
Args:
@@ -409,8 +410,8 @@ async def aget_llm_response(
executor_context: Optional executor context for hook invocation.
Returns:
The response from the LLM as a string, or tool call results if
native function calling is used.
The response from the LLM as a string, Pydantic model (when response_model is provided),
or tool call results if native function calling is used.
Raises:
Exception: If an error occurs.
@@ -986,32 +987,41 @@ def _setup_before_llm_call_hooks(
def _setup_after_llm_call_hooks(
executor_context: CrewAgentExecutor | AgentExecutor | LiteAgent | None,
answer: str,
answer: str | BaseModel,
printer: Printer,
verbose: bool = True,
) -> str:
) -> str | BaseModel:
"""Setup and invoke after_llm_call hooks for the executor context.
Args:
executor_context: The executor context to setup the hooks for.
answer: The LLM response string.
answer: The LLM response (string or Pydantic model).
printer: Printer instance for error logging.
verbose: Whether to print output.
Returns:
The potentially modified response string.
The potentially modified response (string or Pydantic model).
"""
if executor_context and executor_context.after_llm_call_hooks:
from crewai.hooks.llm_hooks import LLMCallHookContext
original_messages = executor_context.messages
hook_context = LLMCallHookContext(executor_context, response=answer)
# For Pydantic models, serialize to JSON for hooks
if isinstance(answer, BaseModel):
pydantic_answer = answer
hook_response: str = pydantic_answer.model_dump_json()
original_json: str = hook_response
else:
pydantic_answer = None
hook_response = str(answer)
hook_context = LLMCallHookContext(executor_context, response=hook_response)
try:
for hook in executor_context.after_llm_call_hooks:
modified_response = hook(hook_context)
if modified_response is not None and isinstance(modified_response, str):
answer = modified_response
hook_response = modified_response
except Exception as e:
if verbose:
@@ -1035,4 +1045,21 @@ def _setup_after_llm_call_hooks(
else:
executor_context.messages = []
# If hooks modified the response, update answer accordingly
if pydantic_answer is not None:
# For Pydantic models, reparse the JSON if it was modified
if hook_response != original_json:
try:
model_class: type[BaseModel] = type(pydantic_answer)
answer = model_class.model_validate_json(hook_response)
except Exception as e:
if verbose:
printer.print(
content=f"Warning: Hook modified response but failed to reparse as {type(pydantic_answer).__name__}: {e}. Using original model.",
color="yellow",
)
else:
# For string responses, use the hook-modified response
answer = hook_response
return answer

View File

@@ -62,7 +62,10 @@ class Converter(OutputConverter):
],
response_model=self.model,
)
result = self.model.model_validate_json(response)
if isinstance(response, BaseModel):
result = response
else:
result = self.model.model_validate_json(response)
else:
response = self.llm.call(
[

View File

@@ -157,10 +157,10 @@ async def test_anthropic_async_with_response_model():
"Say hello in French",
response_model=GreetingResponse
)
model = GreetingResponse.model_validate_json(result)
assert isinstance(model, GreetingResponse)
assert isinstance(model.greeting, str)
assert isinstance(model.language, str)
# When response_model is provided, the result is already a parsed Pydantic model instance
assert isinstance(result, GreetingResponse)
assert isinstance(result.greeting, str)
assert isinstance(result.language, str)
@pytest.mark.vcr()

View File

@@ -799,3 +799,131 @@ def test_google_express_mode_works() -> None:
assert result.token_usage.prompt_tokens > 0
assert result.token_usage.completion_tokens > 0
assert result.token_usage.successful_requests >= 1
def test_gemini_2_0_model_detection():
"""Test that Gemini 2.0 models are properly detected."""
# Test Gemini 2.0 models
llm_2_0 = LLM(model="google/gemini-2.0-flash-001")
from crewai.llms.providers.gemini.completion import GeminiCompletion
assert isinstance(llm_2_0, GeminiCompletion)
assert llm_2_0.is_gemini_2_0 is True
llm_2_5 = LLM(model="google/gemini-2.5-flash")
assert isinstance(llm_2_5, GeminiCompletion)
assert llm_2_5.is_gemini_2_0 is True
# Test non-2.0 models
llm_1_5 = LLM(model="google/gemini-1.5-pro")
assert isinstance(llm_1_5, GeminiCompletion)
assert llm_1_5.is_gemini_2_0 is False
def test_add_property_ordering_to_schema():
"""Test that _add_property_ordering correctly adds propertyOrdering to schemas."""
from crewai.llms.providers.gemini.completion import GeminiCompletion
# Test simple object schema
simple_schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"},
"email": {"type": "string"}
}
}
result = GeminiCompletion._add_property_ordering(simple_schema)
assert "propertyOrdering" in result
assert result["propertyOrdering"] == ["name", "age", "email"]
# Test nested object schema
nested_schema = {
"type": "object",
"properties": {
"user": {
"type": "object",
"properties": {
"name": {"type": "string"},
"contact": {
"type": "object",
"properties": {
"email": {"type": "string"},
"phone": {"type": "string"}
}
}
}
},
"id": {"type": "integer"}
}
}
result = GeminiCompletion._add_property_ordering(nested_schema)
assert "propertyOrdering" in result
assert result["propertyOrdering"] == ["user", "id"]
assert "propertyOrdering" in result["properties"]["user"]
assert result["properties"]["user"]["propertyOrdering"] == ["name", "contact"]
assert "propertyOrdering" in result["properties"]["user"]["properties"]["contact"]
assert result["properties"]["user"]["properties"]["contact"]["propertyOrdering"] == ["email", "phone"]
def test_gemini_2_0_response_model_with_property_ordering():
"""Test that Gemini 2.0 models include propertyOrdering in response schemas."""
from pydantic import BaseModel, Field
class TestResponse(BaseModel):
"""Test response model."""
name: str = Field(..., description="The name")
age: int = Field(..., description="The age")
email: str = Field(..., description="The email")
llm = LLM(model="google/gemini-2.0-flash-001")
# Prepare generation config with response model
config = llm._prepare_generation_config(response_model=TestResponse)
# Verify that the config has response_json_schema
assert hasattr(config, 'response_json_schema') or 'response_json_schema' in config.__dict__
# Get the schema
if hasattr(config, 'response_json_schema'):
schema = config.response_json_schema
else:
schema = config.__dict__.get('response_json_schema', {})
# Verify propertyOrdering is present for Gemini 2.0
assert "propertyOrdering" in schema
assert "name" in schema["propertyOrdering"]
assert "age" in schema["propertyOrdering"]
assert "email" in schema["propertyOrdering"]
def test_gemini_1_5_response_model_uses_response_schema():
"""Test that Gemini 1.5 models use response_schema parameter (not response_json_schema)."""
from pydantic import BaseModel, Field
class TestResponse(BaseModel):
"""Test response model."""
name: str = Field(..., description="The name")
age: int = Field(..., description="The age")
llm = LLM(model="google/gemini-1.5-pro")
# Prepare generation config with response model
config = llm._prepare_generation_config(response_model=TestResponse)
# Verify that the config uses response_schema (not response_json_schema)
assert hasattr(config, 'response_schema') or 'response_schema' in config.__dict__
assert not (hasattr(config, 'response_json_schema') and config.response_json_schema is not None)
# Get the schema
if hasattr(config, 'response_schema'):
schema = config.response_schema
else:
schema = config.__dict__.get('response_schema')
# For Gemini 1.5, response_schema should be the Pydantic model itself
# The SDK handles conversion internally
assert schema is TestResponse or isinstance(schema, type)

View File

@@ -540,7 +540,9 @@ def test_openai_streaming_with_response_model():
result = llm.call("Test question", response_model=TestResponse)
assert result is not None
assert isinstance(result, str)
assert isinstance(result, TestResponse)
assert result.answer == "test"
assert result.confidence == 0.95
assert mock_stream.called
call_kwargs = mock_stream.call_args[1]