mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
Patch/non gpt model pydantic output (#1003)
* patching for non-gpt model * removal of json_object tool name assignment * fixed issue for smaller models due to instructions prompt * fixing for ollama llama3 models * closing brackets * removed not used and fixes
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Any, List, Type, Union
|
from typing import Any, List, Type
|
||||||
|
|
||||||
import regex
|
import regex
|
||||||
from langchain.output_parsers import PydanticOutputParser
|
from langchain.output_parsers import PydanticOutputParser
|
||||||
@@ -7,29 +7,24 @@ from langchain_core.exceptions import OutputParserException
|
|||||||
from langchain_core.outputs import Generation
|
from langchain_core.outputs import Generation
|
||||||
from langchain_core.pydantic_v1 import ValidationError
|
from langchain_core.pydantic_v1 import ValidationError
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from pydantic.v1 import BaseModel as V1BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class CrewPydanticOutputParser(PydanticOutputParser):
|
class CrewPydanticOutputParser(PydanticOutputParser):
|
||||||
"""Parses the text into pydantic models"""
|
"""Parses the text into pydantic models"""
|
||||||
|
|
||||||
pydantic_object: Union[Type[BaseModel], Type[V1BaseModel]]
|
pydantic_object: Type[BaseModel]
|
||||||
|
|
||||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
def parse_result(self, result: List[Generation]) -> Any:
|
||||||
result[0].text = self._transform_in_valid_json(result[0].text)
|
result[0].text = self._transform_in_valid_json(result[0].text)
|
||||||
|
|
||||||
# Treating edge case of function calling llm returning the name instead of tool_name
|
# Treating edge case of function calling llm returning the name instead of tool_name
|
||||||
json_object = json.loads(result[0].text)
|
json_object = json.loads(result[0].text)
|
||||||
json_object["tool_name"] = (
|
if "tool_name" not in json_object:
|
||||||
json_object["name"]
|
json_object["tool_name"] = json_object.get("name", "")
|
||||||
if "tool_name" not in json_object
|
|
||||||
else json_object["tool_name"]
|
|
||||||
)
|
|
||||||
result[0].text = json.dumps(json_object)
|
result[0].text = json.dumps(json_object)
|
||||||
|
|
||||||
json_object = super().parse_result(result)
|
|
||||||
try:
|
try:
|
||||||
return self.pydantic_object.parse_obj(json_object)
|
return self.pydantic_object.model_validate(json_object)
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
name = self.pydantic_object.__name__
|
name = self.pydantic_object.__name__
|
||||||
msg = f"Failed to parse {name} from completion {json_object}. Got: {e}"
|
msg = f"Failed to parse {name} from completion {json_object}. Got: {e}"
|
||||||
|
|||||||
@@ -66,11 +66,11 @@ class TaskEvaluator:
|
|||||||
"- Entities extracted from the task output, if any, their type, description, and relationships"
|
"- Entities extracted from the task output, if any, their type, description, and relationships"
|
||||||
)
|
)
|
||||||
|
|
||||||
instructions = "I'm gonna convert this raw text into valid JSON."
|
instructions = "Convert all responses into valid JSON output."
|
||||||
|
|
||||||
if not self._is_gpt(self.llm):
|
if not self._is_gpt(self.llm):
|
||||||
model_schema = PydanticSchemaParser(model=TaskEvaluation).get_schema()
|
model_schema = PydanticSchemaParser(model=TaskEvaluation).get_schema()
|
||||||
instructions = f"{instructions}\n\nThe json should have the following structure, with the following keys:\n{model_schema}"
|
instructions = f"{instructions}\n\nReturn only valid JSON with the following schema:\n```json\n{model_schema}\n```"
|
||||||
|
|
||||||
converter = Converter(
|
converter = Converter(
|
||||||
llm=self.llm,
|
llm=self.llm,
|
||||||
|
|||||||
@@ -16,11 +16,13 @@ class PydanticSchemaParser(BaseModel):
|
|||||||
return self._get_model_schema(self.model)
|
return self._get_model_schema(self.model)
|
||||||
|
|
||||||
def _get_model_schema(self, model, depth=0) -> str:
|
def _get_model_schema(self, model, depth=0) -> str:
|
||||||
lines = []
|
indent = " " * depth
|
||||||
|
lines = [f"{indent}{{"]
|
||||||
for field_name, field in model.model_fields.items():
|
for field_name, field in model.model_fields.items():
|
||||||
field_type_str = self._get_field_type(field, depth + 1)
|
field_type_str = self._get_field_type(field, depth + 1)
|
||||||
lines.append(f"{' ' * 4 * depth}- {field_name}: {field_type_str}")
|
lines.append(f"{indent} {field_name}: {field_type_str},")
|
||||||
|
lines[-1] = lines[-1].rstrip(",") # Remove trailing comma from last item
|
||||||
|
lines.append(f"{indent}}}")
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|
||||||
def _get_field_type(self, field, depth) -> str:
|
def _get_field_type(self, field, depth) -> str:
|
||||||
@@ -35,6 +37,6 @@ class PydanticSchemaParser(BaseModel):
|
|||||||
else:
|
else:
|
||||||
return f"List[{list_item_type.__name__}]"
|
return f"List[{list_item_type.__name__}]"
|
||||||
elif issubclass(field_type, BaseModel):
|
elif issubclass(field_type, BaseModel):
|
||||||
return f"\n{self._get_model_schema(field_type, depth)}"
|
return self._get_model_schema(field_type, depth)
|
||||||
else:
|
else:
|
||||||
return field_type.__name__
|
return field_type.__name__
|
||||||
|
|||||||
Reference in New Issue
Block a user