From 793cdca75436fee6fe5f17583ee05c5a5f56fc50 Mon Sep 17 00:00:00 2001 From: Brandon Hancock Date: Fri, 31 Jan 2025 12:33:50 -0500 Subject: [PATCH] make sure additional params are getting passed to llm --- src/crewai/llm.py | 2 ++ tests/llm_test.py | 48 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/src/crewai/llm.py b/src/crewai/llm.py index bbf8e35d9..e6d3d5e1e 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -248,6 +248,8 @@ class LLM: # Remove None values from params params = {k: v for k, v in params.items() if v is not None} + print("PARAMS FOR LLM CALL", params) + # --- 2) Make the completion call response = litellm.completion(**params) response_message = cast(Choices, cast(ModelResponse, response).choices)[ diff --git a/tests/llm_test.py b/tests/llm_test.py index 6d1e6a188..8db8726d0 100644 --- a/tests/llm_test.py +++ b/tests/llm_test.py @@ -1,4 +1,5 @@ from time import sleep +from unittest.mock import MagicMock, patch import pytest @@ -154,3 +155,50 @@ def test_llm_call_with_tool_and_message_list(): assert isinstance(result, int) assert result == 25 + + +@pytest.mark.vcr(filter_headers=["authorization"]) +def test_llm_passes_additional_params(): + llm = LLM( + model="gpt-4o-mini", + vertex_credentials="test_credentials", + vertex_project="test_project", + ) + + messages = [{"role": "user", "content": "Hello, world!"}] + + with patch("litellm.completion") as mocked_completion: + # Create mocks for response structure + mock_message = MagicMock() + mock_message.content = "Test response" + mock_choice = MagicMock() + mock_choice.message = mock_message + mock_response = MagicMock() + mock_response.choices = [mock_choice] + mock_response.usage = { + "prompt_tokens": 5, + "completion_tokens": 5, + "total_tokens": 10, + } + + # Set up the mocked completion to return the mock response + mocked_completion.return_value = mock_response + + result = llm.call(messages) + + # Assert that litellm.completion was called once + mocked_completion.assert_called_once() + + # Retrieve the actual arguments with which litellm.completion was called + _, kwargs = mocked_completion.call_args + + # Check that the additional_params were passed to litellm.completion + assert kwargs["vertex_credentials"] == "test_credentials" + assert kwargs["vertex_project"] == "test_project" + + # Also verify that other expected parameters are present + assert kwargs["model"] == "gpt-4o-mini" + assert kwargs["messages"] == messages + + # Check the result from llm.call + assert result == "Test response"