Compare commits

...

1 Commits

Author SHA1 Message Date
Devin AI
affe5709c1 fix: capture thought output from Gemini thinking models (issue #4647)
- Add thinking_config parameter to GeminiCompletion.__init__ (accepts dict or ThinkingConfig)
- Include thinking_config in _prepare_generation_config when set
- Rewrite _process_stream_chunk to iterate over parts directly instead of using chunk.text, avoiding warnings when non-text parts (thought, function_call) are present
- Convert _extract_text_from_response from staticmethod to instance method; separate thought parts from text parts and store thoughts in self.previous_thoughts
- Add 11 tests covering thinking config initialization, generation config integration, thought part extraction in streaming and non-streaming paths

Co-Authored-By: João <joao@crewai.com>
2026-02-28 12:14:10 +00:00
2 changed files with 348 additions and 17 deletions

View File

@@ -34,6 +34,8 @@ except ImportError:
) from None
logger = logging.getLogger(__name__)
STRUCTURED_OUTPUT_TOOL_NAME = "structured_output"
@@ -61,6 +63,7 @@ class GeminiCompletion(BaseLLM):
interceptor: BaseInterceptor[Any, Any] | None = None,
use_vertexai: bool | None = None,
response_format: type[BaseModel] | None = None,
thinking_config: types.ThinkingConfig | dict[str, Any] | None = None,
**kwargs: Any,
):
"""Initialize Google Gemini chat completion client.
@@ -93,6 +96,14 @@ class GeminiCompletion(BaseLLM):
api_version="v1" is automatically configured.
response_format: Pydantic model for structured output. Used as default when
response_model is not passed to call()/acall() methods.
thinking_config: Configuration for Gemini thinking models (e.g. gemini-2.5-pro).
Can be a ThinkingConfig object or a dict with 'include_thoughts'
and optionally 'thinking_budget' keys.
When enabled, the model's reasoning/thought output is captured
and logged. Example:
thinking_config={"include_thoughts": True}
thinking_config=ThinkingConfig(include_thoughts=True,
thinking_budget=10000)
**kwargs: Additional parameters
"""
if interceptor is not None:
@@ -130,6 +141,17 @@ class GeminiCompletion(BaseLLM):
self.tools: list[dict[str, Any]] | None = None
self.response_format = response_format
# Thinking config for Gemini thinking models (e.g. gemini-2.5-pro)
if isinstance(thinking_config, dict):
self.thinking_config: types.ThinkingConfig | None = types.ThinkingConfig(
**thinking_config
)
else:
self.thinking_config = thinking_config
# Store previous thought content for multi-turn conversations
self.previous_thoughts: list[str] = []
# Model-specific settings
version_match = re.search(r"gemini-(\d+(?:\.\d+)?)", model.lower())
self.supports_tools = bool(
@@ -481,6 +503,10 @@ class GeminiCompletion(BaseLLM):
if self.stop_sequences:
config_params["stop_sequences"] = self.stop_sequences
# Add thinking config for thinking models (e.g. gemini-2.5-pro)
if self.thinking_config is not None:
config_params["thinking_config"] = self.thinking_config
if tools and self.supports_tools:
gemini_tools = self._convert_tools_for_interference(tools)
@@ -916,6 +942,11 @@ class GeminiCompletion(BaseLLM):
) -> tuple[str, dict[int, dict[str, Any]], dict[str, int]]:
"""Process a single streaming chunk.
Instead of using ``chunk.text`` (which triggers a warning when non-text
parts like ``function_call`` or ``thought_signature`` are present), this
method iterates over the candidate parts directly to extract text,
thought content, and function calls without side effects.
Args:
chunk: The streaming chunk response
full_response: Accumulated response text
@@ -931,19 +962,31 @@ class GeminiCompletion(BaseLLM):
if chunk.usage_metadata:
usage_data = self._extract_token_usage(chunk)
if chunk.text:
full_response += chunk.text
self._emit_stream_chunk_event(
chunk=chunk.text,
from_task=from_task,
from_agent=from_agent,
response_id=response_id,
)
# Iterate over parts directly to avoid the warning triggered by chunk.text
# when non-text parts (function_call, thought_signature) are present.
if chunk.candidates:
candidate = chunk.candidates[0]
if candidate.content and candidate.content.parts:
for part in candidate.content.parts:
# Handle thought parts from thinking models
if getattr(part, "thought", False) and part.text:
logger.debug(
"Gemini thinking model thought: %s", part.text
)
self.previous_thoughts.append(part.text)
continue
# Handle regular text parts
if hasattr(part, "text") and part.text and not part.function_call:
full_response += part.text
self._emit_stream_chunk_event(
chunk=part.text,
from_task=from_task,
from_agent=from_agent,
response_id=response_id,
)
# Handle function call parts
if part.function_call:
call_index = len(function_calls)
call_id = f"call_{call_index}"
@@ -1305,19 +1348,21 @@ class GeminiCompletion(BaseLLM):
}
return {"total_tokens": 0}
@staticmethod
def _extract_text_from_response(response: GenerateContentResponse) -> str:
def _extract_text_from_response(self, response: GenerateContentResponse) -> str:
"""Extract text content from Gemini response without triggering warnings.
This method directly accesses the response parts to extract text content,
avoiding the warning that occurs when using response.text on responses
containing non-text parts (e.g., 'thought_signature' from thinking models).
Thought parts (where ``part.thought == True``) are separated from regular
text and stored in ``self.previous_thoughts`` for downstream access.
Args:
response: The Gemini API response
Returns:
Concatenated text content from all text parts
Concatenated text content from all non-thought text parts
"""
if not response.candidates:
return ""
@@ -1326,11 +1371,13 @@ class GeminiCompletion(BaseLLM):
if not candidate.content or not candidate.content.parts:
return ""
text_parts = [
part.text
for part in candidate.content.parts
if hasattr(part, "text") and part.text
]
text_parts: list[str] = []
for part in candidate.content.parts:
if getattr(part, "thought", False) and part.text:
logger.debug("Gemini thinking model thought: %s", part.text)
self.previous_thoughts.append(part.text)
elif hasattr(part, "text") and part.text:
text_parts.append(part.text)
return "".join(text_parts)

View File

@@ -1190,3 +1190,287 @@ def test_gemini_cached_prompt_tokens_with_tools():
# cached_prompt_tokens should be populated (may be 0 if Gemini
# doesn't cache for this particular request, but the field should exist)
assert usage.cached_prompt_tokens >= 0
# ────────────────────────────────────────────────────────────────────────────
# Tests for Gemini thinking model support (issue #4647)
# ────────────────────────────────────────────────────────────────────────────
def test_gemini_thinking_config_dict_initialization():
"""Test that thinking_config can be passed as a dict and is converted to ThinkingConfig."""
from google.genai import types as genai_types
llm = LLM(
model="google/gemini-2.5-flash",
thinking_config={"include_thoughts": True},
)
from crewai.llms.providers.gemini.completion import GeminiCompletion
assert isinstance(llm, GeminiCompletion)
assert llm.thinking_config is not None
assert isinstance(llm.thinking_config, genai_types.ThinkingConfig)
assert llm.thinking_config.include_thoughts is True
def test_gemini_thinking_config_object_initialization():
"""Test that thinking_config can be passed as a ThinkingConfig object."""
from google.genai import types as genai_types
tc = genai_types.ThinkingConfig(include_thoughts=True, thinking_budget=10000)
llm = LLM(
model="google/gemini-2.5-flash",
thinking_config=tc,
)
from crewai.llms.providers.gemini.completion import GeminiCompletion
assert isinstance(llm, GeminiCompletion)
assert llm.thinking_config is tc
assert llm.thinking_config.include_thoughts is True
assert llm.thinking_config.thinking_budget == 10000
def test_gemini_thinking_config_none_by_default():
"""Test that thinking_config is None when not provided."""
llm = LLM(model="google/gemini-2.0-flash-001")
from crewai.llms.providers.gemini.completion import GeminiCompletion
assert isinstance(llm, GeminiCompletion)
assert llm.thinking_config is None
def test_gemini_thinking_config_in_generation_config():
"""Test that thinking_config is included in the GenerateContentConfig."""
from google.genai import types as genai_types
llm = LLM(
model="google/gemini-2.5-flash",
thinking_config={"include_thoughts": True},
)
config = llm._prepare_generation_config()
assert config.thinking_config is not None
assert isinstance(config.thinking_config, genai_types.ThinkingConfig)
assert config.thinking_config.include_thoughts is True
def test_gemini_thinking_config_not_in_generation_config_when_none():
"""Test that thinking_config is absent from GenerateContentConfig when not set."""
llm = LLM(model="google/gemini-2.0-flash-001")
config = llm._prepare_generation_config()
assert config.thinking_config is None
def test_gemini_extract_text_filters_out_thought_parts():
"""Test that _extract_text_from_response separates thought parts from text."""
llm = LLM(model="google/gemini-2.5-flash")
# Build a fake response with thought + text parts
mock_response = MagicMock()
thought_part = MagicMock()
thought_part.thought = True
thought_part.text = "Let me think about this..."
thought_part.function_call = None
text_part = MagicMock()
text_part.thought = False
text_part.text = "The answer is 42."
text_part.function_call = None
candidate = MagicMock()
candidate.content.parts = [thought_part, text_part]
mock_response.candidates = [candidate]
llm.previous_thoughts = []
result = llm._extract_text_from_response(mock_response)
assert result == "The answer is 42."
assert len(llm.previous_thoughts) == 1
assert llm.previous_thoughts[0] == "Let me think about this..."
def test_gemini_extract_text_no_thought_parts():
"""Test _extract_text_from_response with no thought parts (normal response)."""
llm = LLM(model="google/gemini-2.0-flash-001")
mock_response = MagicMock()
text_part = MagicMock()
text_part.thought = False
text_part.text = "Hello world"
text_part.function_call = None
candidate = MagicMock()
candidate.content.parts = [text_part]
mock_response.candidates = [candidate]
llm.previous_thoughts = []
result = llm._extract_text_from_response(mock_response)
assert result == "Hello world"
assert len(llm.previous_thoughts) == 0
def test_gemini_stream_chunk_handles_thought_parts():
"""Test that _process_stream_chunk captures thought parts and emits text parts."""
import json as json_mod
llm = LLM(model="google/gemini-2.5-flash")
llm.previous_thoughts = []
# Build a mock chunk with a thought part and a text part
thought_part = MagicMock()
thought_part.thought = True
thought_part.text = "Reasoning step 1"
thought_part.function_call = None
text_part = MagicMock()
text_part.thought = False
text_part.text = "Final answer"
text_part.function_call = None
chunk = MagicMock()
chunk.response_id = "resp_123"
chunk.usage_metadata = None
candidate = MagicMock()
candidate.content.parts = [thought_part, text_part]
chunk.candidates = [candidate]
with patch.object(llm, "_emit_stream_chunk_event"):
full_response, function_calls, usage_data = llm._process_stream_chunk(
chunk=chunk,
full_response="",
function_calls={},
usage_data={"total_tokens": 0},
)
assert full_response == "Final answer"
assert len(llm.previous_thoughts) == 1
assert llm.previous_thoughts[0] == "Reasoning step 1"
def test_gemini_stream_chunk_handles_function_call_without_warning():
"""Test that _process_stream_chunk handles function calls without triggering chunk.text."""
llm = LLM(model="google/gemini-2.5-flash")
llm.previous_thoughts = []
# Build a mock chunk with a function call part
func_part = MagicMock()
func_part.thought = False
func_part.text = None
func_part.function_call.name = "get_weather"
func_part.function_call.args = {"location": "Tokyo"}
chunk = MagicMock()
chunk.response_id = "resp_456"
chunk.usage_metadata = None
candidate = MagicMock()
candidate.content.parts = [func_part]
chunk.candidates = [candidate]
with patch.object(llm, "_emit_stream_chunk_event"):
full_response, function_calls, usage_data = llm._process_stream_chunk(
chunk=chunk,
full_response="",
function_calls={},
usage_data={"total_tokens": 0},
)
assert full_response == ""
assert len(function_calls) == 1
assert function_calls[0]["name"] == "get_weather"
assert function_calls[0]["args"] == {"location": "Tokyo"}
def test_gemini_stream_chunk_mixed_thought_text_and_function_call():
"""Test _process_stream_chunk with thought, text, and function call parts."""
llm = LLM(model="google/gemini-2.5-flash")
llm.previous_thoughts = []
thought_part = MagicMock()
thought_part.thought = True
thought_part.text = "I need to use a tool"
thought_part.function_call = None
func_part = MagicMock()
func_part.thought = False
func_part.text = None
func_part.function_call.name = "search"
func_part.function_call.args = {"query": "hello"}
chunk = MagicMock()
chunk.response_id = "resp_789"
chunk.usage_metadata = None
candidate = MagicMock()
candidate.content.parts = [thought_part, func_part]
chunk.candidates = [candidate]
with patch.object(llm, "_emit_stream_chunk_event"):
full_response, function_calls, usage_data = llm._process_stream_chunk(
chunk=chunk,
full_response="",
function_calls={},
usage_data={"total_tokens": 0},
)
assert full_response == ""
assert len(function_calls) == 1
assert function_calls[0]["name"] == "search"
assert len(llm.previous_thoughts) == 1
assert llm.previous_thoughts[0] == "I need to use a tool"
def test_gemini_previous_thoughts_accumulate_across_chunks():
"""Test that previous_thoughts accumulate across multiple streaming chunks."""
llm = LLM(model="google/gemini-2.5-flash")
llm.previous_thoughts = []
# First chunk with thought
thought1 = MagicMock()
thought1.thought = True
thought1.text = "Step 1"
thought1.function_call = None
chunk1 = MagicMock()
chunk1.response_id = "resp_1"
chunk1.usage_metadata = None
candidate1 = MagicMock()
candidate1.content.parts = [thought1]
chunk1.candidates = [candidate1]
# Second chunk with thought + text
thought2 = MagicMock()
thought2.thought = True
thought2.text = "Step 2"
thought2.function_call = None
text_part = MagicMock()
text_part.thought = False
text_part.text = "Result"
text_part.function_call = None
chunk2 = MagicMock()
chunk2.response_id = "resp_1"
chunk2.usage_metadata = None
candidate2 = MagicMock()
candidate2.content.parts = [thought2, text_part]
chunk2.candidates = [candidate2]
with patch.object(llm, "_emit_stream_chunk_event"):
full_response = ""
function_calls: dict = {}
usage_data = {"total_tokens": 0}
full_response, function_calls, usage_data = llm._process_stream_chunk(
chunk=chunk1, full_response=full_response,
function_calls=function_calls, usage_data=usage_data,
)
full_response, function_calls, usage_data = llm._process_stream_chunk(
chunk=chunk2, full_response=full_response,
function_calls=function_calls, usage_data=usage_data,
)
assert full_response == "Result"
assert len(llm.previous_thoughts) == 2
assert llm.previous_thoughts[0] == "Step 1"
assert llm.previous_thoughts[1] == "Step 2"