chore: refactor llms to base models

This commit is contained in:
Greyson LaLonde
2025-11-10 14:22:09 -05:00
parent 0f1c173d02
commit 46785adf58
60 changed files with 706 additions and 612 deletions

View File

@@ -34,7 +34,7 @@ def test_anthropic_completion_is_used_when_claude_provider():
"""
llm = LLM(model="claude/claude-3-5-sonnet-20241022")
from crewai.llms.providers.anthropic.completion import AnthropicCompletion
from crewai.llm.providers.anthropic.completion import AnthropicCompletion
assert isinstance(llm, AnthropicCompletion)
assert llm.provider == "anthropic"
assert llm.model == "claude-3-5-sonnet-20241022"
@@ -47,7 +47,7 @@ def test_anthropic_tool_use_conversation_flow():
Test that the Anthropic completion properly handles tool use conversation flow
"""
from unittest.mock import Mock, patch
from crewai.llms.providers.anthropic.completion import AnthropicCompletion
from crewai.llm.providers.anthropic.completion import AnthropicCompletion
from anthropic.types.tool_use_block import ToolUseBlock
# Create AnthropicCompletion instance
@@ -123,7 +123,7 @@ def test_anthropic_completion_module_is_imported():
"""
Test that the completion module is properly imported when using Anthropic provider
"""
module_name = "crewai.llms.providers.anthropic.completion"
module_name = "crewai.llm.providers.anthropic.completion"
# Remove module from cache if it exists
if module_name in sys.modules:
@@ -175,7 +175,7 @@ def test_anthropic_completion_initialization_parameters():
api_key="test-key"
)
from crewai.llms.providers.anthropic.completion import AnthropicCompletion
from crewai.llm.providers.anthropic.completion import AnthropicCompletion
assert isinstance(llm, AnthropicCompletion)
assert llm.model == "claude-3-5-sonnet-20241022"
assert llm.temperature == 0.7
@@ -195,7 +195,7 @@ def test_anthropic_specific_parameters():
timeout=60
)
from crewai.llms.providers.anthropic.completion import AnthropicCompletion
from crewai.llm.providers.anthropic.completion import AnthropicCompletion
assert isinstance(llm, AnthropicCompletion)
assert llm.stop_sequences == ["Human:", "Assistant:"]
assert llm.stream == True
@@ -390,7 +390,7 @@ def test_anthropic_raises_error_when_model_not_supported():
"""Test that AnthropicCompletion raises ValueError when model not supported"""
# Mock the Anthropic client to raise an error
with patch('crewai.llms.providers.anthropic.completion.Anthropic') as mock_anthropic_class:
with patch('crewai.llm.providers.anthropic.completion.Anthropic') as mock_anthropic_class:
mock_client = MagicMock()
mock_anthropic_class.return_value = mock_client
@@ -427,7 +427,7 @@ def test_anthropic_client_params_setup():
client_params=custom_client_params
)
from crewai.llms.providers.anthropic.completion import AnthropicCompletion
from crewai.llm.providers.anthropic.completion import AnthropicCompletion
assert isinstance(llm, AnthropicCompletion)
assert llm.client_params == custom_client_params
@@ -462,7 +462,7 @@ def test_anthropic_client_params_override_defaults():
)
# Verify this is actually AnthropicCompletion, not LiteLLM fallback
from crewai.llms.providers.anthropic.completion import AnthropicCompletion
from crewai.llm.providers.anthropic.completion import AnthropicCompletion
assert isinstance(llm, AnthropicCompletion)
merged_params = llm._get_client_params()
@@ -487,7 +487,7 @@ def test_anthropic_client_params_none():
client_params=None
)
from crewai.llms.providers.anthropic.completion import AnthropicCompletion
from crewai.llm.providers.anthropic.completion import AnthropicCompletion
assert isinstance(llm, AnthropicCompletion)
assert llm.client_params is None
@@ -515,7 +515,7 @@ def test_anthropic_client_params_empty_dict():
client_params={}
)
from crewai.llms.providers.anthropic.completion import AnthropicCompletion
from crewai.llm.providers.anthropic.completion import AnthropicCompletion
assert isinstance(llm, AnthropicCompletion)
assert llm.client_params == {}
@@ -538,7 +538,7 @@ def test_anthropic_model_detection():
for model_name in anthropic_test_cases:
llm = LLM(model=model_name)
from crewai.llms.providers.anthropic.completion import AnthropicCompletion
from crewai.llm.providers.anthropic.completion import AnthropicCompletion
assert isinstance(llm, AnthropicCompletion), f"Failed for model: {model_name}"

