fix: convert remaining callback lists to SerializableCallable and add tests

Address PR review: convert before_kickoff_callbacks, after_kickoff_callbacks
(crew.py) and callbacks (base_agent.py) from plain Callable to
SerializableCallable for consistent JSON round-trip support. Add 36 unit
tests for callback.py covering serialization, deserialization, env-var
gating, round-trip detection, and full Pydantic integration.
This commit is contained in:
Greyson LaLonde
2026-03-20 13:29:29 -04:00
parent 0a3d2f596e
commit 9929aea830
3 changed files with 241 additions and 6 deletions

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Callable
from copy import copy as shallow_copy
from hashlib import md5
import re
@@ -34,6 +33,7 @@ from crewai.memory.unified_memory import Memory
from crewai.rag.embeddings.types import EmbedderConfig
from crewai.security.security_config import SecurityConfig
from crewai.tools.base_tool import BaseTool, Tool
from crewai.types.callback import SerializableCallable
from crewai.utilities.config import process_config
from crewai.utilities.i18n import I18N, get_i18n
from crewai.utilities.logger import Logger
@@ -191,7 +191,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
default_factory=SecurityConfig,
description="Security configuration for the agent, including fingerprinting.",
)
callbacks: list[Callable[[Any], Any]] = Field(
callbacks: list[SerializableCallable] = Field(
default_factory=list, description="Callbacks to be used for the agent"
)
adapted_agent: bool = Field(

View File

@@ -228,16 +228,14 @@ class Crew(FlowTrackable, BaseModel):
default=None,
description="Callback to be executed after each task for all agents execution.",
)
before_kickoff_callbacks: list[
Callable[[dict[str, Any] | None], dict[str, Any] | None]
] = Field(
before_kickoff_callbacks: list[SerializableCallable] = Field(
default_factory=list,
description=(
"List of callbacks to be executed before crew kickoff. "
"It may be used to adjust inputs before the crew is executed."
),
)
after_kickoff_callbacks: list[Callable[[CrewOutput], CrewOutput]] = Field(
after_kickoff_callbacks: list[SerializableCallable] = Field(
default_factory=list,
description=(
"List of callbacks to be executed after crew kickoff. "

View File

@@ -0,0 +1,237 @@
"""Tests for crewai.types.callback — SerializableCallable round-tripping."""
from __future__ import annotations
import functools
import os
from typing import Any
import pytest
from pydantic import BaseModel, ValidationError
from crewai.types.callback import (
SerializableCallable,
_is_non_roundtrippable,
_resolve_dotted_path,
callable_to_string,
string_to_callable,
)
# ── Helpers ──────────────────────────────────────────────────────────
def module_level_function() -> str:
"""Plain module-level function that should round-trip."""
return "hello"
class _CallableInstance:
"""Callable class instance — non-roundtrippable."""
def __call__(self) -> str:
return "instance"
class _HasMethod:
def method(self) -> str:
return "method"
class _Model(BaseModel):
cb: SerializableCallable | None = None
# ── _is_non_roundtrippable ───────────────────────────────────────────
class TestIsNonRoundtrippable:
def test_builtin_is_roundtrippable(self) -> None:
assert _is_non_roundtrippable(print) is False
assert _is_non_roundtrippable(len) is False
def test_class_is_roundtrippable(self) -> None:
assert _is_non_roundtrippable(dict) is False
assert _is_non_roundtrippable(_CallableInstance) is False
def test_module_level_function_is_roundtrippable(self) -> None:
assert _is_non_roundtrippable(module_level_function) is False
def test_lambda_is_non_roundtrippable(self) -> None:
assert _is_non_roundtrippable(lambda: None) is True
def test_closure_is_non_roundtrippable(self) -> None:
x = 1
def closure() -> int:
return x
assert _is_non_roundtrippable(closure) is True
def test_bound_method_is_non_roundtrippable(self) -> None:
assert _is_non_roundtrippable(_HasMethod().method) is True
def test_partial_is_non_roundtrippable(self) -> None:
assert _is_non_roundtrippable(functools.partial(print, "hi")) is True
def test_callable_instance_is_non_roundtrippable(self) -> None:
assert _is_non_roundtrippable(_CallableInstance()) is True
# ── callable_to_string ───────────────────────────────────────────────
class TestCallableToString:
def test_module_level_function(self) -> None:
result = callable_to_string(module_level_function)
assert result == f"{__name__}.module_level_function"
def test_class(self) -> None:
result = callable_to_string(dict)
assert result == "builtins.dict"
def test_builtin(self) -> None:
result = callable_to_string(print)
assert result == "builtins.print"
def test_lambda_produces_locals_path(self) -> None:
fn = lambda: None # noqa: E731
result = callable_to_string(fn)
assert "<lambda>" in result
def test_missing_qualname_raises(self) -> None:
obj = type("NoQual", (), {"__module__": "test"})()
obj.__qualname__ = None # type: ignore[assignment]
with pytest.raises(ValueError, match="missing __module__ or __qualname__"):
callable_to_string(obj)
def test_missing_module_raises(self) -> None:
# Create an object where getattr(obj, "__module__", None) returns None
ns: dict[str, Any] = {"__qualname__": "x", "__module__": None}
obj = type("NoMod", (), ns)()
with pytest.raises(ValueError, match="missing __module__"):
callable_to_string(obj)
# ── string_to_callable ───────────────────────────────────────────────
class TestStringToCallable:
def test_callable_passthrough(self) -> None:
assert string_to_callable(print) is print
def test_roundtrippable_callable_no_warning(self, recwarn: pytest.WarningsChecker) -> None:
string_to_callable(module_level_function)
our_warnings = [
w for w in recwarn if "cannot be serialized" in str(w.message)
]
assert our_warnings == []
def test_non_roundtrippable_warns(self) -> None:
with pytest.warns(UserWarning, match="cannot be serialized"):
string_to_callable(functools.partial(print))
def test_non_callable_non_string_raises(self) -> None:
with pytest.raises(ValueError, match="Expected a callable"):
string_to_callable(42)
def test_string_without_dot_raises(self) -> None:
with pytest.raises(ValueError, match="expected 'module.name' format"):
string_to_callable("nodots")
def test_string_refused_without_env_var(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.delenv("CREWAI_DESERIALIZE_CALLBACKS", raising=False)
with pytest.raises(ValueError, match="Refusing to resolve"):
string_to_callable("builtins.print")
def test_string_resolves_with_env_var(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("CREWAI_DESERIALIZE_CALLBACKS", "1")
result = string_to_callable("builtins.print")
assert result is print
def test_string_resolves_multi_level_path(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("CREWAI_DESERIALIZE_CALLBACKS", "1")
result = string_to_callable("os.path.join")
assert result is os.path.join
def test_unresolvable_path_raises(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("CREWAI_DESERIALIZE_CALLBACKS", "1")
with pytest.raises(ValueError, match="Cannot resolve"):
string_to_callable("nonexistent.module.func")
# ── _resolve_dotted_path ─────────────────────────────────────────────
class TestResolveDottedPath:
def test_builtin(self) -> None:
assert _resolve_dotted_path("builtins.print") is print
def test_nested_module_attribute(self) -> None:
assert _resolve_dotted_path("os.path.join") is os.path.join
def test_class_on_module(self) -> None:
from collections import OrderedDict
assert _resolve_dotted_path("collections.OrderedDict") is OrderedDict
def test_nonexistent_raises(self) -> None:
with pytest.raises(ValueError, match="Cannot resolve"):
_resolve_dotted_path("no.such.module.func")
def test_non_callable_attribute_skipped(self) -> None:
# os.sep is a string, not callable — should not resolve
with pytest.raises(ValueError, match="Cannot resolve"):
_resolve_dotted_path("os.sep")
# ── Pydantic integration round-trip ──────────────────────────────────
class TestSerializableCallableRoundTrip:
def test_json_serialize_module_function(self) -> None:
m = _Model(cb=module_level_function)
data = m.model_dump(mode="json")
assert data["cb"] == f"{__name__}.module_level_function"
def test_json_round_trip(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("CREWAI_DESERIALIZE_CALLBACKS", "1")
m = _Model(cb=print)
json_str = m.model_dump_json()
restored = _Model.model_validate_json(json_str)
assert restored.cb is print
def test_json_round_trip_class(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("CREWAI_DESERIALIZE_CALLBACKS", "1")
m = _Model(cb=dict)
json_str = m.model_dump_json()
restored = _Model.model_validate_json(json_str)
assert restored.cb is dict
def test_python_mode_preserves_callable(self) -> None:
m = _Model(cb=module_level_function)
data = m.model_dump(mode="python")
assert data["cb"] is module_level_function
def test_none_field(self) -> None:
m = _Model(cb=None)
assert m.cb is None
data = m.model_dump(mode="json")
assert data["cb"] is None
def test_validation_error_for_int(self) -> None:
with pytest.raises(ValidationError):
_Model(cb=42) # type: ignore[arg-type]
def test_deserialization_refused_without_env(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.delenv("CREWAI_DESERIALIZE_CALLBACKS", raising=False)
with pytest.raises(ValidationError, match="Refusing to resolve"):
_Model.model_validate({"cb": "builtins.print"})
def test_json_schema_is_string(self) -> None:
schema = _Model.model_json_schema()
cb_schema = schema["properties"]["cb"]
# anyOf for Optional: one string, one null
types = {item.get("type") for item in cb_schema.get("anyOf", [cb_schema])}
assert "string" in types