mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-23 07:08:14 +00:00
Compare commits
2 Commits
gl/feat/na
...
devin/1760
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1ae3a003b6 | ||
|
|
fc4b0dd923 |
@@ -299,6 +299,7 @@ class LLM(BaseLLM):
|
|||||||
callbacks: list[Any] | None = None,
|
callbacks: list[Any] | None = None,
|
||||||
reasoning_effort: Literal["none", "low", "medium", "high"] | None = None,
|
reasoning_effort: Literal["none", "low", "medium", "high"] | None = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
|
supports_function_calling: bool | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
@@ -325,6 +326,7 @@ class LLM(BaseLLM):
|
|||||||
self.additional_params = kwargs
|
self.additional_params = kwargs
|
||||||
self.is_anthropic = self._is_anthropic_model(model)
|
self.is_anthropic = self._is_anthropic_model(model)
|
||||||
self.stream = stream
|
self.stream = stream
|
||||||
|
self._supports_function_calling_override = supports_function_calling
|
||||||
|
|
||||||
litellm.drop_params = True
|
litellm.drop_params = True
|
||||||
|
|
||||||
@@ -1197,6 +1199,9 @@ class LLM(BaseLLM):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def supports_function_calling(self) -> bool:
|
def supports_function_calling(self) -> bool:
|
||||||
|
if self._supports_function_calling_override is not None:
|
||||||
|
return self._supports_function_calling_override
|
||||||
|
|
||||||
try:
|
try:
|
||||||
provider = self._get_custom_llm_provider()
|
provider = self._get_custom_llm_provider()
|
||||||
return litellm.utils.supports_function_calling(
|
return litellm.utils.supports_function_calling(
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from typing import Any, Final
|
|||||||
|
|
||||||
DEFAULT_CONTEXT_WINDOW_SIZE: Final[int] = 4096
|
DEFAULT_CONTEXT_WINDOW_SIZE: Final[int] = 4096
|
||||||
DEFAULT_SUPPORTS_STOP_WORDS: Final[bool] = True
|
DEFAULT_SUPPORTS_STOP_WORDS: Final[bool] = True
|
||||||
|
DEFAULT_SUPPORTS_FUNCTION_CALLING: Final[bool] = True
|
||||||
|
|
||||||
|
|
||||||
class BaseLLM(ABC):
|
class BaseLLM(ABC):
|
||||||
@@ -82,6 +83,14 @@ class BaseLLM(ABC):
|
|||||||
RuntimeError: If the LLM request fails for other reasons.
|
RuntimeError: If the LLM request fails for other reasons.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def supports_function_calling(self) -> bool:
|
||||||
|
"""Check if the LLM supports function calling.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the LLM supports function calling, False otherwise.
|
||||||
|
"""
|
||||||
|
return DEFAULT_SUPPORTS_FUNCTION_CALLING
|
||||||
|
|
||||||
def supports_stop_words(self) -> bool:
|
def supports_stop_words(self) -> bool:
|
||||||
"""Check if the LLM supports stop words.
|
"""Check if the LLM supports stop words.
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -159,11 +159,11 @@ class JWTAuthLLM(BaseLLM):
|
|||||||
|
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
messages: Union[str, List[Dict[str, str]]],
|
messages: str | list[dict[str, str]],
|
||||||
tools: Optional[List[dict]] = None,
|
tools: list[dict] | None = None,
|
||||||
callbacks: Optional[List[Any]] = None,
|
callbacks: list[Any] | None = None,
|
||||||
available_functions: Optional[Dict[str, Any]] = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
) -> Union[str, Any]:
|
) -> str | Any:
|
||||||
"""Record the call and return a predefined response."""
|
"""Record the call and return a predefined response."""
|
||||||
self.calls.append(
|
self.calls.append(
|
||||||
{
|
{
|
||||||
@@ -192,7 +192,7 @@ class JWTAuthLLM(BaseLLM):
|
|||||||
|
|
||||||
def test_custom_llm_with_jwt_auth():
|
def test_custom_llm_with_jwt_auth():
|
||||||
"""Test a custom LLM implementation with JWT authentication."""
|
"""Test a custom LLM implementation with JWT authentication."""
|
||||||
jwt_llm = JWTAuthLLM(jwt_token="example.jwt.token")
|
jwt_llm = JWTAuthLLM(jwt_token="example.jwt.token") # noqa: S106
|
||||||
|
|
||||||
# Test that create_llm returns the JWT-authenticated LLM instance directly
|
# Test that create_llm returns the JWT-authenticated LLM instance directly
|
||||||
result_llm = create_llm(jwt_llm)
|
result_llm = create_llm(jwt_llm)
|
||||||
@@ -238,11 +238,11 @@ class TimeoutHandlingLLM(BaseLLM):
|
|||||||
|
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
messages: Union[str, List[Dict[str, str]]],
|
messages: str | list[dict[str, str]],
|
||||||
tools: Optional[List[dict]] = None,
|
tools: list[dict] | None = None,
|
||||||
callbacks: Optional[List[Any]] = None,
|
callbacks: list[Any] | None = None,
|
||||||
available_functions: Optional[Dict[str, Any]] = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
) -> Union[str, Any]:
|
) -> str | Any:
|
||||||
"""Simulate API calls with timeout handling and retry logic.
|
"""Simulate API calls with timeout handling and retry logic.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -282,35 +282,34 @@ class TimeoutHandlingLLM(BaseLLM):
|
|||||||
)
|
)
|
||||||
# Otherwise, continue to the next attempt (simulating backoff)
|
# Otherwise, continue to the next attempt (simulating backoff)
|
||||||
continue
|
continue
|
||||||
else:
|
# Success on first attempt
|
||||||
# Success on first attempt
|
return "First attempt response"
|
||||||
return "First attempt response"
|
# This is a retry attempt (attempt > 0)
|
||||||
else:
|
# Always record retry attempts
|
||||||
# This is a retry attempt (attempt > 0)
|
self.calls.append(
|
||||||
# Always record retry attempts
|
{
|
||||||
self.calls.append(
|
"retry_attempt": attempt,
|
||||||
{
|
"messages": messages,
|
||||||
"retry_attempt": attempt,
|
"tools": tools,
|
||||||
"messages": messages,
|
"callbacks": callbacks,
|
||||||
"tools": tools,
|
"available_functions": available_functions,
|
||||||
"callbacks": callbacks,
|
}
|
||||||
"available_functions": available_functions,
|
)
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Simulate a failure if fail_count > 0
|
# Simulate a failure if fail_count > 0
|
||||||
if self.fail_count > 0:
|
if self.fail_count > 0:
|
||||||
self.fail_count -= 1
|
self.fail_count -= 1
|
||||||
# If we've used all retries, raise an error
|
# If we've used all retries, raise an error
|
||||||
if attempt == self.max_retries - 1:
|
if attempt == self.max_retries - 1:
|
||||||
raise TimeoutError(
|
raise TimeoutError(
|
||||||
f"LLM request failed after {self.max_retries} attempts"
|
f"LLM request failed after {self.max_retries} attempts"
|
||||||
)
|
)
|
||||||
# Otherwise, continue to the next attempt (simulating backoff)
|
# Otherwise, continue to the next attempt (simulating backoff)
|
||||||
continue
|
continue
|
||||||
else:
|
# Success on retry
|
||||||
# Success on retry
|
return "Response after retry"
|
||||||
return "Response after retry"
|
|
||||||
|
return "Response after retry"
|
||||||
|
|
||||||
def supports_function_calling(self) -> bool:
|
def supports_function_calling(self) -> bool:
|
||||||
"""Return True to indicate that function calling is supported.
|
"""Return True to indicate that function calling is supported.
|
||||||
@@ -358,3 +357,25 @@ def test_timeout_handling_llm():
|
|||||||
with pytest.raises(TimeoutError, match="LLM request failed after 2 attempts"):
|
with pytest.raises(TimeoutError, match="LLM request failed after 2 attempts"):
|
||||||
llm.call("Test message")
|
llm.call("Test message")
|
||||||
assert len(llm.calls) == 2 # Initial call + failed retry attempt
|
assert len(llm.calls) == 2 # Initial call + failed retry attempt
|
||||||
|
|
||||||
|
|
||||||
|
class MinimalCustomLLM(BaseLLM):
|
||||||
|
"""Minimal custom LLM implementation that doesn't override supports_function_calling."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(model="minimal-model")
|
||||||
|
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
messages: str | list[dict[str, str]],
|
||||||
|
tools: list[dict] | None = None,
|
||||||
|
callbacks: list[Any] | None = None,
|
||||||
|
available_functions: dict[str, Any] | None = None,
|
||||||
|
) -> str | Any:
|
||||||
|
return "Minimal response"
|
||||||
|
|
||||||
|
|
||||||
|
def test_base_llm_supports_function_calling_default():
|
||||||
|
"""Test that BaseLLM supports function calling by default."""
|
||||||
|
llm = MinimalCustomLLM()
|
||||||
|
assert llm.supports_function_calling() is True
|
||||||
|
|||||||
@@ -711,3 +711,18 @@ def test_ollama_does_not_modify_when_last_is_user(ollama_llm):
|
|||||||
formatted = ollama_llm._format_messages_for_provider(original_messages)
|
formatted = ollama_llm._format_messages_for_provider(original_messages)
|
||||||
|
|
||||||
assert formatted == original_messages
|
assert formatted == original_messages
|
||||||
|
|
||||||
|
|
||||||
|
def test_supports_function_calling_with_override_true():
|
||||||
|
llm = LLM(model="custom-model/my-model", supports_function_calling=True)
|
||||||
|
assert llm.supports_function_calling() is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_supports_function_calling_with_override_false():
|
||||||
|
llm = LLM(model="gpt-4o-mini", supports_function_calling=False)
|
||||||
|
assert llm.supports_function_calling() is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_supports_function_calling_without_override():
|
||||||
|
llm = LLM(model="gpt-4o-mini")
|
||||||
|
assert llm.supports_function_calling() is True
|
||||||
|
|||||||
Reference in New Issue
Block a user