Fix Union type handling in Pydantic output generation

- Fixed generate_model_description to correctly handle Union types
- Union types without None are now properly formatted as Union[type1, type2]
- Union types with None are correctly wrapped in Optional[Union[...]]
- Added support for Python 3.10+ pipe syntax (int | str | None)
- Added comprehensive tests for Union type support
- Updated existing test expectations to match corrected behavior

This fixes issue #3735 where Union types were incorrectly wrapped in Optional
even when None was not part of the Union.

Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
Devin AI
2025-10-19 21:33:03 +00:00
parent 42f2b4d551
commit f76c55ffa7
3 changed files with 220 additions and 5 deletions

View File

@@ -2,11 +2,18 @@ from __future__ import annotations
import json
import re
import sys
from typing import TYPE_CHECKING, Any, Final, TypedDict, Union, get_args, get_origin
from pydantic import BaseModel, ValidationError
from typing_extensions import Unpack
if sys.version_info >= (3, 10):
import types
UnionTypes = (Union, types.UnionType)
else:
UnionTypes = (Union,)
from crewai.agents.agent_builder.utilities.base_output_converter import OutputConverter
from crewai.utilities.internal_instructor import InternalInstructor
from crewai.utilities.printer import Printer
@@ -428,12 +435,21 @@ def generate_model_description(model: type[BaseModel]) -> str:
origin = get_origin(field_type)
args = get_args(field_type)
if origin is Union or (origin is None and len(args) > 0):
if origin in UnionTypes or (origin is None and len(args) > 0):
# Handle both Union and the new '|' syntax
non_none_args = [arg for arg in args if arg is not type(None)]
if len(non_none_args) == 1:
return f"Optional[{describe_field(non_none_args[0])}]"
return f"Optional[Union[{', '.join(describe_field(arg) for arg in non_none_args)}]]"
has_none = type(None) in args
if has_none:
# It's an Optional type
if len(non_none_args) == 1:
return f"Optional[{describe_field(non_none_args[0])}]"
# Union with None and multiple other types
return f"Optional[Union[{', '.join(describe_field(arg) for arg in non_none_args)}]]"
else:
if len(non_none_args) == 1:
return describe_field(non_none_args[0])
return f"Union[{', '.join(describe_field(arg) for arg in args)}]"
if origin is list:
return f"List[{describe_field(args[0])}]"
if origin is dict:

199
tests/test_union_types.py Normal file
View File

