Fix pydantic_schema_parser to handle Python 3.10+ union syntax (types.UnionType)

- Add support for types.UnionType in addition to typing.Union
- Fix AttributeError when processing str | None syntax
- Add comprehensive tests for both union syntaxes
- Handle types without __name__ attribute gracefully
- Resolves issue #3074

Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
Devin AI
2025-06-26 16:26:03 +00:00
parent b09796cd3f
commit b4b6e0d803
2 changed files with 47 additions and 4 deletions

View File

@@ -1,3 +1,4 @@
import types
from typing import Dict, List, Type, Union, get_args, get_origin
from pydantic import BaseModel
@@ -34,7 +35,7 @@ class PydanticSchemaParser(BaseModel):
key_type, value_type = get_args(field_type)
return f"Dict[{key_type.__name__}, {value_type.__name__}]"
if origin is Union:
if origin is Union or (origin is None and len(get_args(field_type)) > 0):
return self._format_union_type(field_type, depth)
if isinstance(field_type, type) and issubclass(field_type, BaseModel):
@@ -42,7 +43,10 @@ class PydanticSchemaParser(BaseModel):
nested_indent = " " * 4 * depth
return f"{field_type.__name__}\n{nested_indent}{{\n{nested_schema}\n{nested_indent}}}"
return field_type.__name__
if hasattr(field_type, '__name__'):
return field_type.__name__
else:
return str(field_type)
def _format_list_type(self, list_item_type, depth: int) -> str:
if isinstance(list_item_type, type) and issubclass(list_item_type, BaseModel):
@@ -83,10 +87,13 @@ class PydanticSchemaParser(BaseModel):
if origin in {dict, Dict}:
key_type, value_type = get_args(annotation)
return f"Dict[{key_type.__name__}, {value_type.__name__}]"
if origin is Union:
if origin is Union or (origin is None and len(get_args(annotation)) > 0):
return self._format_union_type(annotation, depth)
if isinstance(annotation, type) and issubclass(annotation, BaseModel):
nested_schema = self._get_model_schema(annotation, depth)
nested_indent = " " * 4 * depth
return f"{annotation.__name__}\n{nested_indent}{{\n{nested_schema}\n{nested_indent}}}"
return annotation.__name__
if hasattr(annotation, '__name__'):
return annotation.__name__
else:
return str(annotation)

View File

@@ -92,3 +92,39 @@ def test_model_with_dict():
dict_field: Dict[str, int]
}"""
assert schema.strip() == expected_schema.strip()
def test_model_with_python310_union_syntax():
class UnionTypeModel(BaseModel):
union_field: str | None
multi_union_field: int | str | None
non_optional_union: int | str
parser = PydanticSchemaParser(model=UnionTypeModel)
schema = parser.get_schema()
expected_schema = """{
union_field: Optional[str],
multi_union_field: Optional[Union[int, str]],
non_optional_union: Union[int, str]
}"""
assert schema.strip() == expected_schema.strip()
def test_mixed_union_syntax():
class MixedUnionModel(BaseModel):
traditional_optional: Optional[str]
new_union_syntax: str | None
traditional_union: Union[int, str]
new_multi_union: int | str | float
parser = PydanticSchemaParser(model=MixedUnionModel)
schema = parser.get_schema()
expected_schema = """{
traditional_optional: Optional[str],
new_union_syntax: Optional[str],
traditional_union: Union[int, str],
new_multi_union: Union[int, str, float]
}"""
assert schema.strip() == expected_schema.strip()