fix: detect closures and nested functions as non-roundtrippable callbacks

This commit is contained in:
Greyson LaLonde
2026-03-19 19:45:05 -04:00
parent aadb2d6694
commit 0d6b2ef8b8

View File

@@ -18,29 +18,34 @@ from pydantic import BeforeValidator, WithJsonSchema
from pydantic.functional_serializers import PlainSerializer
def _is_lambda(fn: object) -> bool:
"""Return ``True`` if *fn* is a lambda expression.
def _is_non_roundtrippable(fn: object) -> bool:
"""Return ``True`` if *fn* cannot survive a serialize/deserialize round-trip.
Uses ``__qualname__`` ending with ``"<lambda>"`` for resilience against
``__name__`` being reassigned. ``inspect.isfunction`` gates the check
so non-function callables (classes, partials, etc.) are never flagged.
Detects lambdas via ``__qualname__`` ending with ``"<lambda>"`` and
closures or nested functions via ``"<locals>"`` appearing anywhere in
``__qualname__``. Both produce dotted paths that
:func:`string_to_callable` cannot resolve back to the original object.
``inspect.isfunction`` gates the check so non-function callables
like classes and partials are never flagged.
Args:
fn: The object to check.
Returns:
``True`` if *fn* is a lambda, ``False`` otherwise.
``True`` if *fn* is a lambda, closure, or nested function.
"""
return inspect.isfunction(fn) and getattr(fn, "__qualname__", "").endswith(
"<lambda>"
)
if not inspect.isfunction(fn):
return False
qualname = getattr(fn, "__qualname__", "")
return qualname.endswith("<lambda>") or "<locals>" in qualname
def string_to_callable(value: Any) -> Callable[..., Any]:
"""Convert a dotted path string to the callable it references.
If *value* is already callable it is returned as-is, with a warning if
it is a lambda. Otherwise, it is treated as ``"module.qualname"`` and
it is a lambda, closure, or nested function. Otherwise, it is treated
as ``"module.qualname"`` and
resolved via :func:`importlib.import_module`.
Args:
@@ -54,9 +59,10 @@ def string_to_callable(value: Any) -> Callable[..., Any]:
AttributeError: If the attribute cannot be found on the imported module.
"""
if callable(value):
if _is_lambda(value):
if _is_non_roundtrippable(value):
warnings.warn(
"Lambdas cannot be serialized and will prevent checkpointing. "
"Lambdas, closures, and nested functions cannot be serialized "
"and will prevent checkpointing. "
"Use a module-level named function instead.",
UserWarning,
stacklevel=2,