diff --git a/src/crewai/utilities/pydantic_schema_parser.py b/src/crewai/utilities/pydantic_schema_parser.py index 2827d70aa..65b180337 100644 --- a/src/crewai/utilities/pydantic_schema_parser.py +++ b/src/crewai/utilities/pydantic_schema_parser.py @@ -1,3 +1,4 @@ +import types from typing import Dict, List, Type, Union, get_args, get_origin from pydantic import BaseModel @@ -34,7 +35,7 @@ class PydanticSchemaParser(BaseModel): key_type, value_type = get_args(field_type) return f"Dict[{key_type.__name__}, {value_type.__name__}]" - if origin is Union: + if origin is Union or (origin is None and len(get_args(field_type)) > 0): return self._format_union_type(field_type, depth) if isinstance(field_type, type) and issubclass(field_type, BaseModel): @@ -42,7 +43,10 @@ class PydanticSchemaParser(BaseModel): nested_indent = " " * 4 * depth return f"{field_type.__name__}\n{nested_indent}{{\n{nested_schema}\n{nested_indent}}}" - return field_type.__name__ + if hasattr(field_type, '__name__'): + return field_type.__name__ + else: + return str(field_type) def _format_list_type(self, list_item_type, depth: int) -> str: if isinstance(list_item_type, type) and issubclass(list_item_type, BaseModel): @@ -83,10 +87,13 @@ class PydanticSchemaParser(BaseModel): if origin in {dict, Dict}: key_type, value_type = get_args(annotation) return f"Dict[{key_type.__name__}, {value_type.__name__}]" - if origin is Union: + if origin is Union or (origin is None and len(get_args(annotation)) > 0): return self._format_union_type(annotation, depth) if isinstance(annotation, type) and issubclass(annotation, BaseModel): nested_schema = self._get_model_schema(annotation, depth) nested_indent = " " * 4 * depth return f"{annotation.__name__}\n{nested_indent}{{\n{nested_schema}\n{nested_indent}}}" - return annotation.__name__ + if hasattr(annotation, '__name__'): + return annotation.__name__ + else: + return str(annotation) diff --git a/tests/utilities/test_pydantic_schema_parser.py b/tests/utilities/test_pydantic_schema_parser.py index ee6d7e287..5ac8e6b58 100644 --- a/tests/utilities/test_pydantic_schema_parser.py +++ b/tests/utilities/test_pydantic_schema_parser.py @@ -92,3 +92,39 @@ def test_model_with_dict(): dict_field: Dict[str, int] }""" assert schema.strip() == expected_schema.strip() + + +def test_model_with_python310_union_syntax(): + class UnionTypeModel(BaseModel): + union_field: str | None + multi_union_field: int | str | None + non_optional_union: int | str + + parser = PydanticSchemaParser(model=UnionTypeModel) + schema = parser.get_schema() + + expected_schema = """{ + union_field: Optional[str], + multi_union_field: Optional[Union[int, str]], + non_optional_union: Union[int, str] +}""" + assert schema.strip() == expected_schema.strip() + + +def test_mixed_union_syntax(): + class MixedUnionModel(BaseModel): + traditional_optional: Optional[str] + new_union_syntax: str | None + traditional_union: Union[int, str] + new_multi_union: int | str | float + + parser = PydanticSchemaParser(model=MixedUnionModel) + schema = parser.get_schema() + + expected_schema = """{ + traditional_optional: Optional[str], + new_union_syntax: Optional[str], + traditional_union: Union[int, str], + new_multi_union: Union[int, str, float] +}""" + assert schema.strip() == expected_schema.strip()