@@ -0,0 +1,199 @@
"""Test Union type support in Pydantic outputs."""
import json
from typing import Union
from unittest.mock import Mock, patch
import pytest
from pydantic import BaseModel
from crewai.utilities.converter import (
convert_to_model,
generate_model_description,
)
from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser
class SuccessData(BaseModel):
"""Model for successful response."""
status: str
result: str
value: int
class ErrorMessage(BaseModel):
"""Model for error response."""
status: str
error: str
code: int
class ResponseWithUnion(BaseModel):
"""Model with Union type field."""
response: Union[SuccessData, ErrorMessage]
class DirectUnionModel(BaseModel):
"""Model with direct Union type."""
data: Union[str, int, dict]
class MultiUnionModel(BaseModel):
"""Model with multiple Union types."""
field1: Union[str, int]
field2: Union[SuccessData, ErrorMessage, None]
def test_convert_to_model_with_union_success_data():
"""Test converting JSON to a model with Union type (SuccessData variant)."""
result = json.dumps({
"response": {
"status": "success",
"result": "Operation completed",
"value": 42
}
})
output = convert_to_model(result, ResponseWithUnion, None, None)
assert isinstance(output, ResponseWithUnion)
assert isinstance(output.response, SuccessData)
assert output.response.status == "success"
assert output.response.result == "Operation completed"
assert output.response.value == 42
def test_convert_to_model_with_union_error_message():
"""Test converting JSON to a model with Union type (ErrorMessage variant)."""
result = json.dumps({
"response": {
"status": "error",
"error": "Something went wrong",
"code": 500
}
})
output = convert_to_model(result, ResponseWithUnion, None, None)
assert isinstance(output, ResponseWithUnion)
assert isinstance(output.response, ErrorMessage)
assert output.response.status == "error"
assert output.response.error == "Something went wrong"
assert output.response.code == 500
def test_convert_to_model_with_direct_union_string():
"""Test converting JSON to a model with direct Union type (string variant)."""
result = json.dumps({"data": "hello world"})
output = convert_to_model(result, DirectUnionModel, None, None)
assert isinstance(output, DirectUnionModel)
assert isinstance(output.data, str)
assert output.data == "hello world"
def test_convert_to_model_with_direct_union_int():
"""Test converting JSON to a model with direct Union type (int variant)."""
result = json.dumps({"data": 42})
output = convert_to_model(result, DirectUnionModel, None, None)
assert isinstance(output, DirectUnionModel)
assert isinstance(output.data, int)
assert output.data == 42
def test_convert_to_model_with_direct_union_dict():
"""Test converting JSON to a model with direct Union type (dict variant)."""
result = json.dumps({"data": {"key": "value", "number": 123}})
output = convert_to_model(result, DirectUnionModel, None, None)
assert isinstance(output, DirectUnionModel)
assert isinstance(output.data, dict)
assert output.data == {"key": "value", "number": 123}
def test_convert_to_model_with_multiple_unions():
"""Test converting JSON to a model with multiple Union type fields."""
result = json.dumps({
"field1": "text",
"field2": {
"status": "success",
"result": "Done",
"value": 100
}
})
output = convert_to_model(result, MultiUnionModel, None, None)
assert isinstance(output, MultiUnionModel)
assert isinstance(output.field1, str)
assert output.field1 == "text"
assert isinstance(output.field2, SuccessData)
assert output.field2.status == "success"
def test_convert_to_model_with_optional_union_none():
"""Test converting JSON to a model with optional Union type (None variant)."""
result = json.dumps({
"field1": 42,
"field2": None
})
output = convert_to_model(result, MultiUnionModel, None, None)
assert isinstance(output, MultiUnionModel)
assert isinstance(output.field1, int)
assert output.field1 == 42
assert output.field2 is None
def test_generate_model_description_with_union():
"""Test that generate_model_description handles Union types correctly."""
description = generate_model_description(ResponseWithUnion)
assert "Union" in description
assert "Optional" not in description
assert "status" in description
print(f"Generated description:\n{description}")
def test_generate_model_description_with_direct_union():
"""Test that generate_model_description handles direct Union types correctly."""
description = generate_model_description(DirectUnionModel)
assert "Union" in description
assert "Optional" not in description
assert "str" in description and "int" in description and "dict" in description
print(f"Generated description:\n{description}")
def test_pydantic_schema_parser_with_union():
"""Test that PydanticSchemaParser handles Union types correctly."""
parser = PydanticSchemaParser(model=ResponseWithUnion)
schema = parser.get_schema()
assert "Union" in schema or "SuccessData" in schema or "ErrorMessage" in schema
print(f"Generated schema:\n{schema}")
def test_pydantic_schema_parser_with_direct_union():
"""Test that PydanticSchemaParser handles direct Union types correctly."""
parser = PydanticSchemaParser(model=DirectUnionModel)
schema = parser.get_schema()
assert "Union" in schema or ("str" in schema and "int" in schema and "dict" in schema)
print(f"Generated schema:\n{schema}")
def test_pydantic_schema_parser_with_optional_union():
"""Test that PydanticSchemaParser handles Optional Union types correctly."""
parser = PydanticSchemaParser(model=MultiUnionModel)
schema = parser.get_schema()
assert "Union" in schema or "Optional" in schema
print(f"Generated schema:\n{schema}")
def test_generate_model_description_with_optional_union():
"""Test that generate_model_description correctly wraps Optional Union types."""
description = generate_model_description(MultiUnionModel)
assert "field1" in description
assert "field2" in description
assert "Optional" in description
print(f"Generated description:\n{description}")

View File

@@ -596,5 +596,5 @@ def test_generate_model_description_union_field():
field: int | str | None
description = generate_model_description(UnionModel)
expected_description = '{\n "field": int | str | None\n}'
expected_description = '{\n "field": Optional[Union[int, str]]\n}'
assert description == expected_description