Brandon/improve llm structured output (#2029)

* code and tests work

* update docs

---------

Co-authored-by: Lorenze Jay <63378463+lorenzejay@users.noreply.github.com>
This commit is contained in:
Brandon Hancock (bhancock_ai)
2025-02-04 16:46:48 -05:00
committed by GitHub
parent 515478473a
commit f4bb040ad8
3 changed files with 108 additions and 2 deletions

View File

@@ -5,15 +5,17 @@ import sys
import threading
import warnings
from contextlib import contextmanager
from typing import Any, Dict, List, Literal, Optional, Union, cast
from typing import Any, Dict, List, Literal, Optional, Type, Union, cast
from dotenv import load_dotenv
from pydantic import BaseModel
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
import litellm
from litellm import Choices, get_supported_openai_params
from litellm.types.utils import ModelResponse
from litellm.utils import supports_response_schema
from crewai.utilities.exceptions.context_window_exceeding_exception import (
@@ -128,7 +130,7 @@ class LLM:
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[int, float]] = None,
response_format: Optional[Dict[str, Any]] = None,
response_format: Optional[Type[BaseModel]] = None,
seed: Optional[int] = None,
logprobs: Optional[int] = None,
top_logprobs: Optional[int] = None,
@@ -213,6 +215,9 @@ class LLM:
response = llm.call(messages)
print(response)
"""
# Validate parameters before proceeding with the call.
self._validate_call_params()
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
@@ -312,6 +317,36 @@ class LLM:
logging.error(f"LiteLLM call failed: {str(e)}")
raise
def _get_custom_llm_provider(self) -> str:
"""
Derives the custom_llm_provider from the model string.
- For example, if the model is "openrouter/deepseek/deepseek-chat", returns "openrouter".
- If the model is "gemini/gemini-1.5-pro", returns "gemini".
- If there is no '/', defaults to "openai".
"""
if "/" in self.model:
return self.model.split("/")[0]
return "openai"
def _validate_call_params(self) -> None:
"""
Validate parameters before making a call. Currently this only checks if
a response_format is provided and whether the model supports it.
The custom_llm_provider is dynamically determined from the model:
- E.g., "openrouter/deepseek/deepseek-chat" yields "openrouter"
- "gemini/gemini-1.5-pro" yields "gemini"
- If no slash is present, "openai" is assumed.
"""
provider = self._get_custom_llm_provider()
if self.response_format is not None and not supports_response_schema(
model=self.model,
custom_llm_provider=provider,
):
raise ValueError(
f"The model {self.model} does not support response_format for provider '{provider}'. "
"Please remove response_format or use a supported model."
)
def supports_function_calling(self) -> bool:
try:
params = get_supported_openai_params(model=self.model)