diff --git a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py index e2e146e10..674b15fa8 100644 --- a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py +++ b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py @@ -1,7 +1,6 @@ from __future__ import annotations from abc import ABC, abstractmethod -from collections.abc import Callable from copy import copy as shallow_copy from hashlib import md5 import re @@ -34,6 +33,7 @@ from crewai.memory.unified_memory import Memory from crewai.rag.embeddings.types import EmbedderConfig from crewai.security.security_config import SecurityConfig from crewai.tools.base_tool import BaseTool, Tool +from crewai.types.callback import SerializableCallable from crewai.utilities.config import process_config from crewai.utilities.i18n import I18N, get_i18n from crewai.utilities.logger import Logger @@ -191,7 +191,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta): default_factory=SecurityConfig, description="Security configuration for the agent, including fingerprinting.", ) - callbacks: list[Callable[[Any], Any]] = Field( + callbacks: list[SerializableCallable] = Field( default_factory=list, description="Callbacks to be used for the agent" ) adapted_agent: bool = Field( diff --git a/lib/crewai/src/crewai/crew.py b/lib/crewai/src/crewai/crew.py index 5e23f37d6..c5156888c 100644 --- a/lib/crewai/src/crewai/crew.py +++ b/lib/crewai/src/crewai/crew.py @@ -228,16 +228,14 @@ class Crew(FlowTrackable, BaseModel): default=None, description="Callback to be executed after each task for all agents execution.", ) - before_kickoff_callbacks: list[ - Callable[[dict[str, Any] | None], dict[str, Any] | None] - ] = Field( + before_kickoff_callbacks: list[SerializableCallable] = Field( default_factory=list, description=( "List of callbacks to be executed before crew kickoff. " "It may be used to adjust inputs before the crew is executed." ), ) - after_kickoff_callbacks: list[Callable[[CrewOutput], CrewOutput]] = Field( + after_kickoff_callbacks: list[SerializableCallable] = Field( default_factory=list, description=( "List of callbacks to be executed after crew kickoff. " diff --git a/lib/crewai/tests/test_callback.py b/lib/crewai/tests/test_callback.py new file mode 100644 index 000000000..417c74d98 --- /dev/null +++ b/lib/crewai/tests/test_callback.py @@ -0,0 +1,237 @@ +"""Tests for crewai.types.callback — SerializableCallable round-tripping.""" + +from __future__ import annotations + +import functools +import os +from typing import Any +import pytest +from pydantic import BaseModel, ValidationError + +from crewai.types.callback import ( + SerializableCallable, + _is_non_roundtrippable, + _resolve_dotted_path, + callable_to_string, + string_to_callable, +) + + +# ── Helpers ────────────────────────────────────────────────────────── + + +def module_level_function() -> str: + """Plain module-level function that should round-trip.""" + return "hello" + + +class _CallableInstance: + """Callable class instance — non-roundtrippable.""" + + def __call__(self) -> str: + return "instance" + + +class _HasMethod: + def method(self) -> str: + return "method" + + +class _Model(BaseModel): + cb: SerializableCallable | None = None + + +# ── _is_non_roundtrippable ─────────────────────────────────────────── + + +class TestIsNonRoundtrippable: + def test_builtin_is_roundtrippable(self) -> None: + assert _is_non_roundtrippable(print) is False + assert _is_non_roundtrippable(len) is False + + def test_class_is_roundtrippable(self) -> None: + assert _is_non_roundtrippable(dict) is False + assert _is_non_roundtrippable(_CallableInstance) is False + + def test_module_level_function_is_roundtrippable(self) -> None: + assert _is_non_roundtrippable(module_level_function) is False + + def test_lambda_is_non_roundtrippable(self) -> None: + assert _is_non_roundtrippable(lambda: None) is True + + def test_closure_is_non_roundtrippable(self) -> None: + x = 1 + + def closure() -> int: + return x + + assert _is_non_roundtrippable(closure) is True + + def test_bound_method_is_non_roundtrippable(self) -> None: + assert _is_non_roundtrippable(_HasMethod().method) is True + + def test_partial_is_non_roundtrippable(self) -> None: + assert _is_non_roundtrippable(functools.partial(print, "hi")) is True + + def test_callable_instance_is_non_roundtrippable(self) -> None: + assert _is_non_roundtrippable(_CallableInstance()) is True + + +# ── callable_to_string ─────────────────────────────────────────────── + + +class TestCallableToString: + def test_module_level_function(self) -> None: + result = callable_to_string(module_level_function) + assert result == f"{__name__}.module_level_function" + + def test_class(self) -> None: + result = callable_to_string(dict) + assert result == "builtins.dict" + + def test_builtin(self) -> None: + result = callable_to_string(print) + assert result == "builtins.print" + + def test_lambda_produces_locals_path(self) -> None: + fn = lambda: None # noqa: E731 + result = callable_to_string(fn) + assert "" in result + + def test_missing_qualname_raises(self) -> None: + obj = type("NoQual", (), {"__module__": "test"})() + obj.__qualname__ = None # type: ignore[assignment] + with pytest.raises(ValueError, match="missing __module__ or __qualname__"): + callable_to_string(obj) + + def test_missing_module_raises(self) -> None: + # Create an object where getattr(obj, "__module__", None) returns None + ns: dict[str, Any] = {"__qualname__": "x", "__module__": None} + obj = type("NoMod", (), ns)() + with pytest.raises(ValueError, match="missing __module__"): + callable_to_string(obj) + + +# ── string_to_callable ─────────────────────────────────────────────── + + +class TestStringToCallable: + def test_callable_passthrough(self) -> None: + assert string_to_callable(print) is print + + def test_roundtrippable_callable_no_warning(self, recwarn: pytest.WarningsChecker) -> None: + string_to_callable(module_level_function) + our_warnings = [ + w for w in recwarn if "cannot be serialized" in str(w.message) + ] + assert our_warnings == [] + + def test_non_roundtrippable_warns(self) -> None: + with pytest.warns(UserWarning, match="cannot be serialized"): + string_to_callable(functools.partial(print)) + + def test_non_callable_non_string_raises(self) -> None: + with pytest.raises(ValueError, match="Expected a callable"): + string_to_callable(42) + + def test_string_without_dot_raises(self) -> None: + with pytest.raises(ValueError, match="expected 'module.name' format"): + string_to_callable("nodots") + + def test_string_refused_without_env_var(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("CREWAI_DESERIALIZE_CALLBACKS", raising=False) + with pytest.raises(ValueError, match="Refusing to resolve"): + string_to_callable("builtins.print") + + def test_string_resolves_with_env_var(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("CREWAI_DESERIALIZE_CALLBACKS", "1") + result = string_to_callable("builtins.print") + assert result is print + + def test_string_resolves_multi_level_path(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("CREWAI_DESERIALIZE_CALLBACKS", "1") + result = string_to_callable("os.path.join") + assert result is os.path.join + + def test_unresolvable_path_raises(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("CREWAI_DESERIALIZE_CALLBACKS", "1") + with pytest.raises(ValueError, match="Cannot resolve"): + string_to_callable("nonexistent.module.func") + + +# ── _resolve_dotted_path ───────────────────────────────────────────── + + +class TestResolveDottedPath: + def test_builtin(self) -> None: + assert _resolve_dotted_path("builtins.print") is print + + def test_nested_module_attribute(self) -> None: + assert _resolve_dotted_path("os.path.join") is os.path.join + + def test_class_on_module(self) -> None: + from collections import OrderedDict + + assert _resolve_dotted_path("collections.OrderedDict") is OrderedDict + + def test_nonexistent_raises(self) -> None: + with pytest.raises(ValueError, match="Cannot resolve"): + _resolve_dotted_path("no.such.module.func") + + def test_non_callable_attribute_skipped(self) -> None: + # os.sep is a string, not callable — should not resolve + with pytest.raises(ValueError, match="Cannot resolve"): + _resolve_dotted_path("os.sep") + + +# ── Pydantic integration round-trip ────────────────────────────────── + + +class TestSerializableCallableRoundTrip: + def test_json_serialize_module_function(self) -> None: + m = _Model(cb=module_level_function) + data = m.model_dump(mode="json") + assert data["cb"] == f"{__name__}.module_level_function" + + def test_json_round_trip(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("CREWAI_DESERIALIZE_CALLBACKS", "1") + m = _Model(cb=print) + json_str = m.model_dump_json() + restored = _Model.model_validate_json(json_str) + assert restored.cb is print + + def test_json_round_trip_class(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("CREWAI_DESERIALIZE_CALLBACKS", "1") + m = _Model(cb=dict) + json_str = m.model_dump_json() + restored = _Model.model_validate_json(json_str) + assert restored.cb is dict + + def test_python_mode_preserves_callable(self) -> None: + m = _Model(cb=module_level_function) + data = m.model_dump(mode="python") + assert data["cb"] is module_level_function + + def test_none_field(self) -> None: + m = _Model(cb=None) + assert m.cb is None + data = m.model_dump(mode="json") + assert data["cb"] is None + + def test_validation_error_for_int(self) -> None: + with pytest.raises(ValidationError): + _Model(cb=42) # type: ignore[arg-type] + + def test_deserialization_refused_without_env( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.delenv("CREWAI_DESERIALIZE_CALLBACKS", raising=False) + with pytest.raises(ValidationError, match="Refusing to resolve"): + _Model.model_validate({"cb": "builtins.print"}) + + def test_json_schema_is_string(self) -> None: + schema = _Model.model_json_schema() + cb_schema = schema["properties"]["cb"] + # anyOf for Optional: one string, one null + types = {item.get("type") for item in cb_schema.get("anyOf", [cb_schema])} + assert "string" in types \ No newline at end of file