diff --git a/lib/crewai/src/crewai/llms/providers/gemini/completion.py b/lib/crewai/src/crewai/llms/providers/gemini/completion.py index bd634c8dc..1843a839c 100644 --- a/lib/crewai/src/crewai/llms/providers/gemini/completion.py +++ b/lib/crewai/src/crewai/llms/providers/gemini/completion.py @@ -34,6 +34,8 @@ except ImportError: ) from None +logger = logging.getLogger(__name__) + STRUCTURED_OUTPUT_TOOL_NAME = "structured_output" @@ -61,6 +63,7 @@ class GeminiCompletion(BaseLLM): interceptor: BaseInterceptor[Any, Any] | None = None, use_vertexai: bool | None = None, response_format: type[BaseModel] | None = None, + thinking_config: types.ThinkingConfig | dict[str, Any] | None = None, **kwargs: Any, ): """Initialize Google Gemini chat completion client. @@ -93,6 +96,14 @@ class GeminiCompletion(BaseLLM): api_version="v1" is automatically configured. response_format: Pydantic model for structured output. Used as default when response_model is not passed to call()/acall() methods. + thinking_config: Configuration for Gemini thinking models (e.g. gemini-2.5-pro). + Can be a ThinkingConfig object or a dict with 'include_thoughts' + and optionally 'thinking_budget' keys. + When enabled, the model's reasoning/thought output is captured + and logged. Example: + thinking_config={"include_thoughts": True} + thinking_config=ThinkingConfig(include_thoughts=True, + thinking_budget=10000) **kwargs: Additional parameters """ if interceptor is not None: @@ -130,6 +141,17 @@ class GeminiCompletion(BaseLLM): self.tools: list[dict[str, Any]] | None = None self.response_format = response_format + # Thinking config for Gemini thinking models (e.g. gemini-2.5-pro) + if isinstance(thinking_config, dict): + self.thinking_config: types.ThinkingConfig | None = types.ThinkingConfig( + **thinking_config + ) + else: + self.thinking_config = thinking_config + + # Store previous thought content for multi-turn conversations + self.previous_thoughts: list[str] = [] + # Model-specific settings version_match = re.search(r"gemini-(\d+(?:\.\d+)?)", model.lower()) self.supports_tools = bool( @@ -481,6 +503,10 @@ class GeminiCompletion(BaseLLM): if self.stop_sequences: config_params["stop_sequences"] = self.stop_sequences + # Add thinking config for thinking models (e.g. gemini-2.5-pro) + if self.thinking_config is not None: + config_params["thinking_config"] = self.thinking_config + if tools and self.supports_tools: gemini_tools = self._convert_tools_for_interference(tools) @@ -916,6 +942,11 @@ class GeminiCompletion(BaseLLM): ) -> tuple[str, dict[int, dict[str, Any]], dict[str, int]]: """Process a single streaming chunk. + Instead of using ``chunk.text`` (which triggers a warning when non-text + parts like ``function_call`` or ``thought_signature`` are present), this + method iterates over the candidate parts directly to extract text, + thought content, and function calls without side effects. + Args: chunk: The streaming chunk response full_response: Accumulated response text @@ -931,19 +962,31 @@ class GeminiCompletion(BaseLLM): if chunk.usage_metadata: usage_data = self._extract_token_usage(chunk) - if chunk.text: - full_response += chunk.text - self._emit_stream_chunk_event( - chunk=chunk.text, - from_task=from_task, - from_agent=from_agent, - response_id=response_id, - ) - + # Iterate over parts directly to avoid the warning triggered by chunk.text + # when non-text parts (function_call, thought_signature) are present. if chunk.candidates: candidate = chunk.candidates[0] if candidate.content and candidate.content.parts: for part in candidate.content.parts: + # Handle thought parts from thinking models + if getattr(part, "thought", False) and part.text: + logger.debug( + "Gemini thinking model thought: %s", part.text + ) + self.previous_thoughts.append(part.text) + continue + + # Handle regular text parts + if hasattr(part, "text") and part.text and not part.function_call: + full_response += part.text + self._emit_stream_chunk_event( + chunk=part.text, + from_task=from_task, + from_agent=from_agent, + response_id=response_id, + ) + + # Handle function call parts if part.function_call: call_index = len(function_calls) call_id = f"call_{call_index}" @@ -1305,19 +1348,21 @@ class GeminiCompletion(BaseLLM): } return {"total_tokens": 0} - @staticmethod - def _extract_text_from_response(response: GenerateContentResponse) -> str: + def _extract_text_from_response(self, response: GenerateContentResponse) -> str: """Extract text content from Gemini response without triggering warnings. This method directly accesses the response parts to extract text content, avoiding the warning that occurs when using response.text on responses containing non-text parts (e.g., 'thought_signature' from thinking models). + Thought parts (where ``part.thought == True``) are separated from regular + text and stored in ``self.previous_thoughts`` for downstream access. + Args: response: The Gemini API response Returns: - Concatenated text content from all text parts + Concatenated text content from all non-thought text parts """ if not response.candidates: return "" @@ -1326,11 +1371,13 @@ class GeminiCompletion(BaseLLM): if not candidate.content or not candidate.content.parts: return "" - text_parts = [ - part.text - for part in candidate.content.parts - if hasattr(part, "text") and part.text - ] + text_parts: list[str] = [] + for part in candidate.content.parts: + if getattr(part, "thought", False) and part.text: + logger.debug("Gemini thinking model thought: %s", part.text) + self.previous_thoughts.append(part.text) + elif hasattr(part, "text") and part.text: + text_parts.append(part.text) return "".join(text_parts) diff --git a/lib/crewai/tests/llms/google/test_google.py b/lib/crewai/tests/llms/google/test_google.py index 6f475ef49..8075d0111 100644 --- a/lib/crewai/tests/llms/google/test_google.py +++ b/lib/crewai/tests/llms/google/test_google.py @@ -1190,3 +1190,287 @@ def test_gemini_cached_prompt_tokens_with_tools(): # cached_prompt_tokens should be populated (may be 0 if Gemini # doesn't cache for this particular request, but the field should exist) assert usage.cached_prompt_tokens >= 0 + + +# ──────────────────────────────────────────────────────────────────────────── +# Tests for Gemini thinking model support (issue #4647) +# ──────────────────────────────────────────────────────────────────────────── + + +def test_gemini_thinking_config_dict_initialization(): + """Test that thinking_config can be passed as a dict and is converted to ThinkingConfig.""" + from google.genai import types as genai_types + + llm = LLM( + model="google/gemini-2.5-flash", + thinking_config={"include_thoughts": True}, + ) + + from crewai.llms.providers.gemini.completion import GeminiCompletion + + assert isinstance(llm, GeminiCompletion) + assert llm.thinking_config is not None + assert isinstance(llm.thinking_config, genai_types.ThinkingConfig) + assert llm.thinking_config.include_thoughts is True + + +def test_gemini_thinking_config_object_initialization(): + """Test that thinking_config can be passed as a ThinkingConfig object.""" + from google.genai import types as genai_types + + tc = genai_types.ThinkingConfig(include_thoughts=True, thinking_budget=10000) + llm = LLM( + model="google/gemini-2.5-flash", + thinking_config=tc, + ) + + from crewai.llms.providers.gemini.completion import GeminiCompletion + + assert isinstance(llm, GeminiCompletion) + assert llm.thinking_config is tc + assert llm.thinking_config.include_thoughts is True + assert llm.thinking_config.thinking_budget == 10000 + + +def test_gemini_thinking_config_none_by_default(): + """Test that thinking_config is None when not provided.""" + llm = LLM(model="google/gemini-2.0-flash-001") + + from crewai.llms.providers.gemini.completion import GeminiCompletion + + assert isinstance(llm, GeminiCompletion) + assert llm.thinking_config is None + + +def test_gemini_thinking_config_in_generation_config(): + """Test that thinking_config is included in the GenerateContentConfig.""" + from google.genai import types as genai_types + + llm = LLM( + model="google/gemini-2.5-flash", + thinking_config={"include_thoughts": True}, + ) + + config = llm._prepare_generation_config() + assert config.thinking_config is not None + assert isinstance(config.thinking_config, genai_types.ThinkingConfig) + assert config.thinking_config.include_thoughts is True + + +def test_gemini_thinking_config_not_in_generation_config_when_none(): + """Test that thinking_config is absent from GenerateContentConfig when not set.""" + llm = LLM(model="google/gemini-2.0-flash-001") + + config = llm._prepare_generation_config() + assert config.thinking_config is None + + +def test_gemini_extract_text_filters_out_thought_parts(): + """Test that _extract_text_from_response separates thought parts from text.""" + llm = LLM(model="google/gemini-2.5-flash") + + # Build a fake response with thought + text parts + mock_response = MagicMock() + thought_part = MagicMock() + thought_part.thought = True + thought_part.text = "Let me think about this..." + thought_part.function_call = None + + text_part = MagicMock() + text_part.thought = False + text_part.text = "The answer is 42." + text_part.function_call = None + + candidate = MagicMock() + candidate.content.parts = [thought_part, text_part] + mock_response.candidates = [candidate] + + llm.previous_thoughts = [] + result = llm._extract_text_from_response(mock_response) + + assert result == "The answer is 42." + assert len(llm.previous_thoughts) == 1 + assert llm.previous_thoughts[0] == "Let me think about this..." + + +def test_gemini_extract_text_no_thought_parts(): + """Test _extract_text_from_response with no thought parts (normal response).""" + llm = LLM(model="google/gemini-2.0-flash-001") + + mock_response = MagicMock() + text_part = MagicMock() + text_part.thought = False + text_part.text = "Hello world" + text_part.function_call = None + + candidate = MagicMock() + candidate.content.parts = [text_part] + mock_response.candidates = [candidate] + + llm.previous_thoughts = [] + result = llm._extract_text_from_response(mock_response) + + assert result == "Hello world" + assert len(llm.previous_thoughts) == 0 + + +def test_gemini_stream_chunk_handles_thought_parts(): + """Test that _process_stream_chunk captures thought parts and emits text parts.""" + import json as json_mod + + llm = LLM(model="google/gemini-2.5-flash") + llm.previous_thoughts = [] + + # Build a mock chunk with a thought part and a text part + thought_part = MagicMock() + thought_part.thought = True + thought_part.text = "Reasoning step 1" + thought_part.function_call = None + + text_part = MagicMock() + text_part.thought = False + text_part.text = "Final answer" + text_part.function_call = None + + chunk = MagicMock() + chunk.response_id = "resp_123" + chunk.usage_metadata = None + candidate = MagicMock() + candidate.content.parts = [thought_part, text_part] + chunk.candidates = [candidate] + + with patch.object(llm, "_emit_stream_chunk_event"): + full_response, function_calls, usage_data = llm._process_stream_chunk( + chunk=chunk, + full_response="", + function_calls={}, + usage_data={"total_tokens": 0}, + ) + + assert full_response == "Final answer" + assert len(llm.previous_thoughts) == 1 + assert llm.previous_thoughts[0] == "Reasoning step 1" + + +def test_gemini_stream_chunk_handles_function_call_without_warning(): + """Test that _process_stream_chunk handles function calls without triggering chunk.text.""" + llm = LLM(model="google/gemini-2.5-flash") + llm.previous_thoughts = [] + + # Build a mock chunk with a function call part + func_part = MagicMock() + func_part.thought = False + func_part.text = None + func_part.function_call.name = "get_weather" + func_part.function_call.args = {"location": "Tokyo"} + + chunk = MagicMock() + chunk.response_id = "resp_456" + chunk.usage_metadata = None + candidate = MagicMock() + candidate.content.parts = [func_part] + chunk.candidates = [candidate] + + with patch.object(llm, "_emit_stream_chunk_event"): + full_response, function_calls, usage_data = llm._process_stream_chunk( + chunk=chunk, + full_response="", + function_calls={}, + usage_data={"total_tokens": 0}, + ) + + assert full_response == "" + assert len(function_calls) == 1 + assert function_calls[0]["name"] == "get_weather" + assert function_calls[0]["args"] == {"location": "Tokyo"} + + +def test_gemini_stream_chunk_mixed_thought_text_and_function_call(): + """Test _process_stream_chunk with thought, text, and function call parts.""" + llm = LLM(model="google/gemini-2.5-flash") + llm.previous_thoughts = [] + + thought_part = MagicMock() + thought_part.thought = True + thought_part.text = "I need to use a tool" + thought_part.function_call = None + + func_part = MagicMock() + func_part.thought = False + func_part.text = None + func_part.function_call.name = "search" + func_part.function_call.args = {"query": "hello"} + + chunk = MagicMock() + chunk.response_id = "resp_789" + chunk.usage_metadata = None + candidate = MagicMock() + candidate.content.parts = [thought_part, func_part] + chunk.candidates = [candidate] + + with patch.object(llm, "_emit_stream_chunk_event"): + full_response, function_calls, usage_data = llm._process_stream_chunk( + chunk=chunk, + full_response="", + function_calls={}, + usage_data={"total_tokens": 0}, + ) + + assert full_response == "" + assert len(function_calls) == 1 + assert function_calls[0]["name"] == "search" + assert len(llm.previous_thoughts) == 1 + assert llm.previous_thoughts[0] == "I need to use a tool" + + +def test_gemini_previous_thoughts_accumulate_across_chunks(): + """Test that previous_thoughts accumulate across multiple streaming chunks.""" + llm = LLM(model="google/gemini-2.5-flash") + llm.previous_thoughts = [] + + # First chunk with thought + thought1 = MagicMock() + thought1.thought = True + thought1.text = "Step 1" + thought1.function_call = None + chunk1 = MagicMock() + chunk1.response_id = "resp_1" + chunk1.usage_metadata = None + candidate1 = MagicMock() + candidate1.content.parts = [thought1] + chunk1.candidates = [candidate1] + + # Second chunk with thought + text + thought2 = MagicMock() + thought2.thought = True + thought2.text = "Step 2" + thought2.function_call = None + text_part = MagicMock() + text_part.thought = False + text_part.text = "Result" + text_part.function_call = None + chunk2 = MagicMock() + chunk2.response_id = "resp_1" + chunk2.usage_metadata = None + candidate2 = MagicMock() + candidate2.content.parts = [thought2, text_part] + chunk2.candidates = [candidate2] + + with patch.object(llm, "_emit_stream_chunk_event"): + full_response = "" + function_calls: dict = {} + usage_data = {"total_tokens": 0} + + full_response, function_calls, usage_data = llm._process_stream_chunk( + chunk=chunk1, full_response=full_response, + function_calls=function_calls, usage_data=usage_data, + ) + full_response, function_calls, usage_data = llm._process_stream_chunk( + chunk=chunk2, full_response=full_response, + function_calls=function_calls, usage_data=usage_data, + ) + + assert full_response == "Result" + assert len(llm.previous_thoughts) == 2 + assert llm.previous_thoughts[0] == "Step 1" + assert llm.previous_thoughts[1] == "Step 2"