mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 16:18:30 +00:00
Brandon/provide llm additional params (#2018)
Some checks failed
Mark stale issues and pull requests / stale (push) Has been cancelled
Some checks failed
Mark stale issues and pull requests / stale (push) Has been cancelled
* Clean up to match enterprise * add additional params to LLM calls * make sure additional params are getting passed to llm * update docs * drop print
This commit is contained in:
committed by
GitHub
parent
ddb7958da7
commit
23b9e10323
@@ -465,11 +465,22 @@ Learn how to get the most out of your LLM configuration:
|
|||||||
# https://cloud.google.com/vertex-ai/generative-ai/docs/overview
|
# https://cloud.google.com/vertex-ai/generative-ai/docs/overview
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## GET CREDENTIALS
|
||||||
|
file_path = 'path/to/vertex_ai_service_account.json'
|
||||||
|
|
||||||
|
# Load the JSON file
|
||||||
|
with open(file_path, 'r') as file:
|
||||||
|
vertex_credentials = json.load(file)
|
||||||
|
|
||||||
|
# Convert to JSON string
|
||||||
|
vertex_credentials_json = json.dumps(vertex_credentials)
|
||||||
|
|
||||||
Example usage:
|
Example usage:
|
||||||
```python Code
|
```python Code
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
model="gemini/gemini-1.5-pro-latest",
|
model="gemini/gemini-1.5-pro-latest",
|
||||||
temperature=0.7
|
temperature=0.7,
|
||||||
|
vertex_credentials=vertex_credentials_json
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
</Accordion>
|
</Accordion>
|
||||||
|
|||||||
@@ -137,6 +137,7 @@ class LLM:
|
|||||||
api_version: Optional[str] = None,
|
api_version: Optional[str] = None,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
callbacks: List[Any] = [],
|
callbacks: List[Any] = [],
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
@@ -158,6 +159,7 @@ class LLM:
|
|||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.callbacks = callbacks
|
self.callbacks = callbacks
|
||||||
self.context_window_size = 0
|
self.context_window_size = 0
|
||||||
|
self.additional_params = kwargs
|
||||||
|
|
||||||
litellm.drop_params = True
|
litellm.drop_params = True
|
||||||
|
|
||||||
@@ -240,6 +242,7 @@ class LLM:
|
|||||||
"api_key": self.api_key,
|
"api_key": self.api_key,
|
||||||
"stream": False,
|
"stream": False,
|
||||||
"tools": tools,
|
"tools": tools,
|
||||||
|
**self.additional_params,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Remove None values from params
|
# Remove None values from params
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from time import sleep
|
from time import sleep
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -154,3 +155,50 @@ def test_llm_call_with_tool_and_message_list():
|
|||||||
|
|
||||||
assert isinstance(result, int)
|
assert isinstance(result, int)
|
||||||
assert result == 25
|
assert result == 25
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
|
def test_llm_passes_additional_params():
|
||||||
|
llm = LLM(
|
||||||
|
model="gpt-4o-mini",
|
||||||
|
vertex_credentials="test_credentials",
|
||||||
|
vertex_project="test_project",
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "Hello, world!"}]
|
||||||
|
|
||||||
|
with patch("litellm.completion") as mocked_completion:
|
||||||
|
# Create mocks for response structure
|
||||||
|
mock_message = MagicMock()
|
||||||
|
mock_message.content = "Test response"
|
||||||
|
mock_choice = MagicMock()
|
||||||
|
mock_choice.message = mock_message
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.choices = [mock_choice]
|
||||||
|
mock_response.usage = {
|
||||||
|
"prompt_tokens": 5,
|
||||||
|
"completion_tokens": 5,
|
||||||
|
"total_tokens": 10,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Set up the mocked completion to return the mock response
|
||||||
|
mocked_completion.return_value = mock_response
|
||||||
|
|
||||||
|
result = llm.call(messages)
|
||||||
|
|
||||||
|
# Assert that litellm.completion was called once
|
||||||
|
mocked_completion.assert_called_once()
|
||||||
|
|
||||||
|
# Retrieve the actual arguments with which litellm.completion was called
|
||||||
|
_, kwargs = mocked_completion.call_args
|
||||||
|
|
||||||
|
# Check that the additional_params were passed to litellm.completion
|
||||||
|
assert kwargs["vertex_credentials"] == "test_credentials"
|
||||||
|
assert kwargs["vertex_project"] == "test_project"
|
||||||
|
|
||||||
|
# Also verify that other expected parameters are present
|
||||||
|
assert kwargs["model"] == "gpt-4o-mini"
|
||||||
|
assert kwargs["messages"] == messages
|
||||||
|
|
||||||
|
# Check the result from llm.call
|
||||||
|
assert result == "Test response"
|
||||||
|
|||||||
Reference in New Issue
Block a user