View File

@@ -37,7 +37,7 @@ def test_azure_completion_is_used_when_azure_openai_provider():
"""
llm = LLM(model="azure_openai/gpt-4")
from crewai.llms.providers.azure.completion import AzureCompletion
from crewai.llm.providers.azure.completion import AzureCompletion
assert isinstance(llm, AzureCompletion)
assert llm.provider == "azure"
assert llm.model == "gpt-4"
@@ -47,7 +47,7 @@ def test_azure_tool_use_conversation_flow():
"""
Test that the Azure completion properly handles tool use conversation flow
"""
from crewai.llms.providers.azure.completion import AzureCompletion
from crewai.llm.providers.azure.completion import AzureCompletion
from azure.ai.inference.models import ChatCompletionsToolCall
# Create AzureCompletion instance
@@ -105,7 +105,7 @@ def test_azure_completion_module_is_imported():
"""
Test that the completion module is properly imported when using Azure provider
"""
module_name = "crewai.llms.providers.azure.completion"
module_name = "crewai.llm.providers.azure.completion"
# Remove module from cache if it exists
if module_name in sys.modules:
@@ -160,7 +160,7 @@ def test_azure_completion_initialization_parameters():
endpoint="https://test.openai.azure.com"
)
from crewai.llms.providers.azure.completion import AzureCompletion
from crewai.llm.providers.azure.completion import AzureCompletion
assert isinstance(llm, AzureCompletion)
assert llm.model == "gpt-4"
assert llm.temperature == 0.7
@@ -182,7 +182,7 @@ def test_azure_specific_parameters():
endpoint="https://test.openai.azure.com"
)
from crewai.llms.providers.azure.completion import AzureCompletion
from crewai.llm.providers.azure.completion import AzureCompletion
assert isinstance(llm, AzureCompletion)
assert llm.stop == ["Human:", "Assistant:"]
assert llm.stream == True
@@ -374,7 +374,7 @@ def test_azure_completion_with_tools():
def test_azure_raises_error_when_endpoint_missing():
"""Test that AzureCompletion raises ValueError when endpoint is missing"""
from crewai.llms.providers.azure.completion import AzureCompletion
from crewai.llm.providers.azure.completion import AzureCompletion
# Clear environment variables
with patch.dict(os.environ, {}, clear=True):
@@ -383,7 +383,7 @@ def test_azure_raises_error_when_endpoint_missing():
def test_azure_raises_error_when_api_key_missing():
"""Test that AzureCompletion raises ValueError when API key is missing"""
from crewai.llms.providers.azure.completion import AzureCompletion
from crewai.llm.providers.azure.completion import AzureCompletion
# Clear environment variables
with patch.dict(os.environ, {}, clear=True):
@@ -400,7 +400,7 @@ def test_azure_endpoint_configuration():
}):
llm = LLM(model="azure/gpt-4")
from crewai.llms.providers.azure.completion import AzureCompletion
from crewai.llm.providers.azure.completion import AzureCompletion
assert isinstance(llm, AzureCompletion)
assert llm.endpoint == "https://test1.openai.azure.com/openai/deployments/gpt-4"
@@ -426,7 +426,7 @@ def test_azure_api_key_configuration():
}):
llm = LLM(model="azure/gpt-4")
from crewai.llms.providers.azure.completion import AzureCompletion
from crewai.llm.providers.azure.completion import AzureCompletion
assert isinstance(llm, AzureCompletion)
assert llm.api_key == "test-azure-key"
@@ -437,7 +437,7 @@ def test_azure_model_capabilities():
"""
# Test GPT-4 model (supports function calling)
llm_gpt4 = LLM(model="azure/gpt-4")
from crewai.llms.providers.azure.completion import AzureCompletion
from crewai.llm.providers.azure.completion import AzureCompletion
assert isinstance(llm_gpt4, AzureCompletion)
assert llm_gpt4.is_openai_model == True
assert llm_gpt4.supports_function_calling() == True
@@ -466,7 +466,7 @@ def test_azure_completion_params_preparation():
max_tokens=1000
)
from crewai.llms.providers.azure.completion import AzureCompletion
from crewai.llm.providers.azure.completion import AzureCompletion
assert isinstance(llm, AzureCompletion)
messages = [{"role": "user", "content": "Hello"}]
@@ -494,7 +494,7 @@ def test_azure_model_detection():
for model_name in azure_test_cases:
llm = LLM(model=model_name)
from crewai.llms.providers.azure.completion import AzureCompletion
from crewai.llm.providers.azure.completion import AzureCompletion
assert isinstance(llm, AzureCompletion), f"Failed for model: {model_name}"
@@ -662,7 +662,7 @@ def test_azure_streaming_completion():
"""
Test that streaming completions work properly
"""
from crewai.llms.providers.azure.completion import AzureCompletion
from crewai.llm.providers.azure.completion import AzureCompletion
from azure.ai.inference.models import StreamingChatCompletionsUpdate
llm = LLM(model="azure/gpt-4", stream=True)
@@ -698,7 +698,7 @@ def test_azure_api_version_default():
"""
llm = LLM(model="azure/gpt-4")
from crewai.llms.providers.azure.completion import AzureCompletion
from crewai.llm.providers.azure.completion import AzureCompletion
assert isinstance(llm, AzureCompletion)
# Should use default or environment variable
assert llm.api_version is not None
@@ -721,7 +721,7 @@ def test_azure_openai_endpoint_url_construction():
"""
Test that Azure OpenAI endpoint URLs are automatically constructed correctly
"""
from crewai.llms.providers.azure.completion import AzureCompletion
from crewai.llm.providers.azure.completion import AzureCompletion
with patch.dict(os.environ, {
"AZURE_API_KEY": "test-key",
@@ -738,7 +738,7 @@ def test_azure_openai_endpoint_url_with_trailing_slash():
"""
Test that trailing slashes are handled correctly in endpoint URLs
"""
from crewai.llms.providers.azure.completion import AzureCompletion
from crewai.llm.providers.azure.completion import AzureCompletion
with patch.dict(os.environ, {
"AZURE_API_KEY": "test-key",
@@ -804,7 +804,7 @@ def test_non_azure_openai_model_parameter_included():
"""
Test that model parameter IS included for non-Azure OpenAI endpoints
"""
from crewai.llms.providers.azure.completion import AzureCompletion
from crewai.llm.providers.azure.completion import AzureCompletion
with patch.dict(os.environ, {
"AZURE_API_KEY": "test-key",
@@ -824,7 +824,7 @@ def test_azure_message_formatting_with_role():
"""
Test that messages are formatted with both 'role' and 'content' fields
"""
from crewai.llms.providers.azure.completion import AzureCompletion
from crewai.llm.providers.azure.completion import AzureCompletion
llm = LLM(model="azure/gpt-4")
@@ -886,7 +886,7 @@ def test_azure_improved_error_messages():
"""
Test that improved error messages are provided for common HTTP errors
"""
from crewai.llms.providers.azure.completion import AzureCompletion
from crewai.llm.providers.azure.completion import AzureCompletion
from azure.core.exceptions import HttpResponseError
llm = LLM(model="azure/gpt-4")
@@ -918,7 +918,7 @@ def test_azure_api_version_properly_passed():
"""
Test that api_version is properly passed to the client
"""
from crewai.llms.providers.azure.completion import AzureCompletion
from crewai.llm.providers.azure.completion import AzureCompletion
with patch.dict(os.environ, {
"AZURE_API_KEY": "test-key",
@@ -940,7 +940,7 @@ def test_azure_timeout_and_max_retries_stored():
"""
Test that timeout and max_retries parameters are stored
"""
from crewai.llms.providers.azure.completion import AzureCompletion
from crewai.llm.providers.azure.completion import AzureCompletion
with patch.dict(os.environ, {
"AZURE_API_KEY": "test-key",
@@ -960,7 +960,7 @@ def test_azure_complete_params_include_optional_params():
"""
Test that optional parameters are included in completion params when set
"""
from crewai.llms.providers.azure.completion import AzureCompletion
from crewai.llm.providers.azure.completion import AzureCompletion
with patch.dict(os.environ, {
"AZURE_API_KEY": "test-key",
@@ -992,7 +992,7 @@ def test_azure_endpoint_validation_with_azure_prefix():
"""
Test that 'azure/' prefix is properly stripped when constructing endpoint
"""
from crewai.llms.providers.azure.completion import AzureCompletion
from crewai.llm.providers.azure.completion import AzureCompletion
with patch.dict(os.environ, {
"AZURE_API_KEY": "test-key",
@@ -1009,7 +1009,7 @@ def test_azure_message_formatting_preserves_all_roles():
"""
Test that all message roles (system, user, assistant) are preserved correctly
"""
from crewai.llms.providers.azure.completion import AzureCompletion
from crewai.llm.providers.azure.completion import AzureCompletion
llm = LLM(model="azure/gpt-4")

View File

@@ -19,7 +19,7 @@ def mock_aws_credentials():
"AWS_DEFAULT_REGION": "us-east-1"
}):
# Mock boto3 Session to prevent actual AWS connections
with patch('crewai.llms.providers.bedrock.completion.Session') as mock_session_class:
with patch('crewai.llm.providers.bedrock.completion.Session') as mock_session_class:
# Create mock session instance
mock_session_instance = MagicMock()
mock_client = MagicMock()
@@ -67,7 +67,7 @@ def test_bedrock_completion_module_is_imported():
"""
Test that the completion module is properly imported when using Bedrock provider
"""
module_name = "crewai.llms.providers.bedrock.completion"
module_name = "crewai.llm.providers.bedrock.completion"
# Remove module from cache if it exists
if module_name in sys.modules:
@@ -124,7 +124,7 @@ def test_bedrock_completion_initialization_parameters():
region_name="us-west-2"
)
from crewai.llms.providers.bedrock.completion import BedrockCompletion
from crewai.llm.providers.bedrock.completion import BedrockCompletion
assert isinstance(llm, BedrockCompletion)
assert llm.model == "anthropic.claude-3-5-sonnet-20241022-v2:0"
assert llm.temperature == 0.7
@@ -145,7 +145,7 @@ def test_bedrock_specific_parameters():
region_name="us-east-1"
)
from crewai.llms.providers.bedrock.completion import BedrockCompletion
from crewai.llm.providers.bedrock.completion import BedrockCompletion
assert isinstance(llm, BedrockCompletion)
assert llm.stop_sequences == ["Human:", "Assistant:"]
assert llm.stream == True
@@ -369,7 +369,7 @@ def test_bedrock_aws_credentials_configuration():
}):
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
from crewai.llms.providers.bedrock.completion import BedrockCompletion
from crewai.llm.providers.bedrock.completion import BedrockCompletion
assert isinstance(llm, BedrockCompletion)
assert llm.region_name == "us-east-1"
@@ -390,7 +390,7 @@ def test_bedrock_model_capabilities():
"""
# Test Claude model
llm_claude = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
from crewai.llms.providers.bedrock.completion import BedrockCompletion
from crewai.llm.providers.bedrock.completion import BedrockCompletion
assert isinstance(llm_claude, BedrockCompletion)
assert llm_claude.is_claude_model == True
assert llm_claude.supports_tools == True
@@ -413,7 +413,7 @@ def test_bedrock_inference_config():
max_tokens=1000
)
from crewai.llms.providers.bedrock.completion import BedrockCompletion
from crewai.llm.providers.bedrock.completion import BedrockCompletion
assert isinstance(llm, BedrockCompletion)
# Test config preparation
@@ -444,7 +444,7 @@ def test_bedrock_model_detection():
for model_name in bedrock_test_cases:
llm = LLM(model=model_name)
from crewai.llms.providers.bedrock.completion import BedrockCompletion
from crewai.llm.providers.bedrock.completion import BedrockCompletion
assert isinstance(llm, BedrockCompletion), f"Failed for model: {model_name}"

