mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 04:18:35 +00:00
- Add conversion of Pydantic BaseModel classes to json_schema format in _prepare_completion_params - Add parsing of JSON responses back into Pydantic models in _handle_non_streaming_response - Ensure response_model parameter takes precedence over response_format - Add three comprehensive tests covering Pydantic model conversion, dict passthrough, and precedence - Fix test fixture decorator issue (removed @pytest.mark.vcr from anthropic_llm fixture) Fixes #3959 Co-Authored-By: João <joao@crewai.com>
946 lines
33 KiB
Python
946 lines
33 KiB
Python
import logging
|
|
import os
|
|
from time import sleep
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
|
from crewai.events.event_types import (
|
|
LLMCallCompletedEvent,
|
|
LLMStreamChunkEvent,
|
|
ToolUsageErrorEvent,
|
|
ToolUsageFinishedEvent,
|
|
ToolUsageStartedEvent,
|
|
)
|
|
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM
|
|
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
|
from pydantic import BaseModel
|
|
import pytest
|
|
|
|
|
|
# TODO: This test fails without print statement, which makes me think that something is happening asynchronously that we need to eventually fix and dive deeper into at a later date
|
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
|
def test_llm_callback_replacement():
|
|
llm1 = LLM(model="gpt-4o-mini", is_litellm=True)
|
|
llm2 = LLM(model="gpt-4o-mini", is_litellm=True)
|
|
|
|
calc_handler_1 = TokenCalcHandler(token_cost_process=TokenProcess())
|
|
calc_handler_2 = TokenCalcHandler(token_cost_process=TokenProcess())
|
|
|
|
llm1.call(
|
|
messages=[{"role": "user", "content": "Hello, world!"}],
|
|
callbacks=[calc_handler_1],
|
|
)
|
|
usage_metrics_1 = calc_handler_1.token_cost_process.get_summary()
|
|
|
|
llm2.call(
|
|
messages=[{"role": "user", "content": "Hello, world from another agent!"}],
|
|
callbacks=[calc_handler_2],
|
|
)
|
|
sleep(5)
|
|
usage_metrics_2 = calc_handler_2.token_cost_process.get_summary()
|
|
|
|
# The first handler should not have been updated
|
|
assert usage_metrics_1.successful_requests == 1
|
|
assert usage_metrics_2.successful_requests == 1
|
|
assert usage_metrics_1 == calc_handler_1.token_cost_process.get_summary()
|
|
|
|
|
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
|
def test_llm_call_with_string_input():
|
|
llm = LLM(model="gpt-4o-mini")
|
|
|
|
# Test the call method with a string input
|
|
result = llm.call("Return the name of a random city in the world.")
|
|
assert isinstance(result, str)
|
|
assert len(result.strip()) > 0 # Ensure the response is not empty
|
|
|
|
|
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
|
def test_llm_call_with_string_input_and_callbacks():
|
|
llm = LLM(model="gpt-4o-mini", is_litellm=True)
|
|
calc_handler = TokenCalcHandler(token_cost_process=TokenProcess())
|
|
|
|
# Test the call method with a string input and callbacks
|
|
result = llm.call(
|
|
"Tell me a joke.",
|
|
callbacks=[calc_handler],
|
|
)
|
|
usage_metrics = calc_handler.token_cost_process.get_summary()
|
|
|
|
assert isinstance(result, str)
|
|
assert len(result.strip()) > 0
|
|
assert usage_metrics.successful_requests == 1
|
|
|
|
|
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
|
def test_llm_call_with_message_list():
|
|
llm = LLM(model="gpt-4o-mini")
|
|
messages = [{"role": "user", "content": "What is the capital of France?"}]
|
|
|
|
# Test the call method with a list of messages
|
|
result = llm.call(messages)
|
|
assert isinstance(result, str)
|
|
assert "Paris" in result
|
|
|
|
|
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
|
def test_llm_call_with_tool_and_string_input():
|
|
llm = LLM(model="gpt-4o-mini")
|
|
|
|
def get_current_year() -> str:
|
|
"""Returns the current year as a string."""
|
|
from datetime import datetime
|
|
|
|
return str(datetime.now().year)
|
|
|
|
# Create tool schema
|
|
tool_schema = {
|
|
"type": "function",
|
|
"function": {
|
|
"name": "get_current_year",
|
|
"description": "Returns the current year as a string.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {},
|
|
"required": [],
|
|
},
|
|
},
|
|
}
|
|
|
|
# Available functions mapping
|
|
available_functions = {"get_current_year": get_current_year}
|
|
|
|
# Test the call method with a string input and tool
|
|
result = llm.call(
|
|
"What is the current year?",
|
|
tools=[tool_schema],
|
|
available_functions=available_functions,
|
|
)
|
|
|
|
assert isinstance(result, str)
|
|
assert result == get_current_year()
|
|
|
|
|
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
|
def test_llm_call_with_tool_and_message_list():
|
|
llm = LLM(model="gpt-4o-mini", is_litellm=True)
|
|
|
|
def square_number(number: int) -> int:
|
|
"""Returns the square of a number."""
|
|
return number * number
|
|
|
|
# Create tool schema
|
|
tool_schema = {
|
|
"type": "function",
|
|
"function": {
|
|
"name": "square_number",
|
|
"description": "Returns the square of a number.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"number": {"type": "integer", "description": "The number to square"}
|
|
},
|
|
"required": ["number"],
|
|
},
|
|
},
|
|
}
|
|
|
|
# Available functions mapping
|
|
available_functions = {"square_number": square_number}
|
|
|
|
messages = [{"role": "user", "content": "What is the square of 5?"}]
|
|
|
|
# Test the call method with messages and tool
|
|
result = llm.call(
|
|
messages,
|
|
tools=[tool_schema],
|
|
available_functions=available_functions,
|
|
)
|
|
|
|
assert isinstance(result, int)
|
|
assert result == 25
|
|
|
|
|
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
|
def test_llm_passes_additional_params():
|
|
llm = LLM(
|
|
model="gpt-4o-mini",
|
|
vertex_credentials="test_credentials",
|
|
vertex_project="test_project",
|
|
is_litellm=True,
|
|
)
|
|
|
|
messages = [{"role": "user", "content": "Hello, world!"}]
|
|
|
|
with patch("litellm.completion") as mocked_completion:
|
|
# Create mocks for response structure
|
|
mock_message = MagicMock()
|
|
mock_message.content = "Test response"
|
|
mock_choice = MagicMock()
|
|
mock_choice.message = mock_message
|
|
mock_response = MagicMock()
|
|
mock_response.choices = [mock_choice]
|
|
mock_response.usage = {
|
|
"prompt_tokens": 5,
|
|
"completion_tokens": 5,
|
|
"total_tokens": 10,
|
|
}
|
|
|
|
# Set up the mocked completion to return the mock response
|
|
mocked_completion.return_value = mock_response
|
|
|
|
result = llm.call(messages)
|
|
|
|
# Assert that litellm.completion was called once
|
|
mocked_completion.assert_called_once()
|
|
|
|
# Retrieve the actual arguments with which litellm.completion was called
|
|
_, kwargs = mocked_completion.call_args
|
|
|
|
# Check that the additional_params were passed to litellm.completion
|
|
assert kwargs["vertex_credentials"] == "test_credentials"
|
|
assert kwargs["vertex_project"] == "test_project"
|
|
|
|
# Also verify that other expected parameters are present
|
|
assert kwargs["model"] == "gpt-4o-mini"
|
|
assert kwargs["messages"] == messages
|
|
|
|
# Check the result from llm.call
|
|
assert result == "Test response"
|
|
|
|
|
|
def test_get_custom_llm_provider_openrouter():
|
|
llm = LLM(model="openrouter/deepseek/deepseek-chat")
|
|
assert llm._get_custom_llm_provider() == "openrouter"
|
|
|
|
|
|
def test_get_custom_llm_provider_gemini():
|
|
llm = LLM(model="gemini/gemini-1.5-pro", is_litellm=True)
|
|
assert llm._get_custom_llm_provider() == "gemini"
|
|
|
|
|
|
def test_get_custom_llm_provider_openai():
|
|
llm = LLM(model="gpt-4", is_litellm=True)
|
|
assert llm._get_custom_llm_provider() is None
|
|
|
|
|
|
def test_validate_call_params_supported():
|
|
class DummyResponse(BaseModel):
|
|
a: int
|
|
|
|
# Patch supports_response_schema to simulate a supported model.
|
|
with patch("crewai.llm.supports_response_schema", return_value=True):
|
|
llm = LLM(
|
|
model="openrouter/deepseek/deepseek-chat", response_format=DummyResponse
|
|
)
|
|
# Should not raise any error.
|
|
llm._validate_call_params()
|
|
|
|
|
|
def test_validate_call_params_not_supported():
|
|
class DummyResponse(BaseModel):
|
|
a: int
|
|
|
|
# Patch supports_response_schema to simulate an unsupported model.
|
|
with patch("crewai.llm.supports_response_schema", return_value=False):
|
|
llm = LLM(model="gemini/gemini-1.5-pro", response_format=DummyResponse, is_litellm=True)
|
|
with pytest.raises(ValueError) as excinfo:
|
|
llm._validate_call_params()
|
|
assert "does not support response_format" in str(excinfo.value)
|
|
|
|
|
|
def test_validate_call_params_no_response_format():
|
|
# When no response_format is provided, no validation error should occur.
|
|
llm = LLM(model="gemini/gemini-1.5-pro", response_format=None, is_litellm=True)
|
|
llm._validate_call_params()
|
|
|
|
|
|
def test_response_format_pydantic_model_conversion():
|
|
"""Test that response_format with Pydantic model is converted to json_schema format."""
|
|
class TestResponse(BaseModel):
|
|
answer: str
|
|
confidence: float
|
|
|
|
llm = LLM(model="gpt-4o-mini", response_format=TestResponse, is_litellm=True)
|
|
|
|
with patch("litellm.completion") as mocked_completion:
|
|
mock_message = MagicMock()
|
|
mock_message.content = '{"answer": "Paris", "confidence": 0.95}'
|
|
mock_message.tool_calls = []
|
|
mock_choice = MagicMock()
|
|
mock_choice.message = mock_message
|
|
mock_response = MagicMock()
|
|
mock_response.choices = [mock_choice]
|
|
mock_response.usage = {
|
|
"prompt_tokens": 10,
|
|
"completion_tokens": 10,
|
|
"total_tokens": 20,
|
|
}
|
|
|
|
mocked_completion.return_value = mock_response
|
|
|
|
result = llm.call("What is the capital of France?")
|
|
|
|
mocked_completion.assert_called_once()
|
|
_, kwargs = mocked_completion.call_args
|
|
|
|
assert "response_format" in kwargs
|
|
assert isinstance(kwargs["response_format"], dict)
|
|
assert kwargs["response_format"]["type"] == "json_schema"
|
|
assert "json_schema" in kwargs["response_format"]
|
|
assert kwargs["response_format"]["json_schema"]["name"] == "TestResponse"
|
|
assert "schema" in kwargs["response_format"]["json_schema"]
|
|
|
|
import json
|
|
result_dict = json.loads(result)
|
|
assert result_dict["answer"] == "Paris"
|
|
assert result_dict["confidence"] == 0.95
|
|
|
|
|
|
def test_response_format_dict_passthrough():
|
|
"""Test that response_format with dict is passed through unchanged."""
|
|
response_format_dict = {"type": "json_object"}
|
|
|
|
llm = LLM(model="gpt-4o-mini", response_format=response_format_dict, is_litellm=True)
|
|
|
|
with patch("litellm.completion") as mocked_completion:
|
|
mock_message = MagicMock()
|
|
mock_message.content = '{"result": "test"}'
|
|
mock_message.tool_calls = []
|
|
mock_choice = MagicMock()
|
|
mock_choice.message = mock_message
|
|
mock_response = MagicMock()
|
|
mock_response.choices = [mock_choice]
|
|
mock_response.usage = {
|
|
"prompt_tokens": 5,
|
|
"completion_tokens": 5,
|
|
"total_tokens": 10,
|
|
}
|
|
|
|
mocked_completion.return_value = mock_response
|
|
|
|
llm.call("Test message")
|
|
|
|
mocked_completion.assert_called_once()
|
|
_, kwargs = mocked_completion.call_args
|
|
|
|
assert kwargs["response_format"] == response_format_dict
|
|
|
|
|
|
def test_response_model_overrides_response_format():
|
|
"""Test that response_model passed to call() overrides response_format from init."""
|
|
class InitResponse(BaseModel):
|
|
init_field: str
|
|
|
|
class CallResponse(BaseModel):
|
|
call_field: str
|
|
|
|
llm = LLM(model="gpt-4o-mini", response_format=InitResponse, is_litellm=True)
|
|
|
|
with patch("litellm.completion") as mocked_completion:
|
|
mock_message = MagicMock()
|
|
mock_message.content = '{"init_field": "value"}'
|
|
mock_message.tool_calls = []
|
|
mock_choice = MagicMock()
|
|
mock_choice.message = mock_message
|
|
mock_response = MagicMock()
|
|
mock_response.choices = [mock_choice]
|
|
mock_response.usage = {
|
|
"prompt_tokens": 5,
|
|
"completion_tokens": 5,
|
|
"total_tokens": 10,
|
|
}
|
|
|
|
mocked_completion.return_value = mock_response
|
|
|
|
result = llm.call("Test message")
|
|
|
|
mocked_completion.assert_called_once()
|
|
_, kwargs = mocked_completion.call_args
|
|
|
|
assert "response_format" in kwargs
|
|
assert kwargs["response_format"]["type"] == "json_schema"
|
|
assert kwargs["response_format"]["json_schema"]["name"] == "InitResponse"
|
|
|
|
|
|
@pytest.mark.vcr(filter_headers=["authorization"], filter_query_parameters=["key"])
|
|
@pytest.mark.parametrize(
|
|
"model",
|
|
[
|
|
"gemini/gemini-2.0-flash-thinking-exp-01-21",
|
|
"gemini/gemini-2.0-flash-001",
|
|
"gemini/gemini-2.0-flash-lite-001",
|
|
"gemini/gemini-2.5-flash-preview-04-17",
|
|
"gemini/gemini-2.5-pro-exp-03-25",
|
|
],
|
|
)
|
|
def test_gemini_models(model):
|
|
# Use LiteLLM for VCR compatibility (VCR can intercept HTTP calls but not native SDK calls)
|
|
llm = LLM(model=model, is_litellm=True)
|
|
result = llm.call("What is the capital of France?")
|
|
assert isinstance(result, str)
|
|
assert "Paris" in result
|
|
|
|
|
|
@pytest.mark.vcr(filter_headers=["authorization"], filter_query_parameters=["key"])
|
|
@pytest.mark.parametrize(
|
|
"model",
|
|
[
|
|
"gemini/gemma-3-27b-it",
|
|
],
|
|
)
|
|
def test_gemma3(model):
|
|
# Use LiteLLM for VCR compatibility (VCR can intercept HTTP calls but not native SDK calls)
|
|
llm = LLM(model=model, is_litellm=True)
|
|
result = llm.call("What is the capital of France?")
|
|
assert isinstance(result, str)
|
|
assert "Paris" in result
|
|
|
|
|
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
|
@pytest.mark.parametrize(
|
|
"model", ["gpt-4.1", "gpt-4.1-mini-2025-04-14", "gpt-4.1-nano-2025-04-14"]
|
|
)
|
|
def test_gpt_4_1(model):
|
|
llm = LLM(model=model)
|
|
result = llm.call("What is the capital of France?")
|
|
assert isinstance(result, str)
|
|
assert "Paris" in result
|
|
|
|
|
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
|
def test_o3_mini_reasoning_effort_high():
|
|
llm = LLM(
|
|
model="o3-mini",
|
|
reasoning_effort="high",
|
|
)
|
|
result = llm.call("What is the capital of France?")
|
|
assert isinstance(result, str)
|
|
assert "Paris" in result
|
|
|
|
|
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
|
def test_o3_mini_reasoning_effort_low():
|
|
llm = LLM(
|
|
model="o3-mini",
|
|
reasoning_effort="low",
|
|
)
|
|
result = llm.call("What is the capital of France?")
|
|
assert isinstance(result, str)
|
|
assert "Paris" in result
|
|
|
|
|
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
|
def test_o3_mini_reasoning_effort_medium():
|
|
llm = LLM(
|
|
model="o3-mini",
|
|
reasoning_effort="medium",
|
|
)
|
|
result = llm.call("What is the capital of France?")
|
|
assert isinstance(result, str)
|
|
assert "Paris" in result
|
|
|
|
|
|
def test_context_window_validation():
|
|
"""Test that context window validation works correctly."""
|
|
# Test valid window size
|
|
llm = LLM(model="o3-mini")
|
|
assert llm.get_context_window_size() == int(200000 * CONTEXT_WINDOW_USAGE_RATIO)
|
|
|
|
# Test invalid window size
|
|
with pytest.raises(ValueError) as excinfo:
|
|
with patch.dict(
|
|
"crewai.llm.LLM_CONTEXT_WINDOW_SIZES",
|
|
{"test-model": 500}, # Below minimum
|
|
clear=True,
|
|
):
|
|
llm = LLM(model="test-model")
|
|
llm.get_context_window_size()
|
|
assert "must be between 1024 and 2097152" in str(excinfo.value)
|
|
|
|
|
|
@pytest.fixture
|
|
def get_weather_tool_schema():
|
|
return {
|
|
"type": "function",
|
|
"function": {
|
|
"name": "get_weather",
|
|
"description": "Get the current weather in a given location",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"location": {
|
|
"type": "string",
|
|
"description": "The city and state, e.g. San Francisco, CA",
|
|
}
|
|
},
|
|
"required": ["location"],
|
|
},
|
|
},
|
|
}
|
|
|
|
|
|
def test_context_window_exceeded_error_handling():
|
|
"""Test that litellm.ContextWindowExceededError is converted to LLMContextLengthExceededError."""
|
|
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
|
LLMContextLengthExceededError,
|
|
)
|
|
from litellm.exceptions import ContextWindowExceededError
|
|
|
|
llm = LLM(model="gpt-4", is_litellm=True)
|
|
|
|
# Test non-streaming response
|
|
with patch("litellm.completion") as mock_completion:
|
|
mock_completion.side_effect = ContextWindowExceededError(
|
|
"This model's maximum context length is 8192 tokens. However, your messages resulted in 10000 tokens.",
|
|
model="gpt-4",
|
|
llm_provider="openai",
|
|
)
|
|
|
|
with pytest.raises(LLMContextLengthExceededError) as excinfo:
|
|
llm.call("This is a test message")
|
|
|
|
assert "context length exceeded" in str(excinfo.value).lower()
|
|
assert "8192 tokens" in str(excinfo.value)
|
|
|
|
# Test streaming response
|
|
llm = LLM(model="gpt-4", stream=True, is_litellm=True)
|
|
with patch("litellm.completion") as mock_completion:
|
|
mock_completion.side_effect = ContextWindowExceededError(
|
|
"This model's maximum context length is 8192 tokens. However, your messages resulted in 10000 tokens.",
|
|
model="gpt-4",
|
|
llm_provider="openai",
|
|
)
|
|
|
|
with pytest.raises(LLMContextLengthExceededError) as excinfo:
|
|
llm.call("This is a test message")
|
|
|
|
assert "context length exceeded" in str(excinfo.value).lower()
|
|
assert "8192 tokens" in str(excinfo.value)
|
|
|
|
|
|
@pytest.fixture
|
|
def anthropic_llm():
|
|
"""Fixture providing an Anthropic LLM instance."""
|
|
return LLM(model="anthropic/claude-3-sonnet", is_litellm=True)
|
|
|
|
|
|
@pytest.fixture
|
|
def system_message():
|
|
"""Fixture providing a system message."""
|
|
return {"role": "system", "content": "test"}
|
|
|
|
|
|
@pytest.fixture
|
|
def user_message():
|
|
"""Fixture providing a user message."""
|
|
return {"role": "user", "content": "test"}
|
|
|
|
|
|
def test_anthropic_message_formatting_edge_cases(anthropic_llm):
|
|
"""Test edge cases for Anthropic message formatting."""
|
|
# Test None messages
|
|
with pytest.raises(TypeError, match="Messages cannot be None"):
|
|
anthropic_llm._format_messages_for_provider(None)
|
|
|
|
# Test empty message list
|
|
formatted = anthropic_llm._format_messages_for_provider([])
|
|
assert len(formatted) == 1
|
|
assert formatted[0]["role"] == "user"
|
|
assert formatted[0]["content"] == "."
|
|
|
|
# Test invalid message format
|
|
with pytest.raises(TypeError, match="Invalid message format"):
|
|
anthropic_llm._format_messages_for_provider([{"invalid": "message"}])
|
|
|
|
|
|
def test_anthropic_model_detection():
|
|
"""Test Anthropic model detection with various formats."""
|
|
models = [
|
|
("anthropic/claude-3", True),
|
|
("claude-instant", True),
|
|
("claude/v1", True),
|
|
("gpt-4", False),
|
|
("anthropomorphic", False), # Should not match partial words
|
|
]
|
|
|
|
for model, expected in models:
|
|
llm = LLM(model=model, is_litellm=True)
|
|
assert llm.is_anthropic == expected, f"Failed for model: {model}"
|
|
|
|
|
|
def test_anthropic_message_formatting(anthropic_llm, system_message, user_message):
|
|
"""Test Anthropic message formatting with fixtures."""
|
|
# Test when first message is system
|
|
|
|
formatted = anthropic_llm._format_messages_for_provider([])
|
|
assert len(formatted) == 1
|
|
assert formatted[0]["role"] == "user"
|
|
assert formatted[0]["content"] == "."
|
|
|
|
with pytest.raises(TypeError, match="Invalid message format"):
|
|
anthropic_llm._format_messages_for_provider([{"invalid": "message"}])
|
|
|
|
|
|
def test_deepseek_r1_with_open_router():
|
|
if not os.getenv("OPEN_ROUTER_API_KEY"):
|
|
pytest.skip("OPEN_ROUTER_API_KEY not set; skipping test.")
|
|
|
|
llm = LLM(
|
|
model="openrouter/deepseek/deepseek-r1",
|
|
base_url="https://openrouter.ai/api/v1",
|
|
api_key=os.getenv("OPEN_ROUTER_API_KEY"),
|
|
is_litellm=True,
|
|
)
|
|
result = llm.call("What is the capital of France?")
|
|
assert isinstance(result, str)
|
|
assert "Paris" in result
|
|
|
|
|
|
def assert_event_count(
|
|
mock_emit,
|
|
expected_completed_tool_call: int = 0,
|
|
expected_stream_chunk: int = 0,
|
|
expected_completed_llm_call: int = 0,
|
|
expected_tool_usage_started: int = 0,
|
|
expected_tool_usage_finished: int = 0,
|
|
expected_tool_usage_error: int = 0,
|
|
expected_final_chunk_result: str = "",
|
|
):
|
|
event_count = {
|
|
"completed_tool_call": 0,
|
|
"stream_chunk": 0,
|
|
"completed_llm_call": 0,
|
|
"tool_usage_started": 0,
|
|
"tool_usage_finished": 0,
|
|
"tool_usage_error": 0,
|
|
}
|
|
final_chunk_result = ""
|
|
for _call in mock_emit.call_args_list:
|
|
event = _call[1]["event"]
|
|
|
|
if (
|
|
isinstance(event, LLMCallCompletedEvent)
|
|
and event.call_type.value == "tool_call"
|
|
):
|
|
event_count["completed_tool_call"] += 1
|
|
elif isinstance(event, LLMStreamChunkEvent):
|
|
event_count["stream_chunk"] += 1
|
|
final_chunk_result += event.chunk
|
|
elif (
|
|
isinstance(event, LLMCallCompletedEvent)
|
|
and event.call_type.value == "llm_call"
|
|
):
|
|
event_count["completed_llm_call"] += 1
|
|
elif isinstance(event, ToolUsageStartedEvent):
|
|
event_count["tool_usage_started"] += 1
|
|
elif isinstance(event, ToolUsageFinishedEvent):
|
|
event_count["tool_usage_finished"] += 1
|
|
elif isinstance(event, ToolUsageErrorEvent):
|
|
event_count["tool_usage_error"] += 1
|
|
else:
|
|
continue
|
|
|
|
assert event_count["completed_tool_call"] == expected_completed_tool_call
|
|
assert event_count["stream_chunk"] == expected_stream_chunk
|
|
assert event_count["completed_llm_call"] == expected_completed_llm_call
|
|
assert event_count["tool_usage_started"] == expected_tool_usage_started
|
|
assert event_count["tool_usage_finished"] == expected_tool_usage_finished
|
|
assert event_count["tool_usage_error"] == expected_tool_usage_error
|
|
assert final_chunk_result == expected_final_chunk_result
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_emit() -> MagicMock:
|
|
from crewai.events.event_bus import CrewAIEventsBus
|
|
|
|
with patch.object(CrewAIEventsBus, "emit") as mock_emit:
|
|
yield mock_emit
|
|
|
|
|
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
|
def test_handle_streaming_tool_calls(get_weather_tool_schema, mock_emit):
|
|
llm = LLM(model="openai/gpt-4o", stream=True, is_litellm=True)
|
|
response = llm.call(
|
|
messages=[
|
|
{"role": "user", "content": "What is the weather in New York?"},
|
|
],
|
|
tools=[get_weather_tool_schema],
|
|
available_functions={
|
|
"get_weather": lambda location: f"The weather in {location} is sunny"
|
|
},
|
|
)
|
|
assert response == "The weather in New York, NY is sunny"
|
|
|
|
expected_final_chunk_result = (
|
|
'{"location":"New York, NY"}The weather in New York, NY is sunny'
|
|
)
|
|
assert_event_count(
|
|
mock_emit=mock_emit,
|
|
expected_completed_tool_call=1,
|
|
expected_stream_chunk=10,
|
|
expected_completed_llm_call=1,
|
|
expected_tool_usage_started=1,
|
|
expected_tool_usage_finished=1,
|
|
expected_final_chunk_result=expected_final_chunk_result,
|
|
)
|
|
|
|
|
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
|
def test_handle_streaming_tool_calls_with_error(get_weather_tool_schema, mock_emit):
|
|
def get_weather_error(location):
|
|
raise Exception("Error")
|
|
|
|
llm = LLM(model="openai/gpt-4o", stream=True, is_litellm=True)
|
|
response = llm.call(
|
|
messages=[
|
|
{"role": "user", "content": "What is the weather in New York?"},
|
|
],
|
|
tools=[get_weather_tool_schema],
|
|
available_functions={"get_weather": get_weather_error},
|
|
)
|
|
assert response == ""
|
|
expected_final_chunk_result = '{"location":"New York, NY"}'
|
|
assert_event_count(
|
|
mock_emit=mock_emit,
|
|
expected_stream_chunk=9,
|
|
expected_completed_llm_call=1,
|
|
expected_tool_usage_started=1,
|
|
expected_tool_usage_error=1,
|
|
expected_final_chunk_result=expected_final_chunk_result,
|
|
)
|
|
|
|
|
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
|
def test_handle_streaming_tool_calls_no_available_functions(
|
|
get_weather_tool_schema, mock_emit
|
|
):
|
|
llm = LLM(model="openai/gpt-4o", stream=True, is_litellm=True)
|
|
response = llm.call(
|
|
messages=[
|
|
{"role": "user", "content": "What is the weather in New York?"},
|
|
],
|
|
tools=[get_weather_tool_schema],
|
|
)
|
|
assert response == ""
|
|
|
|
assert_event_count(
|
|
mock_emit=mock_emit,
|
|
expected_stream_chunk=9,
|
|
expected_completed_llm_call=1,
|
|
expected_final_chunk_result='{"location":"New York, NY"}',
|
|
)
|
|
|
|
|
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
|
def test_handle_streaming_tool_calls_no_tools(mock_emit):
|
|
llm = LLM(model="openai/gpt-4o", stream=True, is_litellm=True)
|
|
response = llm.call(
|
|
messages=[
|
|
{"role": "user", "content": "What is the weather in New York?"},
|
|
],
|
|
)
|
|
assert (
|
|
response
|
|
== "I'm unable to provide real-time information or current weather updates. For the latest weather information in New York, I recommend checking a reliable weather website or app, such as the National Weather Service, Weather.com, or a similar service."
|
|
)
|
|
|
|
assert_event_count(
|
|
mock_emit=mock_emit,
|
|
expected_stream_chunk=46,
|
|
expected_completed_llm_call=1,
|
|
expected_final_chunk_result=response,
|
|
)
|
|
|
|
|
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
|
@pytest.mark.skip(reason="Highly flaky on ci")
|
|
def test_llm_call_when_stop_is_unsupported(caplog):
|
|
llm = LLM(model="o1-mini", stop=["stop"], is_litellm=True)
|
|
with caplog.at_level(logging.INFO):
|
|
result = llm.call("What is the capital of France?")
|
|
assert "Retrying LLM call without the unsupported 'stop'" in caplog.text
|
|
assert isinstance(result, str)
|
|
assert "Paris" in result
|
|
|
|
|
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
|
@pytest.mark.skip(reason="Highly flaky on ci")
|
|
def test_llm_call_when_stop_is_unsupported_when_additional_drop_params_is_provided(
|
|
caplog,
|
|
):
|
|
llm = LLM(
|
|
model="o1-mini",
|
|
stop=["stop"],
|
|
additional_drop_params=["another_param"],
|
|
)
|
|
with caplog.at_level(logging.INFO):
|
|
result = llm.call("What is the capital of France?")
|
|
assert "Retrying LLM call without the unsupported 'stop'" in caplog.text
|
|
assert isinstance(result, str)
|
|
assert "Paris" in result
|
|
|
|
|
|
@pytest.fixture
|
|
def ollama_llm():
|
|
return LLM(model="ollama/llama3.2:3b", is_litellm=True)
|
|
|
|
|
|
def test_ollama_appends_dummy_user_message_when_last_is_assistant(ollama_llm):
|
|
original_messages = [
|
|
{"role": "user", "content": "Hi there"},
|
|
{"role": "assistant", "content": "Hello!"},
|
|
]
|
|
|
|
formatted = ollama_llm._format_messages_for_provider(original_messages)
|
|
|
|
assert len(formatted) == len(original_messages) + 1
|
|
assert formatted[-1]["role"] == "user"
|
|
assert formatted[-1]["content"] == ""
|
|
|
|
|
|
def test_ollama_does_not_modify_when_last_is_user(ollama_llm):
|
|
original_messages = [
|
|
{"role": "user", "content": "Tell me a joke."},
|
|
]
|
|
|
|
formatted = ollama_llm._format_messages_for_provider(original_messages)
|
|
|
|
assert formatted == original_messages
|
|
|
|
def test_native_provider_raises_error_when_supported_but_fails():
|
|
"""Test that when a native provider is in SUPPORTED_NATIVE_PROVIDERS but fails to instantiate, we raise the error."""
|
|
with patch("crewai.llm.SUPPORTED_NATIVE_PROVIDERS", ["openai"]):
|
|
with patch("crewai.llm.LLM._get_native_provider") as mock_get_native:
|
|
# Mock that provider exists but throws an error when instantiated
|
|
mock_provider = MagicMock()
|
|
mock_provider.side_effect = ValueError("Native provider initialization failed")
|
|
mock_get_native.return_value = mock_provider
|
|
|
|
with pytest.raises(ImportError) as excinfo:
|
|
LLM(model="gpt-4", is_litellm=False)
|
|
|
|
assert "Error importing native provider" in str(excinfo.value)
|
|
assert "Native provider initialization failed" in str(excinfo.value)
|
|
|
|
|
|
def test_native_provider_falls_back_to_litellm_when_not_in_supported_list():
|
|
"""Test that when a provider is not in SUPPORTED_NATIVE_PROVIDERS, we fall back to LiteLLM."""
|
|
with patch("crewai.llm.SUPPORTED_NATIVE_PROVIDERS", ["openai", "anthropic"]):
|
|
# Using a provider not in the supported list
|
|
llm = LLM(model="groq/llama-3.1-70b-versatile", is_litellm=False)
|
|
|
|
# Should fall back to LiteLLM
|
|
assert llm.is_litellm is True
|
|
assert llm.model == "groq/llama-3.1-70b-versatile"
|
|
|
|
|
|
def test_prefixed_models_with_valid_constants_use_native_sdk():
|
|
"""Test that models with native provider prefixes use native SDK when model is in constants."""
|
|
# Test openai/ prefix with actual OpenAI model in constants → Native SDK
|
|
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
|
|
llm = LLM(model="openai/gpt-4o", is_litellm=False)
|
|
assert llm.is_litellm is False
|
|
assert llm.provider == "openai"
|
|
|
|
# Test anthropic/ prefix with Claude model in constants → Native SDK
|
|
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}):
|
|
llm2 = LLM(model="anthropic/claude-opus-4-0", is_litellm=False)
|
|
assert llm2.is_litellm is False
|
|
assert llm2.provider == "anthropic"
|
|
|
|
# Test gemini/ prefix with Gemini model in constants → Native SDK
|
|
with patch.dict(os.environ, {"GOOGLE_API_KEY": "test-key"}):
|
|
llm3 = LLM(model="gemini/gemini-2.5-pro", is_litellm=False)
|
|
assert llm3.is_litellm is False
|
|
assert llm3.provider == "gemini"
|
|
|
|
|
|
def test_prefixed_models_with_invalid_constants_use_litellm():
|
|
"""Test that models with native provider prefixes use LiteLLM when model is NOT in constants."""
|
|
# Test openai/ prefix with non-OpenAI model (not in OPENAI_MODELS) → LiteLLM
|
|
llm = LLM(model="openai/gemini-2.5-flash", is_litellm=False)
|
|
assert llm.is_litellm is True
|
|
assert llm.model == "openai/gemini-2.5-flash"
|
|
|
|
# Test openai/ prefix with unknown future model → LiteLLM
|
|
llm2 = LLM(model="openai/gpt-future-6", is_litellm=False)
|
|
assert llm2.is_litellm is True
|
|
assert llm2.model == "openai/gpt-future-6"
|
|
|
|
# Test anthropic/ prefix with non-Anthropic model → LiteLLM
|
|
llm3 = LLM(model="anthropic/gpt-4o", is_litellm=False)
|
|
assert llm3.is_litellm is True
|
|
assert llm3.model == "anthropic/gpt-4o"
|
|
|
|
|
|
def test_prefixed_models_with_non_native_providers_use_litellm():
|
|
"""Test that models with non-native provider prefixes always use LiteLLM."""
|
|
# Test groq/ prefix (not a native provider) → LiteLLM
|
|
llm = LLM(model="groq/llama-3.3-70b", is_litellm=False)
|
|
assert llm.is_litellm is True
|
|
assert llm.model == "groq/llama-3.3-70b"
|
|
|
|
# Test together/ prefix (not a native provider) → LiteLLM
|
|
llm2 = LLM(model="together/qwen-2.5-72b", is_litellm=False)
|
|
assert llm2.is_litellm is True
|
|
assert llm2.model == "together/qwen-2.5-72b"
|
|
|
|
|
|
def test_unprefixed_models_use_native_sdk():
|
|
"""Test that unprefixed models use native SDK when model is in constants."""
|
|
# gpt-4o is in OPENAI_MODELS → Native OpenAI SDK
|
|
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
|
|
llm = LLM(model="gpt-4o", is_litellm=False)
|
|
assert llm.is_litellm is False
|
|
assert llm.provider == "openai"
|
|
|
|
# claude-opus-4-0 is in ANTHROPIC_MODELS → Native Anthropic SDK
|
|
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}):
|
|
llm2 = LLM(model="claude-opus-4-0", is_litellm=False)
|
|
assert llm2.is_litellm is False
|
|
assert llm2.provider == "anthropic"
|
|
|
|
# gemini-2.5-pro is in GEMINI_MODELS → Native Gemini SDK
|
|
with patch.dict(os.environ, {"GOOGLE_API_KEY": "test-key"}):
|
|
llm3 = LLM(model="gemini-2.5-pro", is_litellm=False)
|
|
assert llm3.is_litellm is False
|
|
assert llm3.provider == "gemini"
|
|
|
|
|
|
def test_explicit_provider_kwarg_takes_priority():
|
|
"""Test that explicit provider kwarg takes priority over model name inference."""
|
|
# Explicit provider=openai should use OpenAI even if model name suggests otherwise
|
|
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
|
|
llm = LLM(model="gpt-4o", provider="openai", is_litellm=False)
|
|
assert llm.is_litellm is False
|
|
assert llm.provider == "openai"
|
|
|
|
# Explicit provider for a model with "/" should still use that provider
|
|
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
|
|
llm2 = LLM(model="gpt-4o", provider="openai", is_litellm=False)
|
|
assert llm2.is_litellm is False
|
|
assert llm2.provider == "openai"
|
|
|
|
|
|
def test_validate_model_in_constants():
|
|
"""Test the _validate_model_in_constants method."""
|
|
# OpenAI models
|
|
assert LLM._validate_model_in_constants("gpt-4o", "openai") is True
|
|
assert LLM._validate_model_in_constants("gpt-future-6", "openai") is False
|
|
|
|
# Anthropic models
|
|
assert LLM._validate_model_in_constants("claude-opus-4-0", "claude") is True
|
|
assert LLM._validate_model_in_constants("claude-future-5", "claude") is False
|
|
|
|
# Gemini models
|
|
assert LLM._validate_model_in_constants("gemini-2.5-pro", "gemini") is True
|
|
assert LLM._validate_model_in_constants("gemini-future", "gemini") is False
|
|
|
|
# Azure models
|
|
assert LLM._validate_model_in_constants("gpt-4o", "azure") is True
|
|
assert LLM._validate_model_in_constants("gpt-35-turbo", "azure") is True
|
|
|
|
# Bedrock models
|
|
assert LLM._validate_model_in_constants("anthropic.claude-opus-4-1-20250805-v1:0", "bedrock") is True
|