mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
code and tests work
This commit is contained in:
@@ -5,15 +5,17 @@ import sys
|
|||||||
import threading
|
import threading
|
||||||
import warnings
|
import warnings
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Dict, List, Optional, Union, cast
|
from typing import Any, Dict, List, Literal, Optional, Type, Union, cast
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore", UserWarning)
|
warnings.simplefilter("ignore", UserWarning)
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import Choices, get_supported_openai_params
|
from litellm import Choices, get_supported_openai_params
|
||||||
from litellm.types.utils import ModelResponse
|
from litellm.types.utils import ModelResponse
|
||||||
|
from litellm.utils import supports_response_schema
|
||||||
|
|
||||||
|
|
||||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||||
@@ -128,7 +130,7 @@ class LLM:
|
|||||||
presence_penalty: Optional[float] = None,
|
presence_penalty: Optional[float] = None,
|
||||||
frequency_penalty: Optional[float] = None,
|
frequency_penalty: Optional[float] = None,
|
||||||
logit_bias: Optional[Dict[int, float]] = None,
|
logit_bias: Optional[Dict[int, float]] = None,
|
||||||
response_format: Optional[Dict[str, Any]] = None,
|
response_format: Optional[Type[BaseModel]] = None,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
logprobs: Optional[int] = None,
|
logprobs: Optional[int] = None,
|
||||||
top_logprobs: Optional[int] = None,
|
top_logprobs: Optional[int] = None,
|
||||||
@@ -211,6 +213,9 @@ class LLM:
|
|||||||
response = llm.call(messages)
|
response = llm.call(messages)
|
||||||
print(response)
|
print(response)
|
||||||
"""
|
"""
|
||||||
|
# Validate parameters before proceeding with the call.
|
||||||
|
self._validate_call_params()
|
||||||
|
|
||||||
if isinstance(messages, str):
|
if isinstance(messages, str):
|
||||||
messages = [{"role": "user", "content": messages}]
|
messages = [{"role": "user", "content": messages}]
|
||||||
|
|
||||||
@@ -309,6 +314,36 @@ class LLM:
|
|||||||
logging.error(f"LiteLLM call failed: {str(e)}")
|
logging.error(f"LiteLLM call failed: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
def _get_custom_llm_provider(self) -> str:
|
||||||
|
"""
|
||||||
|
Derives the custom_llm_provider from the model string.
|
||||||
|
- For example, if the model is "openrouter/deepseek/deepseek-chat", returns "openrouter".
|
||||||
|
- If the model is "gemini/gemini-1.5-pro", returns "gemini".
|
||||||
|
- If there is no '/', defaults to "openai".
|
||||||
|
"""
|
||||||
|
if "/" in self.model:
|
||||||
|
return self.model.split("/")[0]
|
||||||
|
return "openai"
|
||||||
|
|
||||||
|
def _validate_call_params(self) -> None:
|
||||||
|
"""
|
||||||
|
Validate parameters before making a call. Currently this only checks if
|
||||||
|
a response_format is provided and whether the model supports it.
|
||||||
|
The custom_llm_provider is dynamically determined from the model:
|
||||||
|
- E.g., "openrouter/deepseek/deepseek-chat" yields "openrouter"
|
||||||
|
- "gemini/gemini-1.5-pro" yields "gemini"
|
||||||
|
- If no slash is present, "openai" is assumed.
|
||||||
|
"""
|
||||||
|
provider = self._get_custom_llm_provider()
|
||||||
|
if self.response_format is not None and not supports_response_schema(
|
||||||
|
model=self.model,
|
||||||
|
custom_llm_provider=provider,
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"The model {self.model} does not support response_format for provider '{provider}'. "
|
||||||
|
"Please remove response_format or use a supported model."
|
||||||
|
)
|
||||||
|
|
||||||
def supports_function_calling(self) -> bool:
|
def supports_function_calling(self) -> bool:
|
||||||
try:
|
try:
|
||||||
params = get_supported_openai_params(model=self.model)
|
params = get_supported_openai_params(model=self.model)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from time import sleep
|
|||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||||
from crewai.llm import LLM
|
from crewai.llm import LLM
|
||||||
@@ -202,3 +203,49 @@ def test_llm_passes_additional_params():
|
|||||||
|
|
||||||
# Check the result from llm.call
|
# Check the result from llm.call
|
||||||
assert result == "Test response"
|
assert result == "Test response"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_custom_llm_provider_openrouter():
|
||||||
|
llm = LLM(model="openrouter/deepseek/deepseek-chat")
|
||||||
|
assert llm._get_custom_llm_provider() == "openrouter"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_custom_llm_provider_gemini():
|
||||||
|
llm = LLM(model="gemini/gemini-1.5-pro")
|
||||||
|
assert llm._get_custom_llm_provider() == "gemini"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_custom_llm_provider_openai():
|
||||||
|
llm = LLM(model="gpt-4")
|
||||||
|
assert llm._get_custom_llm_provider() == "openai"
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_call_params_supported():
|
||||||
|
class DummyResponse(BaseModel):
|
||||||
|
a: int
|
||||||
|
|
||||||
|
# Patch supports_response_schema to simulate a supported model.
|
||||||
|
with patch("crewai.llm.supports_response_schema", return_value=True):
|
||||||
|
llm = LLM(
|
||||||
|
model="openrouter/deepseek/deepseek-chat", response_format=DummyResponse
|
||||||
|
)
|
||||||
|
# Should not raise any error.
|
||||||
|
llm._validate_call_params()
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_call_params_not_supported():
|
||||||
|
class DummyResponse(BaseModel):
|
||||||
|
a: int
|
||||||
|
|
||||||
|
# Patch supports_response_schema to simulate an unsupported model.
|
||||||
|
with patch("crewai.llm.supports_response_schema", return_value=False):
|
||||||
|
llm = LLM(model="gemini/gemini-1.5-pro", response_format=DummyResponse)
|
||||||
|
with pytest.raises(ValueError) as excinfo:
|
||||||
|
llm._validate_call_params()
|
||||||
|
assert "does not support response_format" in str(excinfo.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_call_params_no_response_format():
|
||||||
|
# When no response_format is provided, no validation error should occur.
|
||||||
|
llm = LLM(model="gemini/gemini-1.5-pro", response_format=None)
|
||||||
|
llm._validate_call_params()
|
||||||
|
|||||||
Reference in New Issue
Block a user