View File

@@ -34,7 +34,7 @@ def test_gemini_completion_is_used_when_gemini_provider():
"""
llm = LLM(model="gemini/gemini-2.0-flash-001")
from crewai.llms.providers.gemini.completion import GeminiCompletion
from crewai.llm.providers.gemini.completion import GeminiCompletion
assert isinstance(llm, GeminiCompletion)
assert llm.provider == "gemini"
assert llm.model == "gemini-2.0-flash-001"
@@ -47,7 +47,7 @@ def test_gemini_tool_use_conversation_flow():
Test that the Gemini completion properly handles tool use conversation flow
"""
from unittest.mock import Mock, patch
from crewai.llms.providers.gemini.completion import GeminiCompletion
from crewai.llm.providers.gemini.completion import GeminiCompletion
# Create GeminiCompletion instance
completion = GeminiCompletion(model="gemini-2.0-flash-001")
@@ -102,7 +102,7 @@ def test_gemini_completion_module_is_imported():
"""
Test that the completion module is properly imported when using Google provider
"""
module_name = "crewai.llms.providers.gemini.completion"
module_name = "crewai.llm.providers.gemini.completion"
# Remove module from cache if it exists
if module_name in sys.modules:
@@ -159,7 +159,7 @@ def test_gemini_completion_initialization_parameters():
api_key="test-key"
)
from crewai.llms.providers.gemini.completion import GeminiCompletion
from crewai.llm.providers.gemini.completion import GeminiCompletion
assert isinstance(llm, GeminiCompletion)
assert llm.model == "gemini-2.0-flash-001"
assert llm.temperature == 0.7
@@ -186,7 +186,7 @@ def test_gemini_specific_parameters():
location="us-central1"
)
from crewai.llms.providers.gemini.completion import GeminiCompletion
from crewai.llm.providers.gemini.completion import GeminiCompletion
assert isinstance(llm, GeminiCompletion)
assert llm.stop_sequences == ["Human:", "Assistant:"]
assert llm.stream == True
@@ -382,7 +382,7 @@ def test_gemini_raises_error_when_model_not_supported():
"""Test that GeminiCompletion raises ValueError when model not supported"""
# Mock the Google client to raise an error
with patch('crewai.llms.providers.gemini.completion.genai') as mock_genai:
with patch('crewai.llm.providers.gemini.completion.genai') as mock_genai:
mock_client = MagicMock()
mock_genai.Client.return_value = mock_client
@@ -420,7 +420,7 @@ def test_gemini_vertex_ai_setup():
location="us-west1"
)
from crewai.llms.providers.gemini.completion import GeminiCompletion
from crewai.llm.providers.gemini.completion import GeminiCompletion
assert isinstance(llm, GeminiCompletion)
assert llm.project == "test-project"
@@ -435,7 +435,7 @@ def test_gemini_api_key_configuration():
with patch.dict(os.environ, {"GOOGLE_API_KEY": "test-google-key"}):
llm = LLM(model="google/gemini-2.0-flash-001")
from crewai.llms.providers.gemini.completion import GeminiCompletion
from crewai.llm.providers.gemini.completion import GeminiCompletion
assert isinstance(llm, GeminiCompletion)
assert llm.api_key == "test-google-key"
@@ -453,7 +453,7 @@ def test_gemini_model_capabilities():
"""
# Test Gemini 2.0 model
llm_2_0 = LLM(model="google/gemini-2.0-flash-001")
from crewai.llms.providers.gemini.completion import GeminiCompletion
from crewai.llm.providers.gemini.completion import GeminiCompletion
assert isinstance(llm_2_0, GeminiCompletion)
assert llm_2_0.is_gemini_2 == True
assert llm_2_0.supports_tools == True
@@ -477,7 +477,7 @@ def test_gemini_generation_config():
max_output_tokens=1000
)
from crewai.llms.providers.gemini.completion import GeminiCompletion
from crewai.llm.providers.gemini.completion import GeminiCompletion
assert isinstance(llm, GeminiCompletion)
# Test config preparation
@@ -504,7 +504,7 @@ def test_gemini_model_detection():
for model_name in gemini_test_cases:
llm = LLM(model=model_name)
from crewai.llms.providers.gemini.completion import GeminiCompletion
from crewai.llm.providers.gemini.completion import GeminiCompletion
assert isinstance(llm, GeminiCompletion), f"Failed for model: {model_name}"

View File

@@ -6,7 +6,7 @@ import httpx
import pytest
from crewai.llm import LLM
from crewai.llms.hooks.base import BaseInterceptor
from crewai.llm.hooks.base import BaseInterceptor
@pytest.fixture(autouse=True)

View File

@@ -3,7 +3,7 @@
import httpx
import pytest
from crewai.llms.hooks.base import BaseInterceptor
from crewai.llm.hooks.base import BaseInterceptor
class SimpleInterceptor(BaseInterceptor[httpx.Request, httpx.Response]):

View File

@@ -4,7 +4,7 @@ import httpx
import pytest
from crewai.llm import LLM
from crewai.llms.hooks.base import BaseInterceptor
from crewai.llm.hooks.base import BaseInterceptor
class OpenAITestInterceptor(BaseInterceptor[httpx.Request, httpx.Response]):

View File

@@ -5,8 +5,8 @@ from unittest.mock import Mock
import httpx
import pytest
from crewai.llms.hooks.base import BaseInterceptor
from crewai.llms.hooks.transport import AsyncHTTPTransport, HTTPTransport
from crewai.llm.hooks.base import BaseInterceptor
from crewai.llm.hooks.transport import AsyncHTTPTransport, HTTPTransport
class TrackingInterceptor(BaseInterceptor[httpx.Request, httpx.Response]):

View File

@@ -6,7 +6,7 @@ import httpx
import pytest
from crewai.llm import LLM
from crewai.llms.hooks.base import BaseInterceptor
from crewai.llm.hooks.base import BaseInterceptor
@pytest.fixture(autouse=True)

View File

@@ -6,7 +6,7 @@ import openai
import pytest
from crewai.llm import LLM
from crewai.llms.providers.openai.completion import OpenAICompletion
from crewai.llm.providers.openai.completion import OpenAICompletion
from crewai.crew import Crew
from crewai.agent import Agent
from crewai.task import Task
@@ -29,7 +29,7 @@ def test_openai_completion_is_used_when_no_provider_prefix():
"""
llm = LLM(model="gpt-4o")
from crewai.llms.providers.openai.completion import OpenAICompletion
from crewai.llm.providers.openai.completion import OpenAICompletion
assert isinstance(llm, OpenAICompletion)
assert llm.provider == "openai"
assert llm.model == "gpt-4o"
@@ -63,7 +63,7 @@ def test_openai_completion_module_is_imported():
"""
Test that the completion module is properly imported when using OpenAI provider
"""
module_name = "crewai.llms.providers.openai.completion"
module_name = "crewai.llm.providers.openai.completion"
# Remove module from cache if it exists
if module_name in sys.modules:
@@ -114,7 +114,7 @@ def test_openai_completion_initialization_parameters():
api_key="test-key"
)
from crewai.llms.providers.openai.completion import OpenAICompletion
from crewai.llm.providers.openai.completion import OpenAICompletion
assert isinstance(llm, OpenAICompletion)
assert llm.model == "gpt-4o"
assert llm.temperature == 0.7
@@ -335,7 +335,7 @@ def test_openai_completion_call_returns_usage_metrics():
def test_openai_raises_error_when_model_not_supported():
"""Test that OpenAICompletion raises ValueError when model not supported"""
with patch('crewai.llms.providers.openai.completion.OpenAI') as mock_openai_class:
with patch('crewai.llm.providers.openai.completion.OpenAI') as mock_openai_class:
mock_client = MagicMock()
mock_openai_class.return_value = mock_client