mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-05 01:02:37 +00:00
fix: fixes pydantic parser to support nested objects.
This fails before the PR and succeeds afterward:
```python
from pydantic import BaseModel
from typing import List, Optional
from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser
class InnerModel(BaseModel):
inner_field: int
class TestModel(BaseModel):
simple_field: str
list_field: List[int]
optional_field: Optional[str]
nested_model: InnerModel
print(PydanticSchemaParser(model=InnerModel).get_schema()) # works
print(PydanticSchemaParser(model=TestModel).get_schema()) # fails
```
**Note.** Because the `main` branch currently doesn't support nested
schemas and the original code against which I made this PR months ago
drifted, I made a judgement call on how to format the nested
structurees. I chose the following, inferred from the current code:
````
>>> print(PydanticSchemaParser(model=InnerModel).get_schema())
{
inner_field: int
}
>>> print(PydanticSchemaParser(model=TestModel).get_schema())
{
simple_field: str,
list_field: List[int],
optional_field: Optional[str],
nested_model: InnerModel
{
inner_field: int
}
}
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
from typing import Type, get_args, get_origin, Union
|
||||
|
||||
from typing import Type, get_args, get_origin, List, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@@ -13,37 +12,55 @@ class PydanticSchemaParser(BaseModel):
|
||||
:param model: The Pydantic model class to generate schema for.
|
||||
:return: String representation of the model schema.
|
||||
"""
|
||||
return self._get_model_schema(self.model)
|
||||
return "{\n" + self._get_model_schema(self.model) + "\n}"
|
||||
|
||||
def _get_model_schema(self, model, depth=0) -> str:
|
||||
indent = " " * depth
|
||||
lines = [f"{indent}{{"]
|
||||
for field_name, field in model.model_fields.items():
|
||||
field_type_str = self._get_field_type(field, depth + 1)
|
||||
lines.append(f"{indent} {field_name}: {field_type_str},")
|
||||
lines[-1] = lines[-1].rstrip(",") # Remove trailing comma from last item
|
||||
lines.append(f"{indent}}}")
|
||||
return "\n".join(lines)
|
||||
def _get_model_schema(self, model: Type[BaseModel], depth: int = 0) -> str:
|
||||
indent = " " * 4 * depth
|
||||
lines = [
|
||||
f"{indent} {field_name}: {self._get_field_type(field, depth + 1)}"
|
||||
for field_name, field in model.model_fields.items()
|
||||
]
|
||||
return ",\n".join(lines)
|
||||
|
||||
def _get_field_type(self, field, depth) -> str:
|
||||
def _get_field_type(self, field, depth: int) -> str:
|
||||
field_type = field.annotation
|
||||
if get_origin(field_type) is list:
|
||||
origin = get_origin(field_type)
|
||||
|
||||
if origin in {list, List}:
|
||||
list_item_type = get_args(field_type)[0]
|
||||
if isinstance(list_item_type, type) and issubclass(
|
||||
list_item_type, BaseModel
|
||||
):
|
||||
nested_schema = self._get_model_schema(list_item_type, depth + 1)
|
||||
return f"List[\n{nested_schema}\n{' ' * 4 * depth}]"
|
||||
else:
|
||||
return f"List[{list_item_type.__name__}]"
|
||||
elif get_origin(field_type) is Union:
|
||||
union_args = get_args(field_type)
|
||||
if type(None) in union_args:
|
||||
non_none_type = next(arg for arg in union_args if arg is not type(None))
|
||||
return f"Optional[{self._get_field_type(field.__class__(annotation=non_none_type), depth)}]"
|
||||
else:
|
||||
return f"Union[{', '.join(arg.__name__ for arg in union_args)}]"
|
||||
elif isinstance(field_type, type) and issubclass(field_type, BaseModel):
|
||||
return self._get_model_schema(field_type, depth)
|
||||
else:
|
||||
return getattr(field_type, "__name__", str(field_type))
|
||||
return self._format_list_type(list_item_type, depth)
|
||||
|
||||
if origin is Union:
|
||||
return self._format_union_type(field_type, depth)
|
||||
|
||||
if isinstance(field_type, type) and issubclass(field_type, BaseModel):
|
||||
nested_schema = self._get_model_schema(field_type, depth)
|
||||
nested_indent = " " * 4 * (depth)
|
||||
return f"{field_type.__name__}\n{nested_indent}{{\n{nested_schema}\n{nested_indent}}}"
|
||||
|
||||
return field_type.__name__
|
||||
|
||||
def _format_list_type(self, list_item_type, depth: int) -> str:
|
||||
if isinstance(list_item_type, type) and issubclass(list_item_type, BaseModel):
|
||||
nested_schema = self._get_model_schema(list_item_type, depth + 1)
|
||||
nested_indent = " " * 4 * (depth - 1)
|
||||
return f"List[\n{nested_schema}\n{nested_indent}]"
|
||||
return f"List[{list_item_type.__name__}]"
|
||||
|
||||
def _format_union_type(self, field_type, depth: int) -> str:
|
||||
args = get_args(field_type)
|
||||
non_none_type = next(arg for arg in args if arg is not type(None))
|
||||
return f"Optional[{self._get_field_type_for_annotation(non_none_type, depth)}]"
|
||||
|
||||
def _get_field_type_for_annotation(self, annotation, depth: int) -> str:
|
||||
origin = get_origin(annotation)
|
||||
if origin in {list, List}:
|
||||
list_item_type = get_args(annotation)[0]
|
||||
return self._format_list_type(list_item_type, depth)
|
||||
|
||||
if isinstance(annotation, type) and issubclass(annotation, BaseModel):
|
||||
nested_schema = self._get_model_schema(annotation, depth)
|
||||
nested_indent = " " * 4 * (depth - 1)
|
||||
return f"{annotation.__name__}[\n{nested_schema}\n{nested_indent}]"
|
||||
|
||||
return annotation.__name__
|
||||
|
||||
73
tests/utilities/test_pydantic_schema_parser.py
Normal file
73
tests/utilities/test_pydantic_schema_parser.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import unittest
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
|
||||
from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser
|
||||
|
||||
|
||||
# Define test models
|
||||
class InnerModel(BaseModel):
|
||||
inner_field: int
|
||||
|
||||
|
||||
class OuterModel(BaseModel):
|
||||
simple_field: str
|
||||
list_field: List[int]
|
||||
optional_field: Optional[str]
|
||||
nested_model: InnerModel
|
||||
|
||||
|
||||
# Test cases
|
||||
class TestPydanticSchemaParser(unittest.TestCase):
|
||||
def test_simple_model(self):
|
||||
class SimpleModel(BaseModel):
|
||||
field1: int
|
||||
field2: str
|
||||
|
||||
parser = PydanticSchemaParser(model=SimpleModel)
|
||||
expected_schema = """
|
||||
{
|
||||
field1: int,
|
||||
field2: str
|
||||
}""".strip()
|
||||
self.assertEqual(parser.get_schema(), expected_schema)
|
||||
|
||||
def test_model_with_list(self):
|
||||
class ListModel(BaseModel):
|
||||
field1: List[int]
|
||||
|
||||
parser = PydanticSchemaParser(model=ListModel)
|
||||
expected_schema = """
|
||||
{
|
||||
field1: List[int]
|
||||
}""".strip()
|
||||
self.assertEqual(parser.get_schema(), expected_schema)
|
||||
|
||||
def test_model_with_optional(self):
|
||||
class OptionalModel(BaseModel):
|
||||
field1: Optional[int]
|
||||
|
||||
parser = PydanticSchemaParser(model=OptionalModel)
|
||||
expected_schema = """
|
||||
{
|
||||
field1: Optional[int]
|
||||
}""".strip()
|
||||
self.assertEqual(parser.get_schema(), expected_schema)
|
||||
|
||||
def test_nested_model(self):
|
||||
parser = PydanticSchemaParser(model=OuterModel)
|
||||
expected_schema = """
|
||||
{
|
||||
simple_field: str,
|
||||
list_field: List[int],
|
||||
optional_field: Optional[str],
|
||||
nested_model: InnerModel
|
||||
{
|
||||
inner_field: int
|
||||
}
|
||||
}""".strip()
|
||||
self.assertEqual(parser.get_schema(), expected_schema)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user