Compare commits

...

2 Commits

Author SHA1 Message Date
Devin AI
1ae3a003b6 Fix lint errors in test_custom_llm.py
- Add noqa comment for hardcoded test JWT token
- Add return statement to satisfy ruff RET503 check

Co-Authored-By: João <joao@crewai.com>
2025-10-15 03:01:54 +00:00
Devin AI
fc4b0dd923 Fix function_calling_llm support for custom models
- Add supports_function_calling() method to BaseLLM class with default True
- Add supports_function_calling parameter to LLM class to allow override of litellm check
- Add tests for both BaseLLM default and LLM override functionality
- Fixes #3708: Custom models not in litellm's list can now use function calling

Co-Authored-By: João <joao@crewai.com>
2025-10-15 02:57:03 +00:00
4 changed files with 90 additions and 40 deletions

View File

@@ -299,6 +299,7 @@ class LLM(BaseLLM):
callbacks: list[Any] | None = None,
reasoning_effort: Literal["none", "low", "medium", "high"] | None = None,
stream: bool = False,
supports_function_calling: bool | None = None,
**kwargs,
):
self.model = model
@@ -325,6 +326,7 @@ class LLM(BaseLLM):
self.additional_params = kwargs
self.is_anthropic = self._is_anthropic_model(model)
self.stream = stream
self._supports_function_calling_override = supports_function_calling
litellm.drop_params = True
@@ -1197,6 +1199,9 @@ class LLM(BaseLLM):
)
def supports_function_calling(self) -> bool:
if self._supports_function_calling_override is not None:
return self._supports_function_calling_override
try:
provider = self._get_custom_llm_provider()
return litellm.utils.supports_function_calling(

View File

@@ -9,6 +9,7 @@ from typing import Any, Final
DEFAULT_CONTEXT_WINDOW_SIZE: Final[int] = 4096
DEFAULT_SUPPORTS_STOP_WORDS: Final[bool] = True
DEFAULT_SUPPORTS_FUNCTION_CALLING: Final[bool] = True
class BaseLLM(ABC):
@@ -82,6 +83,14 @@ class BaseLLM(ABC):
RuntimeError: If the LLM request fails for other reasons.
"""
def supports_function_calling(self) -> bool:
"""Check if the LLM supports function calling.
Returns:
True if the LLM supports function calling, False otherwise.
"""
return DEFAULT_SUPPORTS_FUNCTION_CALLING
def supports_stop_words(self) -> bool:
"""Check if the LLM supports stop words.

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Union
from typing import Any
import pytest
@@ -159,11 +159,11 @@ class JWTAuthLLM(BaseLLM):
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
messages: str | list[dict[str, str]],
tools: list[dict] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
) -> str | Any:
"""Record the call and return a predefined response."""
self.calls.append(
{
@@ -192,7 +192,7 @@ class JWTAuthLLM(BaseLLM):
def test_custom_llm_with_jwt_auth():
"""Test a custom LLM implementation with JWT authentication."""
jwt_llm = JWTAuthLLM(jwt_token="example.jwt.token")
jwt_llm = JWTAuthLLM(jwt_token="example.jwt.token") # noqa: S106
# Test that create_llm returns the JWT-authenticated LLM instance directly
result_llm = create_llm(jwt_llm)
@@ -238,11 +238,11 @@ class TimeoutHandlingLLM(BaseLLM):
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
messages: str | list[dict[str, str]],
tools: list[dict] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
) -> str | Any:
"""Simulate API calls with timeout handling and retry logic.
Args:
@@ -282,35 +282,34 @@ class TimeoutHandlingLLM(BaseLLM):
)
# Otherwise, continue to the next attempt (simulating backoff)
continue
else:
# Success on first attempt
return "First attempt response"
else:
# This is a retry attempt (attempt > 0)
# Always record retry attempts
self.calls.append(
{
"retry_attempt": attempt,
"messages": messages,
"tools": tools,
"callbacks": callbacks,
"available_functions": available_functions,
}
)
# Success on first attempt
return "First attempt response"
# This is a retry attempt (attempt > 0)
# Always record retry attempts
self.calls.append(
{
"retry_attempt": attempt,
"messages": messages,
"tools": tools,
"callbacks": callbacks,
"available_functions": available_functions,
}
)
# Simulate a failure if fail_count > 0
if self.fail_count > 0:
self.fail_count -= 1
# If we've used all retries, raise an error
if attempt == self.max_retries - 1:
raise TimeoutError(
f"LLM request failed after {self.max_retries} attempts"
)
# Otherwise, continue to the next attempt (simulating backoff)
continue
else:
# Success on retry
return "Response after retry"
# Simulate a failure if fail_count > 0
if self.fail_count > 0:
self.fail_count -= 1
# If we've used all retries, raise an error
if attempt == self.max_retries - 1:
raise TimeoutError(
f"LLM request failed after {self.max_retries} attempts"
)
# Otherwise, continue to the next attempt (simulating backoff)
continue
# Success on retry
return "Response after retry"
return "Response after retry"
def supports_function_calling(self) -> bool:
"""Return True to indicate that function calling is supported.
@@ -358,3 +357,25 @@ def test_timeout_handling_llm():
with pytest.raises(TimeoutError, match="LLM request failed after 2 attempts"):
llm.call("Test message")
assert len(llm.calls) == 2 # Initial call + failed retry attempt
class MinimalCustomLLM(BaseLLM):
"""Minimal custom LLM implementation that doesn't override supports_function_calling."""
def __init__(self):
super().__init__(model="minimal-model")
def call(
self,
messages: str | list[dict[str, str]],
tools: list[dict] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
) -> str | Any:
return "Minimal response"
def test_base_llm_supports_function_calling_default():
"""Test that BaseLLM supports function calling by default."""
llm = MinimalCustomLLM()
assert llm.supports_function_calling() is True

View File

@@ -711,3 +711,18 @@ def test_ollama_does_not_modify_when_last_is_user(ollama_llm):
formatted = ollama_llm._format_messages_for_provider(original_messages)
assert formatted == original_messages
def test_supports_function_calling_with_override_true():
llm = LLM(model="custom-model/my-model", supports_function_calling=True)
assert llm.supports_function_calling() is True
def test_supports_function_calling_with_override_false():
llm = LLM(model="gpt-4o-mini", supports_function_calling=False)
assert llm.supports_function_calling() is False
def test_supports_function_calling_without_override():
llm = LLM(model="gpt-4o-mini")
assert llm.supports_function_calling() is True