Check the right property for tool calling (#2160)

* Check the right property

* Fix failing tests

* Update cassettes

* Update cassettes again

* Update cassettes again 2

* Update cassettes again 3

* fix other test that fails in ci/cd

* Fix issues pointed out by lorenze
This commit is contained in:
Brandon Hancock (bhancock_ai)
2025-02-20 12:12:52 -05:00
committed by GitHub
parent 14503bc43b
commit e2ce65fc5b
6 changed files with 108 additions and 2913 deletions

View File

@@ -31,11 +31,11 @@ class OutputConverter(BaseModel, ABC):
)
@abstractmethod
def to_pydantic(self, current_attempt=1):
def to_pydantic(self, current_attempt=1) -> BaseModel:
"""Convert text to pydantic."""
pass
@abstractmethod
def to_json(self, current_attempt=1):
def to_json(self, current_attempt=1) -> dict:
"""Convert text to json."""
pass

View File

@@ -26,9 +26,9 @@ from crewai.utilities.events.tool_usage_events import ToolExecutionErrorEvent
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
import litellm
from litellm import Choices, get_supported_openai_params
from litellm import Choices
from litellm.types.utils import ModelResponse
from litellm.utils import supports_response_schema
from litellm.utils import get_supported_openai_params, supports_response_schema
from crewai.traces.unified_trace_controller import trace_llm_call
@@ -449,7 +449,7 @@ class LLM:
def supports_function_calling(self) -> bool:
try:
params = get_supported_openai_params(model=self.model)
return "response_format" in params
return params is not None and "tools" in params
except Exception as e:
logging.error(f"Failed to get supported params: {str(e)}")
return False
@@ -457,7 +457,7 @@ class LLM:
def supports_stop_words(self) -> bool:
try:
params = get_supported_openai_params(model=self.model)
return "stop" in params
return params is not None and "stop" in params
except Exception as e:
logging.error(f"Failed to get supported params: {str(e)}")
return False

View File

@@ -20,11 +20,11 @@ class ConverterError(Exception):
class Converter(OutputConverter):
"""Class that converts text into either pydantic or json."""
def to_pydantic(self, current_attempt=1):
def to_pydantic(self, current_attempt=1) -> BaseModel:
"""Convert text to pydantic."""
try:
if self.llm.supports_function_calling():
return self._create_instructor().to_pydantic()
result = self._create_instructor().to_pydantic()
else:
response = self.llm.call(
[
@@ -32,18 +32,40 @@ class Converter(OutputConverter):
{"role": "user", "content": self.text},
]
)
return self.model.model_validate_json(response)
try:
# Try to directly validate the response JSON
result = self.model.model_validate_json(response)
except ValidationError:
# If direct validation fails, attempt to extract valid JSON
result = handle_partial_json(response, self.model, False, None)
# Ensure result is a BaseModel instance
if not isinstance(result, BaseModel):
if isinstance(result, dict):
result = self.model.parse_obj(result)
elif isinstance(result, str):
try:
parsed = json.loads(result)
result = self.model.parse_obj(parsed)
except Exception as parse_err:
raise ConverterError(
f"Failed to convert partial JSON result into Pydantic: {parse_err}"
)
else:
raise ConverterError(
"handle_partial_json returned an unexpected type."
)
return result
except ValidationError as e:
if current_attempt < self.max_attempts:
return self.to_pydantic(current_attempt + 1)
raise ConverterError(
f"Failed to convert text into a Pydantic model due to the following validation error: {e}"
f"Failed to convert text into a Pydantic model due to validation error: {e}"
)
except Exception as e:
if current_attempt < self.max_attempts:
return self.to_pydantic(current_attempt + 1)
raise ConverterError(
f"Failed to convert text into a Pydantic model due to the following error: {e}"
f"Failed to convert text into a Pydantic model due to error: {e}"
)
def to_json(self, current_attempt=1):
@@ -194,14 +216,20 @@ def convert_with_instructions(
def get_conversion_instructions(model: Type[BaseModel], llm: Any) -> str:
instructions = "Please convert the following text into valid JSON."
print("Using function calling: ", llm.supports_function_calling())
if llm.supports_function_calling():
model_schema = PydanticSchemaParser(model=model).get_schema()
instructions += (
f"\n\nThe JSON should follow this schema:\n```json\n{model_schema}\n```"
f"\n\nOutput ONLY the valid JSON and nothing else.\n\n"
f"The JSON must follow this schema exactly:\n```json\n{model_schema}\n```"
)
else:
model_description = generate_model_description(model)
instructions += f"\n\nThe JSON should follow this format:\n{model_description}"
print("Model description: ", model_description)
instructions += (
f"\n\nOutput ONLY the valid JSON and nothing else.\n\n"
f"The JSON must follow this format exactly:\n{model_description}"
)
return instructions