diff --git a/docs/concepts/llms.mdx b/docs/concepts/llms.mdx index 261a1fdd8..0358308f4 100644 --- a/docs/concepts/llms.mdx +++ b/docs/concepts/llms.mdx @@ -465,11 +465,22 @@ Learn how to get the most out of your LLM configuration: # https://cloud.google.com/vertex-ai/generative-ai/docs/overview ``` + ## GET CREDENTIALS + file_path = 'path/to/vertex_ai_service_account.json' + + # Load the JSON file + with open(file_path, 'r') as file: + vertex_credentials = json.load(file) + + # Convert to JSON string + vertex_credentials_json = json.dumps(vertex_credentials) + Example usage: ```python Code llm = LLM( model="gemini/gemini-1.5-pro-latest", - temperature=0.7 + temperature=0.7, + vertex_credentials=vertex_credentials_json ) ``` diff --git a/src/crewai/llm.py b/src/crewai/llm.py index ef8746fd5..bbf8e35d9 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -137,6 +137,7 @@ class LLM: api_version: Optional[str] = None, api_key: Optional[str] = None, callbacks: List[Any] = [], + **kwargs, ): self.model = model self.timeout = timeout @@ -158,6 +159,7 @@ class LLM: self.api_key = api_key self.callbacks = callbacks self.context_window_size = 0 + self.additional_params = kwargs litellm.drop_params = True @@ -240,6 +242,7 @@ class LLM: "api_key": self.api_key, "stream": False, "tools": tools, + **self.additional_params, } # Remove None values from params diff --git a/tests/llm_test.py b/tests/llm_test.py index 6d1e6a188..8db8726d0 100644 --- a/tests/llm_test.py +++ b/tests/llm_test.py @@ -1,4 +1,5 @@ from time import sleep +from unittest.mock import MagicMock, patch import pytest @@ -154,3 +155,50 @@ def test_llm_call_with_tool_and_message_list(): assert isinstance(result, int) assert result == 25 + + +@pytest.mark.vcr(filter_headers=["authorization"]) +def test_llm_passes_additional_params(): + llm = LLM( + model="gpt-4o-mini", + vertex_credentials="test_credentials", + vertex_project="test_project", + ) + + messages = [{"role": "user", "content": "Hello, world!"}] + + with patch("litellm.completion") as mocked_completion: + # Create mocks for response structure + mock_message = MagicMock() + mock_message.content = "Test response" + mock_choice = MagicMock() + mock_choice.message = mock_message + mock_response = MagicMock() + mock_response.choices = [mock_choice] + mock_response.usage = { + "prompt_tokens": 5, + "completion_tokens": 5, + "total_tokens": 10, + } + + # Set up the mocked completion to return the mock response + mocked_completion.return_value = mock_response + + result = llm.call(messages) + + # Assert that litellm.completion was called once + mocked_completion.assert_called_once() + + # Retrieve the actual arguments with which litellm.completion was called + _, kwargs = mocked_completion.call_args + + # Check that the additional_params were passed to litellm.completion + assert kwargs["vertex_credentials"] == "test_credentials" + assert kwargs["vertex_project"] == "test_project" + + # Also verify that other expected parameters are present + assert kwargs["model"] == "gpt-4o-mini" + assert kwargs["messages"] == messages + + # Check the result from llm.call + assert result == "Test response"