mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
Fix pydantic_schema_parser to handle Python 3.10+ union syntax (types.UnionType)
- Add support for types.UnionType in addition to typing.Union - Fix AttributeError when processing str | None syntax - Add comprehensive tests for both union syntaxes - Handle types without __name__ attribute gracefully - Resolves issue #3074 Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
import types
|
||||||
from typing import Dict, List, Type, Union, get_args, get_origin
|
from typing import Dict, List, Type, Union, get_args, get_origin
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -34,7 +35,7 @@ class PydanticSchemaParser(BaseModel):
|
|||||||
key_type, value_type = get_args(field_type)
|
key_type, value_type = get_args(field_type)
|
||||||
return f"Dict[{key_type.__name__}, {value_type.__name__}]"
|
return f"Dict[{key_type.__name__}, {value_type.__name__}]"
|
||||||
|
|
||||||
if origin is Union:
|
if origin is Union or (origin is None and len(get_args(field_type)) > 0):
|
||||||
return self._format_union_type(field_type, depth)
|
return self._format_union_type(field_type, depth)
|
||||||
|
|
||||||
if isinstance(field_type, type) and issubclass(field_type, BaseModel):
|
if isinstance(field_type, type) and issubclass(field_type, BaseModel):
|
||||||
@@ -42,7 +43,10 @@ class PydanticSchemaParser(BaseModel):
|
|||||||
nested_indent = " " * 4 * depth
|
nested_indent = " " * 4 * depth
|
||||||
return f"{field_type.__name__}\n{nested_indent}{{\n{nested_schema}\n{nested_indent}}}"
|
return f"{field_type.__name__}\n{nested_indent}{{\n{nested_schema}\n{nested_indent}}}"
|
||||||
|
|
||||||
return field_type.__name__
|
if hasattr(field_type, '__name__'):
|
||||||
|
return field_type.__name__
|
||||||
|
else:
|
||||||
|
return str(field_type)
|
||||||
|
|
||||||
def _format_list_type(self, list_item_type, depth: int) -> str:
|
def _format_list_type(self, list_item_type, depth: int) -> str:
|
||||||
if isinstance(list_item_type, type) and issubclass(list_item_type, BaseModel):
|
if isinstance(list_item_type, type) and issubclass(list_item_type, BaseModel):
|
||||||
@@ -83,10 +87,13 @@ class PydanticSchemaParser(BaseModel):
|
|||||||
if origin in {dict, Dict}:
|
if origin in {dict, Dict}:
|
||||||
key_type, value_type = get_args(annotation)
|
key_type, value_type = get_args(annotation)
|
||||||
return f"Dict[{key_type.__name__}, {value_type.__name__}]"
|
return f"Dict[{key_type.__name__}, {value_type.__name__}]"
|
||||||
if origin is Union:
|
if origin is Union or (origin is None and len(get_args(annotation)) > 0):
|
||||||
return self._format_union_type(annotation, depth)
|
return self._format_union_type(annotation, depth)
|
||||||
if isinstance(annotation, type) and issubclass(annotation, BaseModel):
|
if isinstance(annotation, type) and issubclass(annotation, BaseModel):
|
||||||
nested_schema = self._get_model_schema(annotation, depth)
|
nested_schema = self._get_model_schema(annotation, depth)
|
||||||
nested_indent = " " * 4 * depth
|
nested_indent = " " * 4 * depth
|
||||||
return f"{annotation.__name__}\n{nested_indent}{{\n{nested_schema}\n{nested_indent}}}"
|
return f"{annotation.__name__}\n{nested_indent}{{\n{nested_schema}\n{nested_indent}}}"
|
||||||
return annotation.__name__
|
if hasattr(annotation, '__name__'):
|
||||||
|
return annotation.__name__
|
||||||
|
else:
|
||||||
|
return str(annotation)
|
||||||
|
|||||||
@@ -92,3 +92,39 @@ def test_model_with_dict():
|
|||||||
dict_field: Dict[str, int]
|
dict_field: Dict[str, int]
|
||||||
}"""
|
}"""
|
||||||
assert schema.strip() == expected_schema.strip()
|
assert schema.strip() == expected_schema.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_with_python310_union_syntax():
|
||||||
|
class UnionTypeModel(BaseModel):
|
||||||
|
union_field: str | None
|
||||||
|
multi_union_field: int | str | None
|
||||||
|
non_optional_union: int | str
|
||||||
|
|
||||||
|
parser = PydanticSchemaParser(model=UnionTypeModel)
|
||||||
|
schema = parser.get_schema()
|
||||||
|
|
||||||
|
expected_schema = """{
|
||||||
|
union_field: Optional[str],
|
||||||
|
multi_union_field: Optional[Union[int, str]],
|
||||||
|
non_optional_union: Union[int, str]
|
||||||
|
}"""
|
||||||
|
assert schema.strip() == expected_schema.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def test_mixed_union_syntax():
|
||||||
|
class MixedUnionModel(BaseModel):
|
||||||
|
traditional_optional: Optional[str]
|
||||||
|
new_union_syntax: str | None
|
||||||
|
traditional_union: Union[int, str]
|
||||||
|
new_multi_union: int | str | float
|
||||||
|
|
||||||
|
parser = PydanticSchemaParser(model=MixedUnionModel)
|
||||||
|
schema = parser.get_schema()
|
||||||
|
|
||||||
|
expected_schema = """{
|
||||||
|
traditional_optional: Optional[str],
|
||||||
|
new_union_syntax: Optional[str],
|
||||||
|
traditional_union: Union[int, str],
|
||||||
|
new_multi_union: Union[int, str, float]
|
||||||
|
}"""
|
||||||
|
assert schema.strip() == expected_schema.strip()
|
||||||
|
|||||||
Reference in New Issue
Block a user