mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 16:18:30 +00:00
Add thinking_budget parameter support for Gemini models
- Add thinking_budget parameter to LLM class constructor with validation
- Pass thinking_budget to litellm as thinking={'budget_tokens': value}
- Add comprehensive tests for validation and parameter passing
- Resolves issue #3299
Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -313,6 +313,7 @@ class LLM(BaseLLM):
|
|||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
callbacks: List[Any] = [],
|
callbacks: List[Any] = [],
|
||||||
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
|
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
|
||||||
|
thinking_budget: Optional[int] = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@@ -337,10 +338,14 @@ class LLM(BaseLLM):
|
|||||||
self.callbacks = callbacks
|
self.callbacks = callbacks
|
||||||
self.context_window_size = 0
|
self.context_window_size = 0
|
||||||
self.reasoning_effort = reasoning_effort
|
self.reasoning_effort = reasoning_effort
|
||||||
|
self.thinking_budget = thinking_budget
|
||||||
self.additional_params = kwargs
|
self.additional_params = kwargs
|
||||||
self.is_anthropic = self._is_anthropic_model(model)
|
self.is_anthropic = self._is_anthropic_model(model)
|
||||||
self.stream = stream
|
self.stream = stream
|
||||||
|
|
||||||
|
if self.thinking_budget is not None and (not isinstance(self.thinking_budget, int) or self.thinking_budget <= 0):
|
||||||
|
raise ValueError("thinking_budget must be a positive integer")
|
||||||
|
|
||||||
litellm.drop_params = True
|
litellm.drop_params = True
|
||||||
|
|
||||||
# Normalize self.stop to always be a List[str]
|
# Normalize self.stop to always be a List[str]
|
||||||
@@ -414,6 +419,9 @@ class LLM(BaseLLM):
|
|||||||
**self.additional_params,
|
**self.additional_params,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self.thinking_budget is not None:
|
||||||
|
params["thinking"] = {"budget_tokens": self.thinking_budget}
|
||||||
|
|
||||||
# Remove None values from params
|
# Remove None values from params
|
||||||
return {k: v for k, v in params.items() if v is not None}
|
return {k: v for k, v in params.items() if v is not None}
|
||||||
|
|
||||||
|
|||||||
@@ -354,6 +354,63 @@ def test_context_window_validation():
|
|||||||
assert "must be between 1024 and 2097152" in str(excinfo.value)
|
assert "must be between 1024 and 2097152" in str(excinfo.value)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.vcr(filter_headers=["authorization"], filter_query_parameters=["key"])
|
||||||
|
def test_gemini_thinking_budget():
|
||||||
|
llm = LLM(
|
||||||
|
model="gemini/gemini-2.0-flash-thinking-exp-01-21",
|
||||||
|
thinking_budget=1024,
|
||||||
|
)
|
||||||
|
result = llm.call("What is the capital of France?")
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert "Paris" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_thinking_budget_validation():
|
||||||
|
# Test valid thinking_budget
|
||||||
|
llm = LLM(model="gemini/gemini-2.0-flash-thinking-exp-01-21", thinking_budget=1024)
|
||||||
|
assert llm.thinking_budget == 1024
|
||||||
|
|
||||||
|
# Test invalid thinking_budget (negative)
|
||||||
|
with pytest.raises(ValueError, match="thinking_budget must be a positive integer"):
|
||||||
|
LLM(model="gemini/gemini-2.0-flash-thinking-exp-01-21", thinking_budget=-1)
|
||||||
|
|
||||||
|
# Test invalid thinking_budget (zero)
|
||||||
|
with pytest.raises(ValueError, match="thinking_budget must be a positive integer"):
|
||||||
|
LLM(model="gemini/gemini-2.0-flash-thinking-exp-01-21", thinking_budget=0)
|
||||||
|
|
||||||
|
# Test invalid thinking_budget (non-integer)
|
||||||
|
with pytest.raises(ValueError, match="thinking_budget must be a positive integer"):
|
||||||
|
LLM(model="gemini/gemini-2.0-flash-thinking-exp-01-21", thinking_budget=1024.5)
|
||||||
|
|
||||||
|
|
||||||
|
def test_thinking_budget_parameter_passing():
|
||||||
|
llm = LLM(
|
||||||
|
model="gemini/gemini-2.0-flash-thinking-exp-01-21",
|
||||||
|
thinking_budget=2048,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("litellm.completion") as mocked_completion:
|
||||||
|
mock_message = MagicMock()
|
||||||
|
mock_message.content = "Test response"
|
||||||
|
mock_choice = MagicMock()
|
||||||
|
mock_choice.message = mock_message
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.choices = [mock_choice]
|
||||||
|
mock_response.usage = {
|
||||||
|
"prompt_tokens": 5,
|
||||||
|
"completion_tokens": 5,
|
||||||
|
"total_tokens": 10,
|
||||||
|
}
|
||||||
|
mocked_completion.return_value = mock_response
|
||||||
|
|
||||||
|
result = llm.call("Test message")
|
||||||
|
|
||||||
|
_, kwargs = mocked_completion.call_args
|
||||||
|
assert "thinking" in kwargs
|
||||||
|
assert kwargs["thinking"]["budget_tokens"] == 2048
|
||||||
|
assert result == "Test response"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def get_weather_tool_schema():
|
def get_weather_tool_schema():
|
||||||
return {
|
return {
|
||||||
|
|||||||
Reference in New Issue
Block a user