From 5ac7050f7af17c4f91a7fd186882939c64728bd8 Mon Sep 17 00:00:00 2001 From: Lorenze Jay <63378463+lorenzejay@users.noreply.github.com> Date: Fri, 26 Jul 2024 10:57:56 -0700 Subject: [PATCH] Patch/non gpt model pydantic output (#1003) * patching for non-gpt model * removal of json_object tool name assignment * fixed issue for smaller models due to instructions prompt * fixing for ollama llama3 models * closing brackets * removed not used and fixes --- .../utilities/crew_pydantic_output_parser.py | 17 ++++++----------- .../utilities/evaluators/task_evaluator.py | 4 ++-- src/crewai/utilities/pydantic_schema_parser.py | 10 ++++++---- 3 files changed, 14 insertions(+), 17 deletions(-) diff --git a/src/crewai/utilities/crew_pydantic_output_parser.py b/src/crewai/utilities/crew_pydantic_output_parser.py index 54025d5e3..f4e9cdd18 100644 --- a/src/crewai/utilities/crew_pydantic_output_parser.py +++ b/src/crewai/utilities/crew_pydantic_output_parser.py @@ -1,5 +1,5 @@ import json -from typing import Any, List, Type, Union +from typing import Any, List, Type import regex from langchain.output_parsers import PydanticOutputParser @@ -7,29 +7,24 @@ from langchain_core.exceptions import OutputParserException from langchain_core.outputs import Generation from langchain_core.pydantic_v1 import ValidationError from pydantic import BaseModel -from pydantic.v1 import BaseModel as V1BaseModel class CrewPydanticOutputParser(PydanticOutputParser): """Parses the text into pydantic models""" - pydantic_object: Union[Type[BaseModel], Type[V1BaseModel]] + pydantic_object: Type[BaseModel] - def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: + def parse_result(self, result: List[Generation]) -> Any: result[0].text = self._transform_in_valid_json(result[0].text) # Treating edge case of function calling llm returning the name instead of tool_name json_object = json.loads(result[0].text) - json_object["tool_name"] = ( - json_object["name"] - if "tool_name" not in json_object - else json_object["tool_name"] - ) + if "tool_name" not in json_object: + json_object["tool_name"] = json_object.get("name", "") result[0].text = json.dumps(json_object) - json_object = super().parse_result(result) try: - return self.pydantic_object.parse_obj(json_object) + return self.pydantic_object.model_validate(json_object) except ValidationError as e: name = self.pydantic_object.__name__ msg = f"Failed to parse {name} from completion {json_object}. Got: {e}" diff --git a/src/crewai/utilities/evaluators/task_evaluator.py b/src/crewai/utilities/evaluators/task_evaluator.py index 04983b07c..fdb42e125 100644 --- a/src/crewai/utilities/evaluators/task_evaluator.py +++ b/src/crewai/utilities/evaluators/task_evaluator.py @@ -66,11 +66,11 @@ class TaskEvaluator: "- Entities extracted from the task output, if any, their type, description, and relationships" ) - instructions = "I'm gonna convert this raw text into valid JSON." + instructions = "Convert all responses into valid JSON output." if not self._is_gpt(self.llm): model_schema = PydanticSchemaParser(model=TaskEvaluation).get_schema() - instructions = f"{instructions}\n\nThe json should have the following structure, with the following keys:\n{model_schema}" + instructions = f"{instructions}\n\nReturn only valid JSON with the following schema:\n```json\n{model_schema}\n```" converter = Converter( llm=self.llm, diff --git a/src/crewai/utilities/pydantic_schema_parser.py b/src/crewai/utilities/pydantic_schema_parser.py index aa6b18b0e..9d9cdabe8 100644 --- a/src/crewai/utilities/pydantic_schema_parser.py +++ b/src/crewai/utilities/pydantic_schema_parser.py @@ -16,11 +16,13 @@ class PydanticSchemaParser(BaseModel): return self._get_model_schema(self.model) def _get_model_schema(self, model, depth=0) -> str: - lines = [] + indent = " " * depth + lines = [f"{indent}{{"] for field_name, field in model.model_fields.items(): field_type_str = self._get_field_type(field, depth + 1) - lines.append(f"{' ' * 4 * depth}- {field_name}: {field_type_str}") - + lines.append(f"{indent} {field_name}: {field_type_str},") + lines[-1] = lines[-1].rstrip(",") # Remove trailing comma from last item + lines.append(f"{indent}}}") return "\n".join(lines) def _get_field_type(self, field, depth) -> str: @@ -35,6 +37,6 @@ class PydanticSchemaParser(BaseModel): else: return f"List[{list_item_type.__name__}]" elif issubclass(field_type, BaseModel): - return f"\n{self._get_model_schema(field_type, depth)}" + return self._get_model_schema(field_type, depth) else: return field_type.__name__