mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-30 23:02:50 +00:00
fix: allow classes in roundtrip check and use type-specific warning messages
This commit is contained in:
@@ -21,10 +21,11 @@ from pydantic.functional_serializers import PlainSerializer
|
||||
def _is_non_roundtrippable(fn: object) -> bool:
|
||||
"""Return ``True`` if *fn* cannot survive a serialize/deserialize round-trip.
|
||||
|
||||
Only plain module-level functions and built-in functions produce dotted
|
||||
paths that :func:`_resolve_dotted_path` can reliably resolve. Bound
|
||||
methods, ``functools.partial`` objects, callable class instances, lambdas,
|
||||
and closures all fail or silently change semantics during round-tripping.
|
||||
Built-in functions, plain module-level functions, and classes produce
|
||||
dotted paths that :func:`_resolve_dotted_path` can reliably resolve.
|
||||
Bound methods, ``functools.partial`` objects, callable class instances,
|
||||
lambdas, and closures all fail or silently change semantics during
|
||||
round-tripping.
|
||||
|
||||
Args:
|
||||
fn: The object to check.
|
||||
@@ -32,7 +33,7 @@ def _is_non_roundtrippable(fn: object) -> bool:
|
||||
Returns:
|
||||
``True`` if *fn* would not round-trip through JSON serialization.
|
||||
"""
|
||||
if inspect.isbuiltin(fn):
|
||||
if inspect.isbuiltin(fn) or inspect.isclass(fn):
|
||||
return False
|
||||
if inspect.isfunction(fn):
|
||||
qualname = getattr(fn, "__qualname__", "")
|
||||
@@ -44,9 +45,8 @@ def string_to_callable(value: Any) -> Callable[..., Any]:
|
||||
"""Convert a dotted path string to the callable it references.
|
||||
|
||||
If *value* is already callable it is returned as-is, with a warning if
|
||||
it is a lambda, closure, or nested function. Otherwise, it is treated
|
||||
as ``"module.qualname"`` and
|
||||
resolved via :func:`importlib.import_module`.
|
||||
it cannot survive JSON round-tripping. Otherwise, it is treated as
|
||||
``"module.qualname"`` and resolved via :func:`_resolve_dotted_path`.
|
||||
|
||||
Args:
|
||||
value: A callable or a dotted-path string e.g. ``"builtins.print"``.
|
||||
@@ -60,7 +60,7 @@ def string_to_callable(value: Any) -> Callable[..., Any]:
|
||||
if callable(value):
|
||||
if _is_non_roundtrippable(value):
|
||||
warnings.warn(
|
||||
"Lambdas, closures, and nested functions cannot be serialized "
|
||||
f"{type(value).__name__} callbacks cannot be serialized "
|
||||
"and will prevent checkpointing. "
|
||||
"Use a module-level named function instead.",
|
||||
UserWarning,
|
||||
|
||||
Reference in New Issue
Block a user