mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 08:38:30 +00:00
Fix union issue that Daniel was running into (#1910)
This commit is contained in:
committed by
GitHub
parent
3fecde49b6
commit
30d027158a
@@ -241,9 +241,13 @@ def generate_model_description(model: Type[BaseModel]) -> str:
|
|||||||
origin = get_origin(field_type)
|
origin = get_origin(field_type)
|
||||||
args = get_args(field_type)
|
args = get_args(field_type)
|
||||||
|
|
||||||
if origin is Union and type(None) in args:
|
if origin is Union or (origin is None and len(args) > 0):
|
||||||
|
# Handle both Union and the new '|' syntax
|
||||||
non_none_args = [arg for arg in args if arg is not type(None)]
|
non_none_args = [arg for arg in args if arg is not type(None)]
|
||||||
|
if len(non_none_args) == 1:
|
||||||
return f"Optional[{describe_field(non_none_args[0])}]"
|
return f"Optional[{describe_field(non_none_args[0])}]"
|
||||||
|
else:
|
||||||
|
return f"Optional[Union[{', '.join(describe_field(arg) for arg in non_none_args)}]]"
|
||||||
elif origin is list:
|
elif origin is list:
|
||||||
return f"List[{describe_field(args[0])}]"
|
return f"List[{describe_field(args[0])}]"
|
||||||
elif origin is dict:
|
elif origin is dict:
|
||||||
@@ -252,8 +256,10 @@ def generate_model_description(model: Type[BaseModel]) -> str:
|
|||||||
return f"Dict[{key_type}, {value_type}]"
|
return f"Dict[{key_type}, {value_type}]"
|
||||||
elif isinstance(field_type, type) and issubclass(field_type, BaseModel):
|
elif isinstance(field_type, type) and issubclass(field_type, BaseModel):
|
||||||
return generate_model_description(field_type)
|
return generate_model_description(field_type)
|
||||||
else:
|
elif hasattr(field_type, "__name__"):
|
||||||
return field_type.__name__
|
return field_type.__name__
|
||||||
|
else:
|
||||||
|
return str(field_type)
|
||||||
|
|
||||||
fields = model.__annotations__
|
fields = model.__annotations__
|
||||||
field_descriptions = [
|
field_descriptions = [
|
||||||
|
|||||||
@@ -588,3 +588,12 @@ def test_converter_with_function_calling():
|
|||||||
assert output.name == "Eve"
|
assert output.name == "Eve"
|
||||||
assert output.age == 35
|
assert output.age == 35
|
||||||
instructor.to_pydantic.assert_called_once()
|
instructor.to_pydantic.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_model_description_union_field():
|
||||||
|
class UnionModel(BaseModel):
|
||||||
|
field: int | str | None
|
||||||
|
|
||||||
|
description = generate_model_description(UnionModel)
|
||||||
|
expected_description = '{\n "field": int | str | None\n}'
|
||||||
|
assert description == expected_description
|
||||||
|
|||||||
Reference in New Issue
Block a user