mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-03 08:12:39 +00:00
Lorenze/native inference sdks (#3619)
* ruff linted * using native sdks with litellm fallback * drop exa * drop print on completion * Refactor LLM and utility functions for type consistency - Updated `max_tokens` parameter in `LLM` class to accept `float` in addition to `int`. - Modified `create_llm` function to ensure consistent type hints and return types, now returning `LLM | BaseLLM | None`. - Adjusted type hints for various parameters in `create_llm` and `_llm_via_environment_or_fallback` functions for improved clarity and type safety. - Enhanced test cases to reflect changes in type handling and ensure proper instantiation of LLM instances. * fix agent_tests * fix litellm tests and usagemetrics fix * drop print * Refactor LLM event handling and improve test coverage - Removed commented-out event emission for LLM call failures in `llm.py`. - Added `from_agent` parameter to `CrewAgentExecutor` for better context in LLM responses. - Enhanced test for LLM call failure to simulate OpenAI API failure and updated assertions for clarity. - Updated agent and task ID assertions in tests to ensure they are consistently treated as strings. * fix test_converter * fixed tests/agents/test_agent.py * Refactor LLM context length exception handling and improve provider integration - Renamed `LLMContextLengthExceededException` to `LLMContextLengthExceededExceptionError` for clarity and consistency. - Updated LLM class to pass the provider parameter correctly during initialization. - Enhanced error handling in various LLM provider implementations to raise the new exception type. - Adjusted tests to reflect the updated exception name and ensure proper error handling in context length scenarios. * Enhance LLM context window handling across providers - Introduced CONTEXT_WINDOW_USAGE_RATIO to adjust context window sizes dynamically for Anthropic, Azure, Gemini, and OpenAI LLMs. - Added validation for context window sizes in Azure and Gemini providers to ensure they fall within acceptable limits. - Updated context window size calculations to use the new ratio, improving consistency and adaptability across different models. - Removed hardcoded context window sizes in favor of ratio-based calculations for better flexibility. * fix test agent again * fix test agent * feat: add native LLM providers for Anthropic, Azure, and Gemini - Introduced new completion implementations for Anthropic, Azure, and Gemini, integrating their respective SDKs. - Added utility functions for tool validation and extraction to support function calling across LLM providers. - Enhanced context window management and token usage extraction for each provider. - Created a common utility module for shared functionality among LLM providers. * chore: update dependencies and improve context management - Removed direct dependency on `litellm` from the main dependencies and added it under extras for better modularity. - Updated the `litellm` dependency specification to allow for greater flexibility in versioning. - Refactored context length exception handling across various LLM providers to use a consistent error class. - Enhanced platform-specific dependency markers for NVIDIA packages to ensure compatibility across different systems. * refactor(tests): update LLM instantiation to include is_litellm flag in test cases - Modified multiple test cases in test_llm.py to set the is_litellm parameter to True when instantiating the LLM class. - This change ensures that the tests are aligned with the latest LLM configuration requirements and improves consistency across test scenarios. - Adjusted relevant assertions and comments to reflect the updated LLM behavior. * linter * linted * revert constants * fix(tests): correct type hint in expected model description - Updated the expected description in the test_generate_model_description_dict_field function to use 'Dict' instead of 'dict' for consistency with type hinting conventions. - This change ensures that the test accurately reflects the expected output format for model descriptions. * refactor(llm): enhance LLM instantiation and error handling - Updated the LLM class to include validation for the model parameter, ensuring it is a non-empty string. - Improved error handling by logging warnings when the native SDK fails, allowing for a fallback to LiteLLM. - Adjusted the instantiation of LLM in test cases to consistently include the is_litellm flag, aligning with recent changes in LLM configuration. - Modified relevant tests to reflect these updates, ensuring better coverage and accuracy in testing scenarios. * fixed test * refactor(llm): enhance token usage tracking and add copy methods - Updated the LLM class to track token usage and log callbacks in streaming mode, improving monitoring capabilities. - Introduced shallow and deep copy methods for the LLM instance, allowing for better management of LLM configurations and parameters. - Adjusted test cases to instantiate LLM with the is_litellm flag, ensuring alignment with recent changes in LLM configuration. * refactor(tests): reorganize imports and enhance error messages in test cases - Cleaned up import statements in test_crew.py for better organization and readability. - Enhanced error messages in test cases to use `re.escape` for improved regex matching, ensuring more robust error handling. - Adjusted comments for clarity and consistency across test scenarios. - Ensured that all necessary modules are imported correctly to avoid potential runtime issues.
This commit is contained in:
@@ -3,7 +3,6 @@ import os
|
||||
from time import sleep
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||
from crewai.events.event_types import (
|
||||
LLMCallCompletedEvent,
|
||||
@@ -15,33 +14,30 @@ from crewai.events.event_types import (
|
||||
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM
|
||||
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||
from pydantic import BaseModel
|
||||
import pytest
|
||||
|
||||
|
||||
# TODO: This test fails without print statement, which makes me think that something is happening asynchronously that we need to eventually fix and dive deeper into at a later date
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_llm_callback_replacement():
|
||||
llm1 = LLM(model="gpt-4o-mini")
|
||||
llm2 = LLM(model="gpt-4o-mini")
|
||||
llm1 = LLM(model="gpt-4o-mini", is_litellm=True)
|
||||
llm2 = LLM(model="gpt-4o-mini", is_litellm=True)
|
||||
|
||||
calc_handler_1 = TokenCalcHandler(token_cost_process=TokenProcess())
|
||||
calc_handler_2 = TokenCalcHandler(token_cost_process=TokenProcess())
|
||||
|
||||
result1 = llm1.call(
|
||||
llm1.call(
|
||||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||
callbacks=[calc_handler_1],
|
||||
)
|
||||
print("result1:", result1)
|
||||
usage_metrics_1 = calc_handler_1.token_cost_process.get_summary()
|
||||
print("usage_metrics_1:", usage_metrics_1)
|
||||
|
||||
result2 = llm2.call(
|
||||
llm2.call(
|
||||
messages=[{"role": "user", "content": "Hello, world from another agent!"}],
|
||||
callbacks=[calc_handler_2],
|
||||
)
|
||||
sleep(5)
|
||||
print("result2:", result2)
|
||||
usage_metrics_2 = calc_handler_2.token_cost_process.get_summary()
|
||||
print("usage_metrics_2:", usage_metrics_2)
|
||||
|
||||
# The first handler should not have been updated
|
||||
assert usage_metrics_1.successful_requests == 1
|
||||
@@ -61,7 +57,7 @@ def test_llm_call_with_string_input():
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_llm_call_with_string_input_and_callbacks():
|
||||
llm = LLM(model="gpt-4o-mini")
|
||||
llm = LLM(model="gpt-4o-mini", is_litellm=True)
|
||||
calc_handler = TokenCalcHandler(token_cost_process=TokenProcess())
|
||||
|
||||
# Test the call method with a string input and callbacks
|
||||
@@ -127,7 +123,7 @@ def test_llm_call_with_tool_and_string_input():
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_llm_call_with_tool_and_message_list():
|
||||
llm = LLM(model="gpt-4o-mini")
|
||||
llm = LLM(model="gpt-4o-mini", is_litellm=True)
|
||||
|
||||
def square_number(number: int) -> int:
|
||||
"""Returns the square of a number."""
|
||||
@@ -171,6 +167,7 @@ def test_llm_passes_additional_params():
|
||||
model="gpt-4o-mini",
|
||||
vertex_credentials="test_credentials",
|
||||
vertex_project="test_project",
|
||||
is_litellm=True,
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": "Hello, world!"}]
|
||||
@@ -223,7 +220,7 @@ def test_get_custom_llm_provider_gemini():
|
||||
|
||||
|
||||
def test_get_custom_llm_provider_openai():
|
||||
llm = LLM(model="gpt-4")
|
||||
llm = LLM(model="gpt-4", is_litellm=True)
|
||||
assert llm._get_custom_llm_provider() is None
|
||||
|
||||
|
||||
@@ -380,7 +377,7 @@ def test_context_window_exceeded_error_handling():
|
||||
)
|
||||
from litellm.exceptions import ContextWindowExceededError
|
||||
|
||||
llm = LLM(model="gpt-4")
|
||||
llm = LLM(model="gpt-4", is_litellm=True)
|
||||
|
||||
# Test non-streaming response
|
||||
with patch("litellm.completion") as mock_completion:
|
||||
@@ -397,7 +394,7 @@ def test_context_window_exceeded_error_handling():
|
||||
assert "8192 tokens" in str(excinfo.value)
|
||||
|
||||
# Test streaming response
|
||||
llm = LLM(model="gpt-4", stream=True)
|
||||
llm = LLM(model="gpt-4", stream=True, is_litellm=True)
|
||||
with patch("litellm.completion") as mock_completion:
|
||||
mock_completion.side_effect = ContextWindowExceededError(
|
||||
"This model's maximum context length is 8192 tokens. However, your messages resulted in 10000 tokens.",
|
||||
@@ -416,7 +413,7 @@ def test_context_window_exceeded_error_handling():
|
||||
@pytest.fixture
|
||||
def anthropic_llm():
|
||||
"""Fixture providing an Anthropic LLM instance."""
|
||||
return LLM(model="anthropic/claude-3-sonnet")
|
||||
return LLM(model="anthropic/claude-3-sonnet", is_litellm=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -455,40 +452,25 @@ def test_anthropic_model_detection():
|
||||
("claude-instant", True),
|
||||
("claude/v1", True),
|
||||
("gpt-4", False),
|
||||
("", False),
|
||||
("anthropomorphic", False), # Should not match partial words
|
||||
]
|
||||
|
||||
for model, expected in models:
|
||||
llm = LLM(model=model)
|
||||
llm = LLM(model=model, is_litellm=True)
|
||||
assert llm.is_anthropic == expected, f"Failed for model: {model}"
|
||||
|
||||
|
||||
def test_anthropic_message_formatting(anthropic_llm, system_message, user_message):
|
||||
"""Test Anthropic message formatting with fixtures."""
|
||||
# Test when first message is system
|
||||
formatted = anthropic_llm._format_messages_for_provider([system_message])
|
||||
assert len(formatted) == 2
|
||||
assert formatted[0]["role"] == "user"
|
||||
assert formatted[0]["content"] == "."
|
||||
assert formatted[1] == system_message
|
||||
|
||||
# Test when first message is already user
|
||||
formatted = anthropic_llm._format_messages_for_provider([user_message])
|
||||
assert len(formatted) == 1
|
||||
assert formatted[0] == user_message
|
||||
|
||||
# Test with empty message list
|
||||
formatted = anthropic_llm._format_messages_for_provider([])
|
||||
assert len(formatted) == 1
|
||||
assert formatted[0]["role"] == "user"
|
||||
assert formatted[0]["content"] == "."
|
||||
|
||||
# Test with non-Anthropic model (should not modify messages)
|
||||
non_anthropic_llm = LLM(model="gpt-4")
|
||||
formatted = non_anthropic_llm._format_messages_for_provider([system_message])
|
||||
assert len(formatted) == 1
|
||||
assert formatted[0] == system_message
|
||||
with pytest.raises(TypeError, match="Invalid message format"):
|
||||
anthropic_llm._format_messages_for_provider([{"invalid": "message"}])
|
||||
|
||||
|
||||
def test_deepseek_r1_with_open_router():
|
||||
@@ -499,6 +481,7 @@ def test_deepseek_r1_with_open_router():
|
||||
model="openrouter/deepseek/deepseek-r1",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_key=os.getenv("OPEN_ROUTER_API_KEY"),
|
||||
is_litellm=True,
|
||||
)
|
||||
result = llm.call("What is the capital of France?")
|
||||
assert isinstance(result, str)
|
||||
@@ -568,7 +551,7 @@ def mock_emit() -> MagicMock:
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_handle_streaming_tool_calls(get_weather_tool_schema, mock_emit):
|
||||
llm = LLM(model="openai/gpt-4o", stream=True)
|
||||
llm = LLM(model="openai/gpt-4o", stream=True, is_litellm=True)
|
||||
response = llm.call(
|
||||
messages=[
|
||||
{"role": "user", "content": "What is the weather in New York?"},
|
||||
@@ -599,7 +582,7 @@ def test_handle_streaming_tool_calls_with_error(get_weather_tool_schema, mock_em
|
||||
def get_weather_error(location):
|
||||
raise Exception("Error")
|
||||
|
||||
llm = LLM(model="openai/gpt-4o", stream=True)
|
||||
llm = LLM(model="openai/gpt-4o", stream=True, is_litellm=True)
|
||||
response = llm.call(
|
||||
messages=[
|
||||
{"role": "user", "content": "What is the weather in New York?"},
|
||||
@@ -623,7 +606,7 @@ def test_handle_streaming_tool_calls_with_error(get_weather_tool_schema, mock_em
|
||||
def test_handle_streaming_tool_calls_no_available_functions(
|
||||
get_weather_tool_schema, mock_emit
|
||||
):
|
||||
llm = LLM(model="openai/gpt-4o", stream=True)
|
||||
llm = LLM(model="openai/gpt-4o", stream=True, is_litellm=True)
|
||||
response = llm.call(
|
||||
messages=[
|
||||
{"role": "user", "content": "What is the weather in New York?"},
|
||||
@@ -642,7 +625,7 @@ def test_handle_streaming_tool_calls_no_available_functions(
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_handle_streaming_tool_calls_no_tools(mock_emit):
|
||||
llm = LLM(model="openai/gpt-4o", stream=True)
|
||||
llm = LLM(model="openai/gpt-4o", stream=True, is_litellm=True)
|
||||
response = llm.call(
|
||||
messages=[
|
||||
{"role": "user", "content": "What is the weather in New York?"},
|
||||
@@ -663,7 +646,7 @@ def test_handle_streaming_tool_calls_no_tools(mock_emit):
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_llm_call_when_stop_is_unsupported(caplog):
|
||||
llm = LLM(model="o1-mini", stop=["stop"])
|
||||
llm = LLM(model="o1-mini", stop=["stop"], is_litellm=True)
|
||||
with caplog.at_level(logging.INFO):
|
||||
result = llm.call("What is the capital of France?")
|
||||
assert "Retrying LLM call without the unsupported 'stop'" in caplog.text
|
||||
@@ -675,7 +658,12 @@ def test_llm_call_when_stop_is_unsupported(caplog):
|
||||
def test_llm_call_when_stop_is_unsupported_when_additional_drop_params_is_provided(
|
||||
caplog,
|
||||
):
|
||||
llm = LLM(model="o1-mini", stop=["stop"], additional_drop_params=["another_param"])
|
||||
llm = LLM(
|
||||
model="o1-mini",
|
||||
stop=["stop"],
|
||||
additional_drop_params=["another_param"],
|
||||
is_litellm=True,
|
||||
)
|
||||
with caplog.at_level(logging.INFO):
|
||||
result = llm.call("What is the capital of France?")
|
||||
assert "Retrying LLM call without the unsupported 'stop'" in caplog.text
|
||||
@@ -685,7 +673,7 @@ def test_llm_call_when_stop_is_unsupported_when_additional_drop_params_is_provid
|
||||
|
||||
@pytest.fixture
|
||||
def ollama_llm():
|
||||
return LLM(model="ollama/llama3.2:3b")
|
||||
return LLM(model="ollama/llama3.2:3b", is_litellm=True)
|
||||
|
||||
|
||||
def test_ollama_appends_dummy_user_message_when_last_is_assistant(ollama_llm):
|
||||
|
||||
Reference in New Issue
Block a user