From d0c5120b32bc639ba9463479b0402ea197a7939f Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Sat, 9 Aug 2025 10:15:57 +0000 Subject: [PATCH] Add thinking_budget parameter support for Gemini models MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add thinking_budget parameter to LLM class constructor with validation - Pass thinking_budget to litellm as thinking={'budget_tokens': value} - Add comprehensive tests for validation and parameter passing - Resolves issue #3299 Co-Authored-By: João --- src/crewai/llm.py | 8 +++++++ tests/llm_test.py | 57 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+) diff --git a/src/crewai/llm.py b/src/crewai/llm.py index c701ddf0b..040a3a8d4 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -313,6 +313,7 @@ class LLM(BaseLLM): api_key: Optional[str] = None, callbacks: List[Any] = [], reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None, + thinking_budget: Optional[int] = None, stream: bool = False, **kwargs, ): @@ -337,10 +338,14 @@ class LLM(BaseLLM): self.callbacks = callbacks self.context_window_size = 0 self.reasoning_effort = reasoning_effort + self.thinking_budget = thinking_budget self.additional_params = kwargs self.is_anthropic = self._is_anthropic_model(model) self.stream = stream + if self.thinking_budget is not None and (not isinstance(self.thinking_budget, int) or self.thinking_budget <= 0): + raise ValueError("thinking_budget must be a positive integer") + litellm.drop_params = True # Normalize self.stop to always be a List[str] @@ -414,6 +419,9 @@ class LLM(BaseLLM): **self.additional_params, } + if self.thinking_budget is not None: + params["thinking"] = {"budget_tokens": self.thinking_budget} + # Remove None values from params return {k: v for k, v in params.items() if v is not None} diff --git a/tests/llm_test.py b/tests/llm_test.py index 20f1d8108..7859fe82b 100644 --- a/tests/llm_test.py +++ b/tests/llm_test.py @@ -354,6 +354,63 @@ def test_context_window_validation(): assert "must be between 1024 and 2097152" in str(excinfo.value) +@pytest.mark.vcr(filter_headers=["authorization"], filter_query_parameters=["key"]) +def test_gemini_thinking_budget(): + llm = LLM( + model="gemini/gemini-2.0-flash-thinking-exp-01-21", + thinking_budget=1024, + ) + result = llm.call("What is the capital of France?") + assert isinstance(result, str) + assert "Paris" in result + + +def test_thinking_budget_validation(): + # Test valid thinking_budget + llm = LLM(model="gemini/gemini-2.0-flash-thinking-exp-01-21", thinking_budget=1024) + assert llm.thinking_budget == 1024 + + # Test invalid thinking_budget (negative) + with pytest.raises(ValueError, match="thinking_budget must be a positive integer"): + LLM(model="gemini/gemini-2.0-flash-thinking-exp-01-21", thinking_budget=-1) + + # Test invalid thinking_budget (zero) + with pytest.raises(ValueError, match="thinking_budget must be a positive integer"): + LLM(model="gemini/gemini-2.0-flash-thinking-exp-01-21", thinking_budget=0) + + # Test invalid thinking_budget (non-integer) + with pytest.raises(ValueError, match="thinking_budget must be a positive integer"): + LLM(model="gemini/gemini-2.0-flash-thinking-exp-01-21", thinking_budget=1024.5) + + +def test_thinking_budget_parameter_passing(): + llm = LLM( + model="gemini/gemini-2.0-flash-thinking-exp-01-21", + thinking_budget=2048, + ) + + with patch("litellm.completion") as mocked_completion: + mock_message = MagicMock() + mock_message.content = "Test response" + mock_choice = MagicMock() + mock_choice.message = mock_message + mock_response = MagicMock() + mock_response.choices = [mock_choice] + mock_response.usage = { + "prompt_tokens": 5, + "completion_tokens": 5, + "total_tokens": 10, + } + mocked_completion.return_value = mock_response + + result = llm.call("Test message") + + _, kwargs = mocked_completion.call_args + assert "thinking" in kwargs + assert kwargs["thinking"]["budget_tokens"] == 2048 + assert result == "Test response" + + @pytest.fixture def get_weather_tool_schema(): return {