mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-01 07:13:00 +00:00
fixes gemini
This commit is contained in:
@@ -34,6 +34,9 @@ except ImportError:
|
|||||||
) from None
|
) from None
|
||||||
|
|
||||||
|
|
||||||
|
STRUCTURED_OUTPUT_TOOL_NAME = "structured_output"
|
||||||
|
|
||||||
|
|
||||||
class GeminiCompletion(BaseLLM):
|
class GeminiCompletion(BaseLLM):
|
||||||
"""Google Gemini native completion implementation.
|
"""Google Gemini native completion implementation.
|
||||||
|
|
||||||
@@ -447,6 +450,9 @@ class GeminiCompletion(BaseLLM):
|
|||||||
Structured output support varies by model version:
|
Structured output support varies by model version:
|
||||||
- Gemini 1.5 and earlier: Uses response_schema (Pydantic model)
|
- Gemini 1.5 and earlier: Uses response_schema (Pydantic model)
|
||||||
- Gemini 2.0+: Uses response_json_schema (JSON Schema) with propertyOrdering
|
- Gemini 2.0+: Uses response_json_schema (JSON Schema) with propertyOrdering
|
||||||
|
|
||||||
|
When both tools AND response_model are present, we add a structured_output
|
||||||
|
pseudo-tool since Gemini doesn't support tools + response_schema together.
|
||||||
"""
|
"""
|
||||||
self.tools = tools
|
self.tools = tools
|
||||||
config_params: dict[str, Any] = {}
|
config_params: dict[str, Any] = {}
|
||||||
@@ -472,7 +478,30 @@ class GeminiCompletion(BaseLLM):
|
|||||||
config_params["stop_sequences"] = self.stop_sequences
|
config_params["stop_sequences"] = self.stop_sequences
|
||||||
|
|
||||||
if tools and self.supports_tools:
|
if tools and self.supports_tools:
|
||||||
config_params["tools"] = self._convert_tools_for_interference(tools)
|
gemini_tools = self._convert_tools_for_interference(tools)
|
||||||
|
|
||||||
|
if response_model:
|
||||||
|
schema_output = generate_model_description(response_model)
|
||||||
|
schema = schema_output.get("json_schema", {}).get("schema", {})
|
||||||
|
if self.is_gemini_2_0:
|
||||||
|
schema = self._add_property_ordering(schema)
|
||||||
|
|
||||||
|
structured_output_tool = types.Tool(
|
||||||
|
function_declarations=[
|
||||||
|
types.FunctionDeclaration(
|
||||||
|
name=STRUCTURED_OUTPUT_TOOL_NAME,
|
||||||
|
description=(
|
||||||
|
"Use this tool to provide your final structured response. "
|
||||||
|
"Call this tool when you have gathered all necessary information "
|
||||||
|
"and are ready to provide the final answer in the required format."
|
||||||
|
),
|
||||||
|
parameters_json_schema=schema,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
gemini_tools.append(structured_output_tool)
|
||||||
|
|
||||||
|
config_params["tools"] = gemini_tools
|
||||||
elif response_model:
|
elif response_model:
|
||||||
config_params["response_mime_type"] = "application/json"
|
config_params["response_mime_type"] = "application/json"
|
||||||
schema_output = generate_model_description(response_model)
|
schema_output = generate_model_description(response_model)
|
||||||
@@ -719,6 +748,47 @@ class GeminiCompletion(BaseLLM):
|
|||||||
messages_for_event, content, from_agent
|
messages_for_event, content, from_agent
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _handle_structured_output_tool_call(
|
||||||
|
self,
|
||||||
|
structured_data: dict[str, Any],
|
||||||
|
response_model: type[BaseModel],
|
||||||
|
contents: list[types.Content],
|
||||||
|
from_task: Any | None = None,
|
||||||
|
from_agent: Any | None = None,
|
||||||
|
) -> BaseModel:
|
||||||
|
"""Validate and emit event for structured_output tool call.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
structured_data: The arguments passed to the structured_output tool
|
||||||
|
response_model: Pydantic model to validate against
|
||||||
|
contents: Original contents for event conversion
|
||||||
|
from_task: Task that initiated the call
|
||||||
|
from_agent: Agent that initiated the call
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated Pydantic model instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If validation fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
validated_data = response_model.model_validate(structured_data)
|
||||||
|
self._emit_call_completed_event(
|
||||||
|
response=validated_data.model_dump_json(),
|
||||||
|
call_type=LLMCallType.LLM_CALL,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
messages=self._convert_contents_to_dict(contents),
|
||||||
|
)
|
||||||
|
return validated_data
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = (
|
||||||
|
f"Failed to validate {STRUCTURED_OUTPUT_TOOL_NAME} tool response "
|
||||||
|
f"with model {response_model.__name__}: {e}"
|
||||||
|
)
|
||||||
|
logging.error(error_msg)
|
||||||
|
raise ValueError(error_msg) from e
|
||||||
|
|
||||||
def _process_response_with_tools(
|
def _process_response_with_tools(
|
||||||
self,
|
self,
|
||||||
response: GenerateContentResponse,
|
response: GenerateContentResponse,
|
||||||
@@ -749,17 +819,47 @@ class GeminiCompletion(BaseLLM):
|
|||||||
part for part in candidate.content.parts if part.function_call
|
part for part in candidate.content.parts if part.function_call
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Check for structured_output pseudo-tool call (used when tools + response_model)
|
||||||
|
if response_model and function_call_parts:
|
||||||
|
for part in function_call_parts:
|
||||||
|
if (
|
||||||
|
part.function_call
|
||||||
|
and part.function_call.name == STRUCTURED_OUTPUT_TOOL_NAME
|
||||||
|
):
|
||||||
|
structured_data = (
|
||||||
|
dict(part.function_call.args)
|
||||||
|
if part.function_call.args
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
return self._handle_structured_output_tool_call(
|
||||||
|
structured_data=structured_data,
|
||||||
|
response_model=response_model,
|
||||||
|
contents=contents,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter out structured_output from function calls returned to executor
|
||||||
|
non_structured_output_parts = [
|
||||||
|
part
|
||||||
|
for part in function_call_parts
|
||||||
|
if not (
|
||||||
|
part.function_call
|
||||||
|
and part.function_call.name == STRUCTURED_OUTPUT_TOOL_NAME
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
# If there are function calls but no available_functions,
|
# If there are function calls but no available_functions,
|
||||||
# return them for the executor to handle (like OpenAI/Anthropic)
|
# return them for the executor to handle (like OpenAI/Anthropic)
|
||||||
if function_call_parts and not available_functions:
|
if non_structured_output_parts and not available_functions:
|
||||||
self._emit_call_completed_event(
|
self._emit_call_completed_event(
|
||||||
response=function_call_parts,
|
response=non_structured_output_parts,
|
||||||
call_type=LLMCallType.TOOL_CALL,
|
call_type=LLMCallType.TOOL_CALL,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
messages=self._convert_contents_to_dict(contents),
|
messages=self._convert_contents_to_dict(contents),
|
||||||
)
|
)
|
||||||
return function_call_parts
|
return non_structured_output_parts
|
||||||
|
|
||||||
# Otherwise execute the tools internally
|
# Otherwise execute the tools internally
|
||||||
for part in candidate.content.parts:
|
for part in candidate.content.parts:
|
||||||
@@ -767,6 +867,9 @@ class GeminiCompletion(BaseLLM):
|
|||||||
function_name = part.function_call.name
|
function_name = part.function_call.name
|
||||||
if function_name is None:
|
if function_name is None:
|
||||||
continue
|
continue
|
||||||
|
# Skip structured_output - it's handled above
|
||||||
|
if function_name == STRUCTURED_OUTPUT_TOOL_NAME:
|
||||||
|
continue
|
||||||
function_args = (
|
function_args = (
|
||||||
dict(part.function_call.args)
|
dict(part.function_call.args)
|
||||||
if part.function_call.args
|
if part.function_call.args
|
||||||
@@ -899,9 +1002,27 @@ class GeminiCompletion(BaseLLM):
|
|||||||
"""
|
"""
|
||||||
self._track_token_usage_internal(usage_data)
|
self._track_token_usage_internal(usage_data)
|
||||||
|
|
||||||
|
if response_model and function_calls:
|
||||||
|
for call_data in function_calls.values():
|
||||||
|
if call_data.get("name") == STRUCTURED_OUTPUT_TOOL_NAME:
|
||||||
|
structured_data = call_data.get("args", {})
|
||||||
|
return self._handle_structured_output_tool_call(
|
||||||
|
structured_data=structured_data,
|
||||||
|
response_model=response_model,
|
||||||
|
contents=contents,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
non_structured_output_calls = {
|
||||||
|
idx: call_data
|
||||||
|
for idx, call_data in function_calls.items()
|
||||||
|
if call_data.get("name") != STRUCTURED_OUTPUT_TOOL_NAME
|
||||||
|
}
|
||||||
|
|
||||||
# If there are function calls but no available_functions,
|
# If there are function calls but no available_functions,
|
||||||
# return them for the executor to handle
|
# return them for the executor to handle
|
||||||
if function_calls and not available_functions:
|
if non_structured_output_calls and not available_functions:
|
||||||
formatted_function_calls = [
|
formatted_function_calls = [
|
||||||
{
|
{
|
||||||
"id": call_data["id"],
|
"id": call_data["id"],
|
||||||
@@ -911,7 +1032,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
},
|
},
|
||||||
"type": "function",
|
"type": "function",
|
||||||
}
|
}
|
||||||
for call_data in function_calls.values()
|
for call_data in non_structured_output_calls.values()
|
||||||
]
|
]
|
||||||
self._emit_call_completed_event(
|
self._emit_call_completed_event(
|
||||||
response=formatted_function_calls,
|
response=formatted_function_calls,
|
||||||
@@ -922,9 +1043,9 @@ class GeminiCompletion(BaseLLM):
|
|||||||
)
|
)
|
||||||
return formatted_function_calls
|
return formatted_function_calls
|
||||||
|
|
||||||
# Handle completed function calls
|
# Handle completed function calls (excluding structured_output)
|
||||||
if function_calls and available_functions:
|
if non_structured_output_calls and available_functions:
|
||||||
for call_data in function_calls.values():
|
for call_data in non_structured_output_calls.values():
|
||||||
function_name = call_data["name"]
|
function_name = call_data["name"]
|
||||||
function_args = call_data["args"]
|
function_args = call_data["args"]
|
||||||
|
|
||||||
@@ -948,6 +1069,9 @@ class GeminiCompletion(BaseLLM):
|
|||||||
if result is not None:
|
if result is not None:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
# When tools are present, structured output should come via the structured_output
|
||||||
|
# pseudo-tool, not via direct text response. If we reach here with tools present,
|
||||||
|
# the LLM chose to return plain text instead of calling structured_output.
|
||||||
effective_response_model = None if self.tools else response_model
|
effective_response_model = None if self.tools else response_model
|
||||||
|
|
||||||
return self._finalize_completion_response(
|
return self._finalize_completion_response(
|
||||||
|
|||||||
Reference in New Issue
Block a user