mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-22 14:48:13 +00:00
Compare commits
2 Commits
devin/1768
...
devin/1753
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1bdce4cc76 | ||
|
|
22761d74ba |
@@ -308,6 +308,7 @@ class LLM(BaseLLM):
|
|||||||
api_version: Optional[str] = None,
|
api_version: Optional[str] = None,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
callbacks: List[Any] = [],
|
callbacks: List[Any] = [],
|
||||||
|
reasoning: Optional[bool] = None,
|
||||||
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
|
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -332,6 +333,7 @@ class LLM(BaseLLM):
|
|||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.callbacks = callbacks
|
self.callbacks = callbacks
|
||||||
self.context_window_size = 0
|
self.context_window_size = 0
|
||||||
|
self.reasoning = reasoning
|
||||||
self.reasoning_effort = reasoning_effort
|
self.reasoning_effort = reasoning_effort
|
||||||
self.additional_params = kwargs
|
self.additional_params = kwargs
|
||||||
self.is_anthropic = self._is_anthropic_model(model)
|
self.is_anthropic = self._is_anthropic_model(model)
|
||||||
@@ -406,10 +408,15 @@ class LLM(BaseLLM):
|
|||||||
"api_key": self.api_key,
|
"api_key": self.api_key,
|
||||||
"stream": self.stream,
|
"stream": self.stream,
|
||||||
"tools": tools,
|
"tools": tools,
|
||||||
"reasoning_effort": self.reasoning_effort,
|
|
||||||
**self.additional_params,
|
**self.additional_params,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self.reasoning is False:
|
||||||
|
# When reasoning is explicitly disabled, don't include reasoning_effort
|
||||||
|
pass
|
||||||
|
elif self.reasoning is True or self.reasoning_effort is not None:
|
||||||
|
params["reasoning_effort"] = self.reasoning_effort
|
||||||
|
|
||||||
# Remove None values from params
|
# Remove None values from params
|
||||||
return {k: v for k, v in params.items() if v is not None}
|
return {k: v for k, v in params.items() if v is not None}
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import pytest
|
import pytest
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
from crewai import Agent, Task
|
from crewai import Agent, Task
|
||||||
from crewai.llm import LLM
|
from crewai.llm import LLM
|
||||||
@@ -259,3 +260,31 @@ def test_agent_with_function_calling_fallback():
|
|||||||
assert result == "4"
|
assert result == "4"
|
||||||
assert "Reasoning Plan:" in task.description
|
assert "Reasoning Plan:" in task.description
|
||||||
assert "Invalid JSON that will trigger fallback" in task.description
|
assert "Invalid JSON that will trigger fallback" in task.description
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_with_llm_reasoning_disabled():
|
||||||
|
"""Test agent with LLM reasoning disabled."""
|
||||||
|
llm = LLM("gpt-3.5-turbo", reasoning=False)
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="Test Agent",
|
||||||
|
goal="To test the LLM reasoning parameter",
|
||||||
|
backstory="I am a test agent created to verify the LLM reasoning parameter works correctly.",
|
||||||
|
llm=llm,
|
||||||
|
reasoning=False,
|
||||||
|
verbose=True
|
||||||
|
)
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Simple math task: What's 3+3?",
|
||||||
|
expected_output="The answer should be a number.",
|
||||||
|
agent=agent
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(agent.llm, 'call') as mock_call:
|
||||||
|
mock_call.return_value = "6"
|
||||||
|
|
||||||
|
result = agent.execute_task(task)
|
||||||
|
|
||||||
|
assert result == "6"
|
||||||
|
assert "Reasoning Plan:" not in task.description
|
||||||
|
|||||||
@@ -711,3 +711,99 @@ def test_ollama_does_not_modify_when_last_is_user(ollama_llm):
|
|||||||
formatted = ollama_llm._format_messages_for_provider(original_messages)
|
formatted = ollama_llm._format_messages_for_provider(original_messages)
|
||||||
|
|
||||||
assert formatted == original_messages
|
assert formatted == original_messages
|
||||||
|
|
||||||
|
|
||||||
|
def test_llm_reasoning_parameter_false():
|
||||||
|
"""Test that reasoning=False disables reasoning mode."""
|
||||||
|
llm = LLM(model="ollama/qwen", reasoning=False)
|
||||||
|
|
||||||
|
with patch("litellm.completion") as mock_completion:
|
||||||
|
mock_message = MagicMock()
|
||||||
|
mock_message.content = "Test response"
|
||||||
|
mock_choice = MagicMock()
|
||||||
|
mock_choice.message = mock_message
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.choices = [mock_choice]
|
||||||
|
mock_response.usage = {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10}
|
||||||
|
mock_completion.return_value = mock_response
|
||||||
|
|
||||||
|
llm.call("Test message")
|
||||||
|
|
||||||
|
_, kwargs = mock_completion.call_args
|
||||||
|
assert "reasoning_effort" not in kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def test_llm_reasoning_parameter_true():
|
||||||
|
"""Test that reasoning=True enables reasoning mode."""
|
||||||
|
llm = LLM(model="ollama/qwen", reasoning=True, reasoning_effort="medium")
|
||||||
|
|
||||||
|
with patch("litellm.completion") as mock_completion:
|
||||||
|
mock_message = MagicMock()
|
||||||
|
mock_message.content = "Test response"
|
||||||
|
mock_choice = MagicMock()
|
||||||
|
mock_choice.message = mock_message
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.choices = [mock_choice]
|
||||||
|
mock_response.usage = {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10}
|
||||||
|
mock_completion.return_value = mock_response
|
||||||
|
|
||||||
|
llm.call("Test message")
|
||||||
|
|
||||||
|
_, kwargs = mock_completion.call_args
|
||||||
|
assert kwargs["reasoning_effort"] == "medium"
|
||||||
|
|
||||||
|
|
||||||
|
def test_llm_reasoning_parameter_none_with_reasoning_effort():
|
||||||
|
"""Test that reasoning=None with reasoning_effort still includes reasoning_effort."""
|
||||||
|
llm = LLM(model="ollama/qwen", reasoning=None, reasoning_effort="high")
|
||||||
|
|
||||||
|
with patch("litellm.completion") as mock_completion:
|
||||||
|
mock_message = MagicMock()
|
||||||
|
mock_message.content = "Test response"
|
||||||
|
mock_choice = MagicMock()
|
||||||
|
mock_choice.message = mock_message
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.choices = [mock_choice]
|
||||||
|
mock_response.usage = {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10}
|
||||||
|
mock_completion.return_value = mock_response
|
||||||
|
|
||||||
|
llm.call("Test message")
|
||||||
|
|
||||||
|
_, kwargs = mock_completion.call_args
|
||||||
|
assert kwargs["reasoning_effort"] == "high"
|
||||||
|
|
||||||
|
|
||||||
|
def test_llm_reasoning_false_overrides_reasoning_effort():
|
||||||
|
"""Test that reasoning=False overrides reasoning_effort."""
|
||||||
|
llm = LLM(model="ollama/qwen", reasoning=False, reasoning_effort="high")
|
||||||
|
|
||||||
|
with patch("litellm.completion") as mock_completion:
|
||||||
|
mock_message = MagicMock()
|
||||||
|
mock_message.content = "Test response"
|
||||||
|
mock_choice = MagicMock()
|
||||||
|
mock_choice.message = mock_message
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.choices = [mock_choice]
|
||||||
|
mock_response.usage = {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10}
|
||||||
|
mock_completion.return_value = mock_response
|
||||||
|
|
||||||
|
llm.call("Test message")
|
||||||
|
|
||||||
|
_, kwargs = mock_completion.call_args
|
||||||
|
assert "reasoning_effort" not in kwargs
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
|
def test_ollama_qwen_with_reasoning_disabled():
|
||||||
|
"""Test Ollama Qwen model with reasoning disabled."""
|
||||||
|
if not os.getenv("OLLAMA_BASE_URL"):
|
||||||
|
pytest.skip("OLLAMA_BASE_URL not set; skipping test.")
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model="ollama/qwen",
|
||||||
|
base_url=os.getenv("OLLAMA_BASE_URL", "http://localhost:11434"),
|
||||||
|
reasoning=False
|
||||||
|
)
|
||||||
|
result = llm.call("What is 2+2?")
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert len(result.strip()) > 0
|
||||||
|
|||||||
Reference in New Issue
Block a user