mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-29 18:18:13 +00:00
fix: improve structured output handling across providers and agents
- add gemini 2.0 schema support using response_json_schema with propertyordering while retaining backward compatibility for earlier models - refactor llm completions to return validated pydantic models when a response_model is provided, updating hooks, types, and tests for consistent structured outputs - extend agentfinish and executors to support basemodel outputs, improve anthropic structured parsing, and clean up schema utilities, tests, and original_json handling
This commit is contained in:
@@ -348,18 +348,36 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
# breakpoint()
|
||||
if self.response_model is not None:
|
||||
try:
|
||||
self.response_model.model_validate_json(answer)
|
||||
formatted_answer = AgentFinish(
|
||||
thought="",
|
||||
output=answer,
|
||||
text=answer,
|
||||
)
|
||||
if isinstance(answer, BaseModel):
|
||||
output_json = answer.model_dump_json()
|
||||
formatted_answer = AgentFinish(
|
||||
thought="",
|
||||
output=answer,
|
||||
text=output_json,
|
||||
)
|
||||
else:
|
||||
self.response_model.model_validate_json(answer)
|
||||
formatted_answer = AgentFinish(
|
||||
thought="",
|
||||
output=answer,
|
||||
text=answer,
|
||||
)
|
||||
except ValidationError:
|
||||
# If validation fails, convert BaseModel to JSON string for parsing
|
||||
answer_str = (
|
||||
answer.model_dump_json()
|
||||
if isinstance(answer, BaseModel)
|
||||
else str(answer)
|
||||
)
|
||||
formatted_answer = process_llm_response(
|
||||
answer, self.use_stop_words
|
||||
answer_str, self.use_stop_words
|
||||
) # type: ignore[assignment]
|
||||
else:
|
||||
formatted_answer = process_llm_response(answer, self.use_stop_words) # type: ignore[assignment]
|
||||
# When no response_model, answer should be a string
|
||||
answer_str = str(answer) if not isinstance(answer, str) else answer
|
||||
formatted_answer = process_llm_response(
|
||||
answer_str, self.use_stop_words
|
||||
) # type: ignore[assignment]
|
||||
|
||||
if isinstance(formatted_answer, AgentAction):
|
||||
# Extract agent fingerprint if available
|
||||
@@ -520,6 +538,18 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
self._show_logs(formatted_answer)
|
||||
return formatted_answer
|
||||
|
||||
if isinstance(answer, BaseModel):
|
||||
output_json = answer.model_dump_json()
|
||||
formatted_answer = AgentFinish(
|
||||
thought="",
|
||||
output=answer,
|
||||
text=output_json,
|
||||
)
|
||||
self._invoke_step_callback(formatted_answer)
|
||||
self._append_message(output_json)
|
||||
self._show_logs(formatted_answer)
|
||||
return formatted_answer
|
||||
|
||||
# Unexpected response type, treat as final answer
|
||||
formatted_answer = AgentFinish(
|
||||
thought="",
|
||||
@@ -570,11 +600,20 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
verbose=self.agent.verbose,
|
||||
)
|
||||
|
||||
formatted_answer = AgentFinish(
|
||||
thought="",
|
||||
output=str(answer),
|
||||
text=str(answer),
|
||||
)
|
||||
if isinstance(answer, BaseModel):
|
||||
output_json = answer.model_dump_json()
|
||||
formatted_answer = AgentFinish(
|
||||
thought="",
|
||||
output=answer,
|
||||
text=output_json,
|
||||
)
|
||||
else:
|
||||
answer_str = answer if isinstance(answer, str) else str(answer)
|
||||
formatted_answer = AgentFinish(
|
||||
thought="",
|
||||
output=answer_str,
|
||||
text=answer_str,
|
||||
)
|
||||
self._show_logs(formatted_answer)
|
||||
return formatted_answer
|
||||
|
||||
@@ -1031,18 +1070,36 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
|
||||
if self.response_model is not None:
|
||||
try:
|
||||
self.response_model.model_validate_json(answer)
|
||||
formatted_answer = AgentFinish(
|
||||
thought="",
|
||||
output=answer,
|
||||
text=answer,
|
||||
)
|
||||
if isinstance(answer, BaseModel):
|
||||
output_json = answer.model_dump_json()
|
||||
formatted_answer = AgentFinish(
|
||||
thought="",
|
||||
output=answer,
|
||||
text=output_json,
|
||||
)
|
||||
else:
|
||||
self.response_model.model_validate_json(answer)
|
||||
formatted_answer = AgentFinish(
|
||||
thought="",
|
||||
output=answer,
|
||||
text=answer,
|
||||
)
|
||||
except ValidationError:
|
||||
# If validation fails, convert BaseModel to JSON string for parsing
|
||||
answer_str = (
|
||||
answer.model_dump_json()
|
||||
if isinstance(answer, BaseModel)
|
||||
else str(answer)
|
||||
)
|
||||
formatted_answer = process_llm_response(
|
||||
answer, self.use_stop_words
|
||||
answer_str, self.use_stop_words
|
||||
) # type: ignore[assignment]
|
||||
else:
|
||||
formatted_answer = process_llm_response(answer, self.use_stop_words) # type: ignore[assignment]
|
||||
# When no response_model, answer should be a string
|
||||
answer_str = str(answer) if not isinstance(answer, str) else answer
|
||||
formatted_answer = process_llm_response(
|
||||
answer_str, self.use_stop_words
|
||||
) # type: ignore[assignment]
|
||||
|
||||
if isinstance(formatted_answer, AgentAction):
|
||||
fingerprint_context = {}
|
||||
@@ -1194,6 +1251,18 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
self._show_logs(formatted_answer)
|
||||
return formatted_answer
|
||||
|
||||
if isinstance(answer, BaseModel):
|
||||
output_json = answer.model_dump_json()
|
||||
formatted_answer = AgentFinish(
|
||||
thought="",
|
||||
output=answer,
|
||||
text=output_json,
|
||||
)
|
||||
self._invoke_step_callback(formatted_answer)
|
||||
self._append_message(output_json)
|
||||
self._show_logs(formatted_answer)
|
||||
return formatted_answer
|
||||
|
||||
# Unexpected response type, treat as final answer
|
||||
formatted_answer = AgentFinish(
|
||||
thought="",
|
||||
@@ -1244,11 +1313,20 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
verbose=self.agent.verbose,
|
||||
)
|
||||
|
||||
formatted_answer = AgentFinish(
|
||||
thought="",
|
||||
output=str(answer),
|
||||
text=str(answer),
|
||||
)
|
||||
if isinstance(answer, BaseModel):
|
||||
output_json = answer.model_dump_json()
|
||||
formatted_answer = AgentFinish(
|
||||
thought="",
|
||||
output=answer,
|
||||
text=output_json,
|
||||
)
|
||||
else:
|
||||
answer_str = answer if isinstance(answer, str) else str(answer)
|
||||
formatted_answer = AgentFinish(
|
||||
thought="",
|
||||
output=answer_str,
|
||||
text=answer_str,
|
||||
)
|
||||
self._show_logs(formatted_answer)
|
||||
return formatted_answer
|
||||
|
||||
@@ -1421,7 +1499,12 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
Returns:
|
||||
Final answer after feedback.
|
||||
"""
|
||||
human_feedback = self._ask_human_input(formatted_answer.output)
|
||||
output_str = (
|
||||
formatted_answer.output
|
||||
if isinstance(formatted_answer.output, str)
|
||||
else formatted_answer.output.model_dump_json()
|
||||
)
|
||||
human_feedback = self._ask_human_input(output_str)
|
||||
|
||||
if self._is_training_mode():
|
||||
return self._handle_training_feedback(formatted_answer, human_feedback)
|
||||
@@ -1480,7 +1563,12 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
self.ask_for_human_input = False
|
||||
else:
|
||||
answer = self._process_feedback_iteration(feedback)
|
||||
feedback = self._ask_human_input(answer.output)
|
||||
output_str = (
|
||||
answer.output
|
||||
if isinstance(answer.output, str)
|
||||
else answer.output.model_dump_json()
|
||||
)
|
||||
feedback = self._ask_human_input(output_str)
|
||||
|
||||
return answer
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ AgentAction or AgentFinish objects.
|
||||
from dataclasses import dataclass
|
||||
|
||||
from json_repair import repair_json # type: ignore[import-untyped]
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.agents.constants import (
|
||||
ACTION_INPUT_ONLY_REGEX,
|
||||
@@ -40,7 +41,7 @@ class AgentFinish:
|
||||
"""Represents the final answer from an agent."""
|
||||
|
||||
thought: str
|
||||
output: str
|
||||
output: str | BaseModel
|
||||
text: str
|
||||
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ if TYPE_CHECKING:
|
||||
try:
|
||||
from anthropic import Anthropic, AsyncAnthropic, transform_schema
|
||||
from anthropic.types import Message, TextBlock, ThinkingBlock, ToolUseBlock
|
||||
from anthropic.types.beta import BetaMessage
|
||||
from anthropic.types.beta import BetaMessage, BetaTextBlock
|
||||
import httpx
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
@@ -337,6 +337,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
available_functions: Available functions for tool calling
|
||||
from_task: Task that initiated the call
|
||||
from_agent: Agent that initiated the call
|
||||
response_model: Optional response model.
|
||||
|
||||
Returns:
|
||||
Chat completion response or tool call result
|
||||
@@ -677,31 +678,31 @@ class AnthropicCompletion(BaseLLM):
|
||||
if _is_pydantic_model_class(response_model) and response.content:
|
||||
if use_native_structured_output:
|
||||
for block in response.content:
|
||||
if isinstance(block, TextBlock):
|
||||
structured_json = block.text
|
||||
if isinstance(block, (TextBlock, BetaTextBlock)):
|
||||
structured_data = response_model.model_validate_json(block.text)
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
response=structured_data.model_dump_json(),
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return structured_json
|
||||
return structured_data
|
||||
else:
|
||||
for block in response.content:
|
||||
if (
|
||||
isinstance(block, ToolUseBlock)
|
||||
and block.name == "structured_output"
|
||||
):
|
||||
structured_json = json.dumps(block.input)
|
||||
structured_data = response_model.model_validate(block.input)
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
response=structured_data.model_dump_json(),
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return structured_json
|
||||
return structured_data
|
||||
|
||||
# Check if Claude wants to use tools
|
||||
if response.content:
|
||||
@@ -897,28 +898,29 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
if _is_pydantic_model_class(response_model):
|
||||
if use_native_structured_output:
|
||||
structured_data = response_model.model_validate_json(full_response)
|
||||
self._emit_call_completed_event(
|
||||
response=full_response,
|
||||
response=structured_data.model_dump_json(),
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return full_response
|
||||
return structured_data
|
||||
for block in final_message.content:
|
||||
if (
|
||||
isinstance(block, ToolUseBlock)
|
||||
and block.name == "structured_output"
|
||||
):
|
||||
structured_json = json.dumps(block.input)
|
||||
structured_data = response_model.model_validate(block.input)
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
response=structured_data.model_dump_json(),
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return structured_json
|
||||
return structured_data
|
||||
|
||||
if final_message.content:
|
||||
tool_uses = [
|
||||
@@ -1166,31 +1168,31 @@ class AnthropicCompletion(BaseLLM):
|
||||
if _is_pydantic_model_class(response_model) and response.content:
|
||||
if use_native_structured_output:
|
||||
for block in response.content:
|
||||
if isinstance(block, TextBlock):
|
||||
structured_json = block.text
|
||||
if isinstance(block, (TextBlock, BetaTextBlock)):
|
||||
structured_data = response_model.model_validate_json(block.text)
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
response=structured_data.model_dump_json(),
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return structured_json
|
||||
return structured_data
|
||||
else:
|
||||
for block in response.content:
|
||||
if (
|
||||
isinstance(block, ToolUseBlock)
|
||||
and block.name == "structured_output"
|
||||
):
|
||||
structured_json = json.dumps(block.input)
|
||||
structured_data = response_model.model_validate(block.input)
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
response=structured_data.model_dump_json(),
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return structured_json
|
||||
return structured_data
|
||||
|
||||
if response.content:
|
||||
tool_uses = [
|
||||
@@ -1362,28 +1364,29 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
if _is_pydantic_model_class(response_model):
|
||||
if use_native_structured_output:
|
||||
structured_data = response_model.model_validate_json(full_response)
|
||||
self._emit_call_completed_event(
|
||||
response=full_response,
|
||||
response=structured_data.model_dump_json(),
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return full_response
|
||||
return structured_data
|
||||
for block in final_message.content:
|
||||
if (
|
||||
isinstance(block, ToolUseBlock)
|
||||
and block.name == "structured_output"
|
||||
):
|
||||
structured_json = json.dumps(block.input)
|
||||
structured_data = response_model.model_validate(block.input)
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
response=structured_data.model_dump_json(),
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return structured_json
|
||||
return structured_data
|
||||
|
||||
if final_message.content:
|
||||
tool_uses = [
|
||||
|
||||
@@ -557,7 +557,7 @@ class AzureCompletion(BaseLLM):
|
||||
params: AzureCompletionParams,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str:
|
||||
) -> BaseModel:
|
||||
"""Validate content against response model and emit completion event.
|
||||
|
||||
Args:
|
||||
@@ -568,24 +568,23 @@ class AzureCompletion(BaseLLM):
|
||||
from_agent: Agent that initiated the call
|
||||
|
||||
Returns:
|
||||
Validated and serialized JSON string
|
||||
Validated Pydantic model instance
|
||||
|
||||
Raises:
|
||||
ValueError: If validation fails
|
||||
"""
|
||||
try:
|
||||
structured_data = response_model.model_validate_json(content)
|
||||
structured_json = structured_data.model_dump_json()
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
response=structured_data.model_dump_json(),
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return structured_json
|
||||
return structured_data
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to validate structured output with model {response_model.__name__}: {e}"
|
||||
logging.error(error_msg)
|
||||
|
||||
@@ -132,6 +132,9 @@ class GeminiCompletion(BaseLLM):
|
||||
self.supports_tools = bool(
|
||||
version_match and float(version_match.group(1)) >= 1.5
|
||||
)
|
||||
self.is_gemini_2_0 = bool(
|
||||
version_match and float(version_match.group(1)) >= 2.0
|
||||
)
|
||||
|
||||
@property
|
||||
def stop(self) -> list[str]:
|
||||
@@ -439,6 +442,11 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
Returns:
|
||||
GenerateContentConfig object for Gemini API
|
||||
|
||||
Note:
|
||||
Structured output support varies by model version:
|
||||
- Gemini 1.5 and earlier: Uses response_schema (Pydantic model)
|
||||
- Gemini 2.0+: Uses response_json_schema (JSON Schema) with propertyOrdering
|
||||
"""
|
||||
self.tools = tools
|
||||
config_params: dict[str, Any] = {}
|
||||
@@ -466,9 +474,13 @@ class GeminiCompletion(BaseLLM):
|
||||
if response_model:
|
||||
config_params["response_mime_type"] = "application/json"
|
||||
schema_output = generate_model_description(response_model)
|
||||
config_params["response_schema"] = schema_output.get("json_schema", {}).get(
|
||||
"schema", {}
|
||||
)
|
||||
schema = schema_output.get("json_schema", {}).get("schema", {})
|
||||
|
||||
if self.is_gemini_2_0:
|
||||
schema = self._add_property_ordering(schema)
|
||||
config_params["response_json_schema"] = schema
|
||||
else:
|
||||
config_params["response_schema"] = response_model
|
||||
|
||||
# Handle tools for supported models
|
||||
if tools and self.supports_tools:
|
||||
@@ -632,7 +644,7 @@ class GeminiCompletion(BaseLLM):
|
||||
messages_for_event: list[LLMMessage],
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str:
|
||||
) -> BaseModel:
|
||||
"""Validate content against response model and emit completion event.
|
||||
|
||||
Args:
|
||||
@@ -643,24 +655,23 @@ class GeminiCompletion(BaseLLM):
|
||||
from_agent: Agent that initiated the call
|
||||
|
||||
Returns:
|
||||
Validated and serialized JSON string
|
||||
Validated Pydantic model instance
|
||||
|
||||
Raises:
|
||||
ValueError: If validation fails
|
||||
"""
|
||||
try:
|
||||
structured_data = response_model.model_validate_json(content)
|
||||
structured_json = structured_data.model_dump_json()
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
response=structured_data.model_dump_json(),
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages_for_event,
|
||||
)
|
||||
|
||||
return structured_json
|
||||
return structured_data
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to validate structured output with model {response_model.__name__}: {e}"
|
||||
logging.error(error_msg)
|
||||
@@ -673,7 +684,7 @@ class GeminiCompletion(BaseLLM):
|
||||
response_model: type[BaseModel] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str:
|
||||
) -> str | BaseModel:
|
||||
"""Finalize completion response with validation and event emission.
|
||||
|
||||
Args:
|
||||
@@ -684,7 +695,7 @@ class GeminiCompletion(BaseLLM):
|
||||
from_agent: Agent that initiated the call
|
||||
|
||||
Returns:
|
||||
Final response content after processing
|
||||
Final response content after processing (str or Pydantic model if response_model provided)
|
||||
"""
|
||||
messages_for_event = self._convert_contents_to_dict(contents)
|
||||
|
||||
@@ -870,7 +881,7 @@ class GeminiCompletion(BaseLLM):
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | list[dict[str, Any]]:
|
||||
) -> str | BaseModel | list[dict[str, Any]]:
|
||||
"""Finalize streaming response with usage tracking, function execution, and events.
|
||||
|
||||
Args:
|
||||
@@ -990,7 +1001,7 @@ class GeminiCompletion(BaseLLM):
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
) -> str | BaseModel | list[dict[str, Any]] | Any:
|
||||
"""Handle streaming content generation."""
|
||||
full_response = ""
|
||||
function_calls: dict[int, dict[str, Any]] = {}
|
||||
@@ -1190,6 +1201,36 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
return "".join(text_parts)
|
||||
|
||||
@staticmethod
|
||||
def _add_property_ordering(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Add propertyOrdering to JSON schema for Gemini 2.0 compatibility.
|
||||
|
||||
Gemini 2.0 models require an explicit propertyOrdering list to define
|
||||
the preferred structure of JSON objects. This recursively adds
|
||||
propertyOrdering to all objects in the schema.
|
||||
|
||||
Args:
|
||||
schema: JSON schema dictionary.
|
||||
|
||||
Returns:
|
||||
Modified schema with propertyOrdering added to all objects.
|
||||
"""
|
||||
if isinstance(schema, dict):
|
||||
if schema.get("type") == "object" and "properties" in schema:
|
||||
properties = schema["properties"]
|
||||
if properties and "propertyOrdering" not in schema:
|
||||
schema["propertyOrdering"] = list(properties.keys())
|
||||
|
||||
for value in schema.values():
|
||||
if isinstance(value, dict):
|
||||
GeminiCompletion._add_property_ordering(value)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
GeminiCompletion._add_property_ordering(item)
|
||||
|
||||
return schema
|
||||
|
||||
@staticmethod
|
||||
def _convert_contents_to_dict(
|
||||
contents: list[types.Content],
|
||||
|
||||
@@ -1570,15 +1570,14 @@ class OpenAICompletion(BaseLLM):
|
||||
|
||||
parsed_object = parsed_response.choices[0].message.parsed
|
||||
if parsed_object:
|
||||
structured_json = parsed_object.model_dump_json()
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
response=parsed_object.model_dump_json(),
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return structured_json
|
||||
return parsed_object
|
||||
|
||||
response: ChatCompletion = self.client.chat.completions.create(**params)
|
||||
|
||||
@@ -1692,7 +1691,7 @@ class OpenAICompletion(BaseLLM):
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
) -> str | BaseModel:
|
||||
"""Handle streaming chat completion."""
|
||||
full_response = ""
|
||||
tool_calls: dict[int, dict[str, Any]] = {}
|
||||
@@ -1728,15 +1727,14 @@ class OpenAICompletion(BaseLLM):
|
||||
if final_completion.choices:
|
||||
parsed_result = final_completion.choices[0].message.parsed
|
||||
if parsed_result:
|
||||
structured_json = parsed_result.model_dump_json()
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
response=parsed_result.model_dump_json(),
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return structured_json
|
||||
return parsed_result
|
||||
|
||||
logging.error("Failed to get parsed result from stream")
|
||||
return ""
|
||||
@@ -1887,15 +1885,14 @@ class OpenAICompletion(BaseLLM):
|
||||
|
||||
parsed_object = parsed_response.choices[0].message.parsed
|
||||
if parsed_object:
|
||||
structured_json = parsed_object.model_dump_json()
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
response=parsed_object.model_dump_json(),
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return structured_json
|
||||
return parsed_object
|
||||
|
||||
response: ChatCompletion = await self.async_client.chat.completions.create(
|
||||
**params
|
||||
@@ -2006,7 +2003,7 @@ class OpenAICompletion(BaseLLM):
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
) -> str | BaseModel:
|
||||
"""Handle async streaming chat completion."""
|
||||
full_response = ""
|
||||
tool_calls: dict[int, dict[str, Any]] = {}
|
||||
@@ -2044,17 +2041,16 @@ class OpenAICompletion(BaseLLM):
|
||||
|
||||
try:
|
||||
parsed_object = response_model.model_validate_json(accumulated_content)
|
||||
structured_json = parsed_object.model_dump_json()
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
response=parsed_object.model_dump_json(),
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return structured_json
|
||||
return parsed_object
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to parse structured output from stream: {e}")
|
||||
self._emit_call_completed_event(
|
||||
|
||||
@@ -327,7 +327,7 @@ def get_llm_response(
|
||||
response_model: type[BaseModel] | None = None,
|
||||
executor_context: CrewAgentExecutor | AgentExecutor | LiteAgent | None = None,
|
||||
verbose: bool = True,
|
||||
) -> str | Any:
|
||||
) -> str | BaseModel | Any:
|
||||
"""Call the LLM and return the response, handling any invalid responses.
|
||||
|
||||
Args:
|
||||
@@ -341,10 +341,11 @@ def get_llm_response(
|
||||
from_agent: Optional agent context for the LLM call.
|
||||
response_model: Optional Pydantic model for structured outputs.
|
||||
executor_context: Optional executor context for hook invocation.
|
||||
verbose: Whether to print output.
|
||||
|
||||
Returns:
|
||||
The response from the LLM as a string, or tool call results if
|
||||
native function calling is used.
|
||||
The response from the LLM as a string, Pydantic model (when response_model is provided),
|
||||
or tool call results if native function calling is used.
|
||||
|
||||
Raises:
|
||||
Exception: If an error occurs.
|
||||
@@ -393,7 +394,7 @@ async def aget_llm_response(
|
||||
response_model: type[BaseModel] | None = None,
|
||||
executor_context: CrewAgentExecutor | AgentExecutor | None = None,
|
||||
verbose: bool = True,
|
||||
) -> str | Any:
|
||||
) -> str | BaseModel | Any:
|
||||
"""Call the LLM asynchronously and return the response.
|
||||
|
||||
Args:
|
||||
@@ -409,8 +410,8 @@ async def aget_llm_response(
|
||||
executor_context: Optional executor context for hook invocation.
|
||||
|
||||
Returns:
|
||||
The response from the LLM as a string, or tool call results if
|
||||
native function calling is used.
|
||||
The response from the LLM as a string, Pydantic model (when response_model is provided),
|
||||
or tool call results if native function calling is used.
|
||||
|
||||
Raises:
|
||||
Exception: If an error occurs.
|
||||
@@ -986,32 +987,41 @@ def _setup_before_llm_call_hooks(
|
||||
|
||||
def _setup_after_llm_call_hooks(
|
||||
executor_context: CrewAgentExecutor | AgentExecutor | LiteAgent | None,
|
||||
answer: str,
|
||||
answer: str | BaseModel,
|
||||
printer: Printer,
|
||||
verbose: bool = True,
|
||||
) -> str:
|
||||
) -> str | BaseModel:
|
||||
"""Setup and invoke after_llm_call hooks for the executor context.
|
||||
|
||||
Args:
|
||||
executor_context: The executor context to setup the hooks for.
|
||||
answer: The LLM response string.
|
||||
answer: The LLM response (string or Pydantic model).
|
||||
printer: Printer instance for error logging.
|
||||
verbose: Whether to print output.
|
||||
|
||||
Returns:
|
||||
The potentially modified response string.
|
||||
The potentially modified response (string or Pydantic model).
|
||||
"""
|
||||
if executor_context and executor_context.after_llm_call_hooks:
|
||||
from crewai.hooks.llm_hooks import LLMCallHookContext
|
||||
|
||||
original_messages = executor_context.messages
|
||||
|
||||
hook_context = LLMCallHookContext(executor_context, response=answer)
|
||||
# For Pydantic models, serialize to JSON for hooks
|
||||
if isinstance(answer, BaseModel):
|
||||
pydantic_answer = answer
|
||||
hook_response: str = pydantic_answer.model_dump_json()
|
||||
original_json: str = hook_response
|
||||
else:
|
||||
pydantic_answer = None
|
||||
hook_response = str(answer)
|
||||
|
||||
hook_context = LLMCallHookContext(executor_context, response=hook_response)
|
||||
try:
|
||||
for hook in executor_context.after_llm_call_hooks:
|
||||
modified_response = hook(hook_context)
|
||||
if modified_response is not None and isinstance(modified_response, str):
|
||||
answer = modified_response
|
||||
hook_response = modified_response
|
||||
|
||||
except Exception as e:
|
||||
if verbose:
|
||||
@@ -1035,4 +1045,21 @@ def _setup_after_llm_call_hooks(
|
||||
else:
|
||||
executor_context.messages = []
|
||||
|
||||
# If hooks modified the response, update answer accordingly
|
||||
if pydantic_answer is not None:
|
||||
# For Pydantic models, reparse the JSON if it was modified
|
||||
if hook_response != original_json:
|
||||
try:
|
||||
model_class: type[BaseModel] = type(pydantic_answer)
|
||||
answer = model_class.model_validate_json(hook_response)
|
||||
except Exception as e:
|
||||
if verbose:
|
||||
printer.print(
|
||||
content=f"Warning: Hook modified response but failed to reparse as {type(pydantic_answer).__name__}: {e}. Using original model.",
|
||||
color="yellow",
|
||||
)
|
||||
else:
|
||||
# For string responses, use the hook-modified response
|
||||
answer = hook_response
|
||||
|
||||
return answer
|
||||
|
||||
@@ -62,7 +62,10 @@ class Converter(OutputConverter):
|
||||
],
|
||||
response_model=self.model,
|
||||
)
|
||||
result = self.model.model_validate_json(response)
|
||||
if isinstance(response, BaseModel):
|
||||
result = response
|
||||
else:
|
||||
result = self.model.model_validate_json(response)
|
||||
else:
|
||||
response = self.llm.call(
|
||||
[
|
||||
|
||||
@@ -157,10 +157,10 @@ async def test_anthropic_async_with_response_model():
|
||||
"Say hello in French",
|
||||
response_model=GreetingResponse
|
||||
)
|
||||
model = GreetingResponse.model_validate_json(result)
|
||||
assert isinstance(model, GreetingResponse)
|
||||
assert isinstance(model.greeting, str)
|
||||
assert isinstance(model.language, str)
|
||||
# When response_model is provided, the result is already a parsed Pydantic model instance
|
||||
assert isinstance(result, GreetingResponse)
|
||||
assert isinstance(result.greeting, str)
|
||||
assert isinstance(result.language, str)
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
|
||||
@@ -799,3 +799,131 @@ def test_google_express_mode_works() -> None:
|
||||
assert result.token_usage.prompt_tokens > 0
|
||||
assert result.token_usage.completion_tokens > 0
|
||||
assert result.token_usage.successful_requests >= 1
|
||||
|
||||
|
||||
def test_gemini_2_0_model_detection():
|
||||
"""Test that Gemini 2.0 models are properly detected."""
|
||||
# Test Gemini 2.0 models
|
||||
llm_2_0 = LLM(model="google/gemini-2.0-flash-001")
|
||||
from crewai.llms.providers.gemini.completion import GeminiCompletion
|
||||
assert isinstance(llm_2_0, GeminiCompletion)
|
||||
assert llm_2_0.is_gemini_2_0 is True
|
||||
|
||||
llm_2_5 = LLM(model="google/gemini-2.5-flash")
|
||||
assert isinstance(llm_2_5, GeminiCompletion)
|
||||
assert llm_2_5.is_gemini_2_0 is True
|
||||
|
||||
# Test non-2.0 models
|
||||
llm_1_5 = LLM(model="google/gemini-1.5-pro")
|
||||
assert isinstance(llm_1_5, GeminiCompletion)
|
||||
assert llm_1_5.is_gemini_2_0 is False
|
||||
|
||||
|
||||
def test_add_property_ordering_to_schema():
|
||||
"""Test that _add_property_ordering correctly adds propertyOrdering to schemas."""
|
||||
from crewai.llms.providers.gemini.completion import GeminiCompletion
|
||||
|
||||
# Test simple object schema
|
||||
simple_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"},
|
||||
"email": {"type": "string"}
|
||||
}
|
||||
}
|
||||
|
||||
result = GeminiCompletion._add_property_ordering(simple_schema)
|
||||
|
||||
assert "propertyOrdering" in result
|
||||
assert result["propertyOrdering"] == ["name", "age", "email"]
|
||||
|
||||
# Test nested object schema
|
||||
nested_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"user": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"contact": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"email": {"type": "string"},
|
||||
"phone": {"type": "string"}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"id": {"type": "integer"}
|
||||
}
|
||||
}
|
||||
|
||||
result = GeminiCompletion._add_property_ordering(nested_schema)
|
||||
|
||||
assert "propertyOrdering" in result
|
||||
assert result["propertyOrdering"] == ["user", "id"]
|
||||
assert "propertyOrdering" in result["properties"]["user"]
|
||||
assert result["properties"]["user"]["propertyOrdering"] == ["name", "contact"]
|
||||
assert "propertyOrdering" in result["properties"]["user"]["properties"]["contact"]
|
||||
assert result["properties"]["user"]["properties"]["contact"]["propertyOrdering"] == ["email", "phone"]
|
||||
|
||||
|
||||
def test_gemini_2_0_response_model_with_property_ordering():
|
||||
"""Test that Gemini 2.0 models include propertyOrdering in response schemas."""
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class TestResponse(BaseModel):
|
||||
"""Test response model."""
|
||||
name: str = Field(..., description="The name")
|
||||
age: int = Field(..., description="The age")
|
||||
email: str = Field(..., description="The email")
|
||||
|
||||
llm = LLM(model="google/gemini-2.0-flash-001")
|
||||
|
||||
# Prepare generation config with response model
|
||||
config = llm._prepare_generation_config(response_model=TestResponse)
|
||||
|
||||
# Verify that the config has response_json_schema
|
||||
assert hasattr(config, 'response_json_schema') or 'response_json_schema' in config.__dict__
|
||||
|
||||
# Get the schema
|
||||
if hasattr(config, 'response_json_schema'):
|
||||
schema = config.response_json_schema
|
||||
else:
|
||||
schema = config.__dict__.get('response_json_schema', {})
|
||||
|
||||
# Verify propertyOrdering is present for Gemini 2.0
|
||||
assert "propertyOrdering" in schema
|
||||
assert "name" in schema["propertyOrdering"]
|
||||
assert "age" in schema["propertyOrdering"]
|
||||
assert "email" in schema["propertyOrdering"]
|
||||
|
||||
|
||||
def test_gemini_1_5_response_model_uses_response_schema():
|
||||
"""Test that Gemini 1.5 models use response_schema parameter (not response_json_schema)."""
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class TestResponse(BaseModel):
|
||||
"""Test response model."""
|
||||
name: str = Field(..., description="The name")
|
||||
age: int = Field(..., description="The age")
|
||||
|
||||
llm = LLM(model="google/gemini-1.5-pro")
|
||||
|
||||
# Prepare generation config with response model
|
||||
config = llm._prepare_generation_config(response_model=TestResponse)
|
||||
|
||||
# Verify that the config uses response_schema (not response_json_schema)
|
||||
assert hasattr(config, 'response_schema') or 'response_schema' in config.__dict__
|
||||
assert not (hasattr(config, 'response_json_schema') and config.response_json_schema is not None)
|
||||
|
||||
# Get the schema
|
||||
if hasattr(config, 'response_schema'):
|
||||
schema = config.response_schema
|
||||
else:
|
||||
schema = config.__dict__.get('response_schema')
|
||||
|
||||
# For Gemini 1.5, response_schema should be the Pydantic model itself
|
||||
# The SDK handles conversion internally
|
||||
assert schema is TestResponse or isinstance(schema, type)
|
||||
|
||||
@@ -540,7 +540,9 @@ def test_openai_streaming_with_response_model():
|
||||
result = llm.call("Test question", response_model=TestResponse)
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, str)
|
||||
assert isinstance(result, TestResponse)
|
||||
assert result.answer == "test"
|
||||
assert result.confidence == 0.95
|
||||
|
||||
assert mock_stream.called
|
||||
call_kwargs = mock_stream.call_args[1]
|
||||
|
||||
Reference in New Issue
Block a user