From f76c55ffa7099cc52bcf310178aac9dd49e35fc8 Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Sun, 19 Oct 2025 21:33:03 +0000 Subject: [PATCH] Fix Union type handling in Pydantic output generation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fixed generate_model_description to correctly handle Union types - Union types without None are now properly formatted as Union[type1, type2] - Union types with None are correctly wrapped in Optional[Union[...]] - Added support for Python 3.10+ pipe syntax (int | str | None) - Added comprehensive tests for Union type support - Updated existing test expectations to match corrected behavior This fixes issue #3735 where Union types were incorrectly wrapped in Optional even when None was not part of the Union. Co-Authored-By: João --- src/crewai/utilities/converter.py | 24 +++- tests/test_union_types.py | 199 ++++++++++++++++++++++++++++++ tests/utilities/test_converter.py | 2 +- 3 files changed, 220 insertions(+), 5 deletions(-) create mode 100644 tests/test_union_types.py diff --git a/src/crewai/utilities/converter.py b/src/crewai/utilities/converter.py index 07f6f7ea3..a650ab4c1 100644 --- a/src/crewai/utilities/converter.py +++ b/src/crewai/utilities/converter.py @@ -2,11 +2,18 @@ from __future__ import annotations import json import re +import sys from typing import TYPE_CHECKING, Any, Final, TypedDict, Union, get_args, get_origin from pydantic import BaseModel, ValidationError from typing_extensions import Unpack +if sys.version_info >= (3, 10): + import types + UnionTypes = (Union, types.UnionType) +else: + UnionTypes = (Union,) + from crewai.agents.agent_builder.utilities.base_output_converter import OutputConverter from crewai.utilities.internal_instructor import InternalInstructor from crewai.utilities.printer import Printer @@ -428,12 +435,21 @@ def generate_model_description(model: type[BaseModel]) -> str: origin = get_origin(field_type) args = get_args(field_type) - if origin is Union or (origin is None and len(args) > 0): + if origin in UnionTypes or (origin is None and len(args) > 0): # Handle both Union and the new '|' syntax non_none_args = [arg for arg in args if arg is not type(None)] - if len(non_none_args) == 1: - return f"Optional[{describe_field(non_none_args[0])}]" - return f"Optional[Union[{', '.join(describe_field(arg) for arg in non_none_args)}]]" + has_none = type(None) in args + + if has_none: + # It's an Optional type + if len(non_none_args) == 1: + return f"Optional[{describe_field(non_none_args[0])}]" + # Union with None and multiple other types + return f"Optional[Union[{', '.join(describe_field(arg) for arg in non_none_args)}]]" + else: + if len(non_none_args) == 1: + return describe_field(non_none_args[0]) + return f"Union[{', '.join(describe_field(arg) for arg in args)}]" if origin is list: return f"List[{describe_field(args[0])}]" if origin is dict: diff --git a/tests/test_union_types.py b/tests/test_union_types.py new file mode 100644 index 000000000..7bbac8730 --- /dev/null +++ b/tests/test_union_types.py @@ -0,0 +1,199 @@ +"""Test Union type support in Pydantic outputs.""" +import json +from typing import Union +from unittest.mock import Mock, patch + +import pytest +from pydantic import BaseModel + +from crewai.utilities.converter import ( + convert_to_model, + generate_model_description, +) +from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser + + +class SuccessData(BaseModel): + """Model for successful response.""" + status: str + result: str + value: int + + +class ErrorMessage(BaseModel): + """Model for error response.""" + status: str + error: str + code: int + + +class ResponseWithUnion(BaseModel): + """Model with Union type field.""" + response: Union[SuccessData, ErrorMessage] + + +class DirectUnionModel(BaseModel): + """Model with direct Union type.""" + data: Union[str, int, dict] + + +class MultiUnionModel(BaseModel): + """Model with multiple Union types.""" + field1: Union[str, int] + field2: Union[SuccessData, ErrorMessage, None] + + +def test_convert_to_model_with_union_success_data(): + """Test converting JSON to a model with Union type (SuccessData variant).""" + result = json.dumps({ + "response": { + "status": "success", + "result": "Operation completed", + "value": 42 + } + }) + + output = convert_to_model(result, ResponseWithUnion, None, None) + assert isinstance(output, ResponseWithUnion) + assert isinstance(output.response, SuccessData) + assert output.response.status == "success" + assert output.response.result == "Operation completed" + assert output.response.value == 42 + + +def test_convert_to_model_with_union_error_message(): + """Test converting JSON to a model with Union type (ErrorMessage variant).""" + result = json.dumps({ + "response": { + "status": "error", + "error": "Something went wrong", + "code": 500 + } + }) + + output = convert_to_model(result, ResponseWithUnion, None, None) + assert isinstance(output, ResponseWithUnion) + assert isinstance(output.response, ErrorMessage) + assert output.response.status == "error" + assert output.response.error == "Something went wrong" + assert output.response.code == 500 + + +def test_convert_to_model_with_direct_union_string(): + """Test converting JSON to a model with direct Union type (string variant).""" + result = json.dumps({"data": "hello world"}) + + output = convert_to_model(result, DirectUnionModel, None, None) + assert isinstance(output, DirectUnionModel) + assert isinstance(output.data, str) + assert output.data == "hello world" + + +def test_convert_to_model_with_direct_union_int(): + """Test converting JSON to a model with direct Union type (int variant).""" + result = json.dumps({"data": 42}) + + output = convert_to_model(result, DirectUnionModel, None, None) + assert isinstance(output, DirectUnionModel) + assert isinstance(output.data, int) + assert output.data == 42 + + +def test_convert_to_model_with_direct_union_dict(): + """Test converting JSON to a model with direct Union type (dict variant).""" + result = json.dumps({"data": {"key": "value", "number": 123}}) + + output = convert_to_model(result, DirectUnionModel, None, None) + assert isinstance(output, DirectUnionModel) + assert isinstance(output.data, dict) + assert output.data == {"key": "value", "number": 123} + + +def test_convert_to_model_with_multiple_unions(): + """Test converting JSON to a model with multiple Union type fields.""" + result = json.dumps({ + "field1": "text", + "field2": { + "status": "success", + "result": "Done", + "value": 100 + } + }) + + output = convert_to_model(result, MultiUnionModel, None, None) + assert isinstance(output, MultiUnionModel) + assert isinstance(output.field1, str) + assert output.field1 == "text" + assert isinstance(output.field2, SuccessData) + assert output.field2.status == "success" + + +def test_convert_to_model_with_optional_union_none(): + """Test converting JSON to a model with optional Union type (None variant).""" + result = json.dumps({ + "field1": 42, + "field2": None + }) + + output = convert_to_model(result, MultiUnionModel, None, None) + assert isinstance(output, MultiUnionModel) + assert isinstance(output.field1, int) + assert output.field1 == 42 + assert output.field2 is None + + +def test_generate_model_description_with_union(): + """Test that generate_model_description handles Union types correctly.""" + description = generate_model_description(ResponseWithUnion) + + assert "Union" in description + assert "Optional" not in description + assert "status" in description + print(f"Generated description:\n{description}") + + +def test_generate_model_description_with_direct_union(): + """Test that generate_model_description handles direct Union types correctly.""" + description = generate_model_description(DirectUnionModel) + + assert "Union" in description + assert "Optional" not in description + assert "str" in description and "int" in description and "dict" in description + print(f"Generated description:\n{description}") + + +def test_pydantic_schema_parser_with_union(): + """Test that PydanticSchemaParser handles Union types correctly.""" + parser = PydanticSchemaParser(model=ResponseWithUnion) + schema = parser.get_schema() + + assert "Union" in schema or "SuccessData" in schema or "ErrorMessage" in schema + print(f"Generated schema:\n{schema}") + + +def test_pydantic_schema_parser_with_direct_union(): + """Test that PydanticSchemaParser handles direct Union types correctly.""" + parser = PydanticSchemaParser(model=DirectUnionModel) + schema = parser.get_schema() + + assert "Union" in schema or ("str" in schema and "int" in schema and "dict" in schema) + print(f"Generated schema:\n{schema}") + + +def test_pydantic_schema_parser_with_optional_union(): + """Test that PydanticSchemaParser handles Optional Union types correctly.""" + parser = PydanticSchemaParser(model=MultiUnionModel) + schema = parser.get_schema() + + assert "Union" in schema or "Optional" in schema + print(f"Generated schema:\n{schema}") + + +def test_generate_model_description_with_optional_union(): + """Test that generate_model_description correctly wraps Optional Union types.""" + description = generate_model_description(MultiUnionModel) + + assert "field1" in description + assert "field2" in description + assert "Optional" in description + print(f"Generated description:\n{description}") diff --git a/tests/utilities/test_converter.py b/tests/utilities/test_converter.py index 7ebc52bed..e425f728d 100644 --- a/tests/utilities/test_converter.py +++ b/tests/utilities/test_converter.py @@ -596,5 +596,5 @@ def test_generate_model_description_union_field(): field: int | str | None description = generate_model_description(UnionModel) - expected_description = '{\n "field": int | str | None\n}' + expected_description = '{\n "field": Optional[Union[int, str]]\n}' assert description == expected_description