From af0ed9ad530c8d3c344151f329fdc9cb853d3e38 Mon Sep 17 00:00:00 2001 From: Jason Vertrees <1031738+inchoate@users.noreply.github.com> Date: Mon, 9 Dec 2024 15:41:19 -0600 Subject: [PATCH] fix: fixes pydantic parser to support nested objects. This fails before the PR and succeeds afterward: ```python from pydantic import BaseModel from typing import List, Optional from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser class InnerModel(BaseModel): inner_field: int class TestModel(BaseModel): simple_field: str list_field: List[int] optional_field: Optional[str] nested_model: InnerModel print(PydanticSchemaParser(model=InnerModel).get_schema()) # works print(PydanticSchemaParser(model=TestModel).get_schema()) # fails ``` **Note.** Because the `main` branch currently doesn't support nested schemas and the original code against which I made this PR months ago drifted, I made a judgement call on how to format the nested structurees. I chose the following, inferred from the current code: ```` >>> print(PydanticSchemaParser(model=InnerModel).get_schema()) { inner_field: int } >>> print(PydanticSchemaParser(model=TestModel).get_schema()) { simple_field: str, list_field: List[int], optional_field: Optional[str], nested_model: InnerModel { inner_field: int } } --- .../utilities/pydantic_schema_parser.py | 81 +++++++++++-------- .../utilities/test_pydantic_schema_parser.py | 73 +++++++++++++++++ 2 files changed, 122 insertions(+), 32 deletions(-) create mode 100644 tests/utilities/test_pydantic_schema_parser.py diff --git a/src/crewai/utilities/pydantic_schema_parser.py b/src/crewai/utilities/pydantic_schema_parser.py index 073280dd2..757ba3edd 100644 --- a/src/crewai/utilities/pydantic_schema_parser.py +++ b/src/crewai/utilities/pydantic_schema_parser.py @@ -1,5 +1,4 @@ -from typing import Type, get_args, get_origin, Union - +from typing import Type, get_args, get_origin, List, Union from pydantic import BaseModel @@ -13,37 +12,55 @@ class PydanticSchemaParser(BaseModel): :param model: The Pydantic model class to generate schema for. :return: String representation of the model schema. """ - return self._get_model_schema(self.model) + return "{\n" + self._get_model_schema(self.model) + "\n}" - def _get_model_schema(self, model, depth=0) -> str: - indent = " " * depth - lines = [f"{indent}{{"] - for field_name, field in model.model_fields.items(): - field_type_str = self._get_field_type(field, depth + 1) - lines.append(f"{indent} {field_name}: {field_type_str},") - lines[-1] = lines[-1].rstrip(",") # Remove trailing comma from last item - lines.append(f"{indent}}}") - return "\n".join(lines) + def _get_model_schema(self, model: Type[BaseModel], depth: int = 0) -> str: + indent = " " * 4 * depth + lines = [ + f"{indent} {field_name}: {self._get_field_type(field, depth + 1)}" + for field_name, field in model.model_fields.items() + ] + return ",\n".join(lines) - def _get_field_type(self, field, depth) -> str: + def _get_field_type(self, field, depth: int) -> str: field_type = field.annotation - if get_origin(field_type) is list: + origin = get_origin(field_type) + + if origin in {list, List}: list_item_type = get_args(field_type)[0] - if isinstance(list_item_type, type) and issubclass( - list_item_type, BaseModel - ): - nested_schema = self._get_model_schema(list_item_type, depth + 1) - return f"List[\n{nested_schema}\n{' ' * 4 * depth}]" - else: - return f"List[{list_item_type.__name__}]" - elif get_origin(field_type) is Union: - union_args = get_args(field_type) - if type(None) in union_args: - non_none_type = next(arg for arg in union_args if arg is not type(None)) - return f"Optional[{self._get_field_type(field.__class__(annotation=non_none_type), depth)}]" - else: - return f"Union[{', '.join(arg.__name__ for arg in union_args)}]" - elif isinstance(field_type, type) and issubclass(field_type, BaseModel): - return self._get_model_schema(field_type, depth) - else: - return getattr(field_type, "__name__", str(field_type)) + return self._format_list_type(list_item_type, depth) + + if origin is Union: + return self._format_union_type(field_type, depth) + + if isinstance(field_type, type) and issubclass(field_type, BaseModel): + nested_schema = self._get_model_schema(field_type, depth) + nested_indent = " " * 4 * (depth) + return f"{field_type.__name__}\n{nested_indent}{{\n{nested_schema}\n{nested_indent}}}" + + return field_type.__name__ + + def _format_list_type(self, list_item_type, depth: int) -> str: + if isinstance(list_item_type, type) and issubclass(list_item_type, BaseModel): + nested_schema = self._get_model_schema(list_item_type, depth + 1) + nested_indent = " " * 4 * (depth - 1) + return f"List[\n{nested_schema}\n{nested_indent}]" + return f"List[{list_item_type.__name__}]" + + def _format_union_type(self, field_type, depth: int) -> str: + args = get_args(field_type) + non_none_type = next(arg for arg in args if arg is not type(None)) + return f"Optional[{self._get_field_type_for_annotation(non_none_type, depth)}]" + + def _get_field_type_for_annotation(self, annotation, depth: int) -> str: + origin = get_origin(annotation) + if origin in {list, List}: + list_item_type = get_args(annotation)[0] + return self._format_list_type(list_item_type, depth) + + if isinstance(annotation, type) and issubclass(annotation, BaseModel): + nested_schema = self._get_model_schema(annotation, depth) + nested_indent = " " * 4 * (depth - 1) + return f"{annotation.__name__}[\n{nested_schema}\n{nested_indent}]" + + return annotation.__name__ diff --git a/tests/utilities/test_pydantic_schema_parser.py b/tests/utilities/test_pydantic_schema_parser.py new file mode 100644 index 000000000..f49a07595 --- /dev/null +++ b/tests/utilities/test_pydantic_schema_parser.py @@ -0,0 +1,73 @@ +import unittest +from pydantic import BaseModel +from typing import List, Optional + +from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser + + +# Define test models +class InnerModel(BaseModel): + inner_field: int + + +class OuterModel(BaseModel): + simple_field: str + list_field: List[int] + optional_field: Optional[str] + nested_model: InnerModel + + +# Test cases +class TestPydanticSchemaParser(unittest.TestCase): + def test_simple_model(self): + class SimpleModel(BaseModel): + field1: int + field2: str + + parser = PydanticSchemaParser(model=SimpleModel) + expected_schema = """ +{ + field1: int, + field2: str +}""".strip() + self.assertEqual(parser.get_schema(), expected_schema) + + def test_model_with_list(self): + class ListModel(BaseModel): + field1: List[int] + + parser = PydanticSchemaParser(model=ListModel) + expected_schema = """ +{ + field1: List[int] +}""".strip() + self.assertEqual(parser.get_schema(), expected_schema) + + def test_model_with_optional(self): + class OptionalModel(BaseModel): + field1: Optional[int] + + parser = PydanticSchemaParser(model=OptionalModel) + expected_schema = """ +{ + field1: Optional[int] +}""".strip() + self.assertEqual(parser.get_schema(), expected_schema) + + def test_nested_model(self): + parser = PydanticSchemaParser(model=OuterModel) + expected_schema = """ +{ + simple_field: str, + list_field: List[int], + optional_field: Optional[str], + nested_model: InnerModel + { + inner_field: int + } +}""".strip() + self.assertEqual(parser.get_schema(), expected_schema) + + +if __name__ == "__main__": + unittest.main()