chore: remove duplication in azure client

This commit is contained in:
Greyson LaLonde
2025-11-11 18:01:27 -05:00
parent 93f1fbd75e
commit 8b83bf3e54
22 changed files with 370 additions and 209 deletions

View File

@@ -60,7 +60,7 @@ def test_anthropic_tool_use_conversation_flow():
available_functions = {"get_weather": mock_weather_tool}
# Mock the Anthropic client responses
with patch.object(completion.client.messages, 'create') as mock_create:
with patch.object(completion._client.messages, 'create') as mock_create:
# Mock initial response with tool use - need to properly mock ToolUseBlock
mock_tool_use = Mock(spec=ToolUseBlock)
mock_tool_use.id = "tool_123"
@@ -199,8 +199,8 @@ def test_anthropic_specific_parameters():
assert isinstance(llm, AnthropicCompletion)
assert llm.stop == ["Human:", "Assistant:"]
assert llm.stream == True
assert llm.client.max_retries == 5
assert llm.client.timeout == 60
assert llm._client.max_retries == 5
assert llm._client.timeout == 60
def test_anthropic_completion_call():
@@ -637,8 +637,8 @@ def test_anthropic_environment_variable_api_key():
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-anthropic-key"}):
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
assert llm.client is not None
assert hasattr(llm.client, 'messages')
assert llm._client is not None
assert hasattr(llm._client, 'messages')
def test_anthropic_token_usage_tracking():
@@ -648,7 +648,7 @@ def test_anthropic_token_usage_tracking():
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
# Mock the Anthropic response with usage information
with patch.object(llm.client.messages, 'create') as mock_create:
with patch.object(llm._client.messages, 'create') as mock_create:
mock_response = MagicMock()
mock_response.content = [MagicMock(text="test response")]
mock_response.usage = MagicMock(input_tokens=50, output_tokens=25)

View File

@@ -64,7 +64,7 @@ def test_azure_tool_use_conversation_flow():
available_functions = {"get_weather": mock_weather_tool}
# Mock the Azure client responses
with patch.object(completion.client, 'complete') as mock_complete:
with patch.object(completion._client, 'complete') as mock_complete:
# Mock tool call in response with proper type
mock_tool_call = MagicMock(spec=ChatCompletionsToolCall)
mock_tool_call.function.name = "get_weather"
@@ -602,7 +602,7 @@ def test_azure_environment_variable_endpoint():
}):
llm = LLM(model="azure/gpt-4")
assert llm.client is not None
assert llm._client is not None
assert llm.endpoint == "https://test.openai.azure.com/openai/deployments/gpt-4"
@@ -613,7 +613,7 @@ def test_azure_token_usage_tracking():
llm = LLM(model="azure/gpt-4")
# Mock the Azure response with usage information
with patch.object(llm.client, 'complete') as mock_complete:
with patch.object(llm._client, 'complete') as mock_complete:
mock_message = MagicMock()
mock_message.content = "test response"
mock_message.tool_calls = None
@@ -651,7 +651,7 @@ def test_azure_http_error_handling():
llm = LLM(model="azure/gpt-4")
# Mock an HTTP error
with patch.object(llm.client, 'complete') as mock_complete:
with patch.object(llm._client, 'complete') as mock_complete:
mock_complete.side_effect = HttpResponseError(message="Rate limit exceeded", response=MagicMock(status_code=429))
with pytest.raises(HttpResponseError):
@@ -668,7 +668,7 @@ def test_azure_streaming_completion():
llm = LLM(model="azure/gpt-4", stream=True)
# Mock streaming response
with patch.object(llm.client, 'complete') as mock_complete:
with patch.object(llm._client, 'complete') as mock_complete:
# Create mock streaming updates with proper type
mock_updates = []
for chunk in ["Hello", " ", "world", "!"]:
@@ -891,7 +891,7 @@ def test_azure_improved_error_messages():
llm = LLM(model="azure/gpt-4")
with patch.object(llm.client, 'complete') as mock_complete:
with patch.object(llm._client, 'complete') as mock_complete:
error_401 = HttpResponseError(message="Unauthorized")
error_401.status_code = 401
mock_complete.side_effect = error_401

View File

