mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
Fix issue #2738: Exclude stop parameter for o3 model
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -351,7 +351,6 @@ class LLM(BaseLLM):
|
|||||||
"temperature": self.temperature,
|
"temperature": self.temperature,
|
||||||
"top_p": self.top_p,
|
"top_p": self.top_p,
|
||||||
"n": self.n,
|
"n": self.n,
|
||||||
"stop": self.stop,
|
|
||||||
"max_tokens": self.max_tokens or self.max_completion_tokens,
|
"max_tokens": self.max_tokens or self.max_completion_tokens,
|
||||||
"presence_penalty": self.presence_penalty,
|
"presence_penalty": self.presence_penalty,
|
||||||
"frequency_penalty": self.frequency_penalty,
|
"frequency_penalty": self.frequency_penalty,
|
||||||
@@ -370,6 +369,9 @@ class LLM(BaseLLM):
|
|||||||
**self.additional_params,
|
**self.additional_params,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self.stop and self.supports_stop_words():
|
||||||
|
params["stop"] = self.stop
|
||||||
|
|
||||||
# Remove None values from params
|
# Remove None values from params
|
||||||
return {k: v for k, v in params.items() if v is not None}
|
return {k: v for k, v in params.items() if v is not None}
|
||||||
|
|
||||||
|
|||||||
52
tests/unit/test_llm.py
Normal file
52
tests/unit/test_llm.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
import unittest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from crewai.llm import LLM
|
||||||
|
|
||||||
|
|
||||||
|
class TestLLM(unittest.TestCase):
|
||||||
|
@patch("crewai.llm.litellm.completion")
|
||||||
|
@patch("crewai.llm.LLM.supports_stop_words")
|
||||||
|
def test_call_with_supported_stop_words(self, mock_supports_stop_words, mock_completion):
|
||||||
|
mock_supports_stop_words.return_value = True
|
||||||
|
|
||||||
|
message = SimpleNamespace(content="Hello, World!")
|
||||||
|
choice = SimpleNamespace(message=message)
|
||||||
|
response = SimpleNamespace(choices=[choice])
|
||||||
|
mock_completion.return_value = response
|
||||||
|
|
||||||
|
llm = LLM(model="gpt-4", stop=["STOP"])
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "Say Hello"}]
|
||||||
|
result = llm.call(messages)
|
||||||
|
|
||||||
|
mock_completion.assert_called_once()
|
||||||
|
call_args = mock_completion.call_args[1]
|
||||||
|
self.assertIn("stop", call_args)
|
||||||
|
self.assertEqual(call_args["stop"], ["STOP"])
|
||||||
|
self.assertEqual(result, "Hello, World!")
|
||||||
|
|
||||||
|
@patch("crewai.llm.litellm.completion")
|
||||||
|
@patch("crewai.llm.LLM.supports_stop_words")
|
||||||
|
def test_call_with_unsupported_stop_words(self, mock_supports_stop_words, mock_completion):
|
||||||
|
mock_supports_stop_words.return_value = False
|
||||||
|
|
||||||
|
message = SimpleNamespace(content="Hello, World!")
|
||||||
|
choice = SimpleNamespace(message=message)
|
||||||
|
response = SimpleNamespace(choices=[choice])
|
||||||
|
mock_completion.return_value = response
|
||||||
|
|
||||||
|
llm = LLM(model="o3", stop=["STOP"])
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "Say Hello"}]
|
||||||
|
result = llm.call(messages)
|
||||||
|
|
||||||
|
mock_completion.assert_called_once()
|
||||||
|
call_args = mock_completion.call_args[1]
|
||||||
|
self.assertNotIn("stop", call_args)
|
||||||
|
self.assertEqual(result, "Hello, World!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user