@@ -579,7 +579,7 @@ def test_bedrock_token_usage_tracking():
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
# Mock the Bedrock response with usage information
with patch.object(llm.client, 'converse') as mock_converse:
with patch.object(llm._client, 'converse') as mock_converse:
mock_response = {
'output': {
'message': {
@@ -624,7 +624,7 @@ def test_bedrock_tool_use_conversation_flow():
available_functions = {"get_weather": mock_weather_tool}
# Mock the Bedrock client responses
with patch.object(llm.client, 'converse') as mock_converse:
with patch.object(llm._client, 'converse') as mock_converse:
# First response: tool use request
tool_use_response = {
'output': {
@@ -710,7 +710,7 @@ def test_bedrock_client_error_handling():
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
# Test ValidationException
with patch.object(llm.client, 'converse') as mock_converse:
with patch.object(llm._client, 'converse') as mock_converse:
error_response = {
'Error': {
'Code': 'ValidationException',
@@ -724,7 +724,7 @@ def test_bedrock_client_error_handling():
assert "validation" in str(exc_info.value).lower()
# Test ThrottlingException
with patch.object(llm.client, 'converse') as mock_converse:
with patch.object(llm._client, 'converse') as mock_converse:
error_response = {
'Error': {
'Code': 'ThrottlingException',
@@ -762,7 +762,7 @@ def test_bedrock_stop_sequences_sent_to_api():
llm.stop = ["\nObservation:", "\nThought:"]
# Patch the API call to capture parameters without making real call
with patch.object(llm.client, 'converse') as mock_converse:
with patch.object(llm._client, 'converse') as mock_converse:
mock_response = {
'output': {
'message': {

View File

@@ -59,7 +59,7 @@ def test_gemini_tool_use_conversation_flow():
available_functions = {"get_weather": mock_weather_tool}
# Mock the Google Gemini client responses
with patch.object(completion.client.models, 'generate_content') as mock_generate:
with patch.object(completion._client.models, 'generate_content') as mock_generate:
# Mock function call in response
mock_function_call = Mock()
mock_function_call.name = "get_weather"
@@ -614,8 +614,8 @@ def test_gemini_environment_variable_api_key():
with patch.dict(os.environ, {"GOOGLE_API_KEY": "test-google-key"}):
llm = LLM(model="google/gemini-2.0-flash-001")
assert llm.client is not None
assert hasattr(llm.client, 'models')
assert llm._client is not None
assert hasattr(llm._client, 'models')
assert llm.api_key == "test-google-key"
@@ -626,7 +626,7 @@ def test_gemini_token_usage_tracking():
llm = LLM(model="google/gemini-2.0-flash-001")
# Mock the Gemini response with usage information
with patch.object(llm.client.models, 'generate_content') as mock_generate:
with patch.object(llm._client.models, 'generate_content') as mock_generate:
mock_response = MagicMock()
mock_response.text = "test response"
mock_response.candidates = []
@@ -675,7 +675,7 @@ def test_gemini_stop_sequences_sent_to_api():
llm.stop = ["\nObservation:", "\nThought:"]
# Patch the API call to capture parameters without making real call
with patch.object(llm.client.models, 'generate_content') as mock_generate:
with patch.object(llm._client.models, 'generate_content') as mock_generate:
mock_response = MagicMock()
mock_response.text = "Hello"
mock_response.candidates = []

View File

@@ -369,11 +369,11 @@ def test_openai_client_setup_with_extra_arguments():
assert llm.top_p == 0.5
# Check that client parameters are properly configured
assert llm.client.max_retries == 3
assert llm.client.timeout == 30
assert llm._client.max_retries == 3
assert llm._client.timeout == 30
# Test that parameters are properly used in API calls
with patch.object(llm.client.chat.completions, 'create') as mock_create:
with patch.object(llm._client.chat.completions, 'create') as mock_create:
mock_create.return_value = MagicMock(
choices=[MagicMock(message=MagicMock(content="test response", tool_calls=None))],
usage=MagicMock(prompt_tokens=10, completion_tokens=20, total_tokens=30)
@@ -394,7 +394,7 @@ def test_extra_arguments_are_passed_to_openai_completion():
"""
llm = LLM(model="gpt-4o", temperature=0.7, max_tokens=1000, top_p=0.5, max_retries=3)
with patch.object(llm.client.chat.completions, 'create') as mock_create:
with patch.object(llm._client.chat.completions, 'create') as mock_create:
mock_create.return_value = MagicMock(
choices=[MagicMock(message=MagicMock(content="test response", tool_calls=None))],
usage=MagicMock(prompt_tokens=10, completion_tokens=20, total_tokens=30)
@@ -501,7 +501,7 @@ def test_openai_streaming_with_response_model():
llm = LLM(model="openai/gpt-4o", stream=True)
with patch.object(llm.client.chat.completions, "create") as mock_create:
with patch.object(llm._client.chat.completions, "create") as mock_create:
mock_chunk1 = MagicMock()
mock_chunk1.choices = [
MagicMock(delta=MagicMock(content='{"answer": "test", ', tool_calls=None))