From 6495aff5286713fa5b488ffd714255b7cca80f76 Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Fri, 20 Mar 2026 15:18:50 -0400 Subject: [PATCH] refactor: replace Any-typed callback and model fields with serializable types --- lib/crewai/src/crewai/agent/core.py | 9 +- .../crewai/agents/agent_builder/base_agent.py | 12 +- lib/crewai/src/crewai/crew.py | 28 ++- lib/crewai/src/crewai/flow/flow.py | 31 ++- lib/crewai/src/crewai/flow/flow_config.py | 10 +- .../src/crewai/memory/unified_memory.py | 3 +- lib/crewai/src/crewai/task.py | 3 +- lib/crewai/src/crewai/types/callback.py | 152 +++++++++++ lib/crewai/tests/agents/test_agent.py | 3 + lib/crewai/tests/test_callback.py | 237 ++++++++++++++++++ lib/crewai/tests/test_project.py | 6 +- 11 files changed, 452 insertions(+), 42 deletions(-) create mode 100644 lib/crewai/src/crewai/types/callback.py create mode 100644 lib/crewai/tests/test_callback.py diff --git a/lib/crewai/src/crewai/agent/core.py b/lib/crewai/src/crewai/agent/core.py index 55eb807ef..3aa48137d 100644 --- a/lib/crewai/src/crewai/agent/core.py +++ b/lib/crewai/src/crewai/agent/core.py @@ -66,6 +66,7 @@ from crewai.mcp.tool_resolver import MCPToolResolver from crewai.rag.embeddings.types import EmbedderConfig from crewai.security.fingerprint import Fingerprint from crewai.tools.agent_tools.agent_tools import AgentTools +from crewai.types.callback import SerializableCallable from crewai.utilities.agent_utils import ( get_tool_names, is_inside_event_loop, @@ -143,7 +144,7 @@ class Agent(BaseAgent): default=None, description="Maximum execution time for an agent to execute a task", ) - step_callback: Any | None = Field( + step_callback: SerializableCallable | None = Field( default=None, description="Callback to be executed after each step of the agent execution.", ) @@ -151,10 +152,10 @@ class Agent(BaseAgent): default=True, description="Use system prompt for the agent.", ) - llm: str | InstanceOf[BaseLLM] | Any = Field( + llm: str | InstanceOf[BaseLLM] | None = Field( description="Language model that will run the agent.", default=None ) - function_calling_llm: str | InstanceOf[BaseLLM] | Any | None = Field( + function_calling_llm: str | InstanceOf[BaseLLM] | None = Field( description="Language model that will run the agent.", default=None ) system_template: str | None = Field( @@ -340,7 +341,7 @@ class Agent(BaseAgent): return ( hasattr(self.llm, "supports_function_calling") and callable(getattr(self.llm, "supports_function_calling", None)) - and self.llm.supports_function_calling() + and self.llm.supports_function_calling() # type: ignore[union-attr] and len(tools) > 0 ) diff --git a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py index da32d9c1c..674b15fa8 100644 --- a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py +++ b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py @@ -1,7 +1,6 @@ from __future__ import annotations from abc import ABC, abstractmethod -from collections.abc import Callable from copy import copy as shallow_copy from hashlib import md5 import re @@ -12,6 +11,7 @@ from pydantic import ( UUID4, BaseModel, Field, + InstanceOf, PrivateAttr, field_validator, model_validator, @@ -26,10 +26,14 @@ from crewai.agents.tools_handler import ToolsHandler from crewai.knowledge.knowledge import Knowledge from crewai.knowledge.knowledge_config import KnowledgeConfig from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource +from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage from crewai.mcp.config import MCPServerConfig +from crewai.memory.memory_scope import MemoryScope, MemorySlice +from crewai.memory.unified_memory import Memory from crewai.rag.embeddings.types import EmbedderConfig from crewai.security.security_config import SecurityConfig from crewai.tools.base_tool import BaseTool, Tool +from crewai.types.callback import SerializableCallable from crewai.utilities.config import process_config from crewai.utilities.i18n import I18N, get_i18n from crewai.utilities.logger import Logger @@ -179,7 +183,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta): default=None, description="Knowledge sources for the agent.", ) - knowledge_storage: Any | None = Field( + knowledge_storage: InstanceOf[BaseKnowledgeStorage] | None = Field( default=None, description="Custom knowledge storage for the agent.", ) @@ -187,7 +191,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta): default_factory=SecurityConfig, description="Security configuration for the agent, including fingerprinting.", ) - callbacks: list[Callable[[Any], Any]] = Field( + callbacks: list[SerializableCallable] = Field( default_factory=list, description="Callbacks to be used for the agent" ) adapted_agent: bool = Field( @@ -205,7 +209,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta): default=None, description="List of MCP server references. Supports 'https://server.com/path' for external servers and bare slugs like 'notion' for connected MCP integrations. Use '#tool_name' suffix for specific tools.", ) - memory: Any = Field( + memory: bool | Memory | MemoryScope | MemorySlice | None = Field( default=None, description=( "Enable agent memory. Pass True for default Memory(), " diff --git a/lib/crewai/src/crewai/crew.py b/lib/crewai/src/crewai/crew.py index 61d1f52cf..c5156888c 100644 --- a/lib/crewai/src/crewai/crew.py +++ b/lib/crewai/src/crewai/crew.py @@ -35,6 +35,7 @@ from typing_extensions import Self if TYPE_CHECKING: from crewai_files import FileInput + from opentelemetry.trace import Span try: from crewai_files import get_supported_content_types @@ -83,6 +84,8 @@ from crewai.knowledge.knowledge import Knowledge from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.llm import LLM from crewai.llms.base_llm import BaseLLM +from crewai.memory.memory_scope import MemoryScope, MemorySlice +from crewai.memory.unified_memory import Memory from crewai.process import Process from crewai.rag.embeddings.types import EmbedderConfig from crewai.rag.types import SearchResult @@ -94,6 +97,7 @@ from crewai.tasks.task_output import TaskOutput from crewai.tools.agent_tools.agent_tools import AgentTools from crewai.tools.agent_tools.read_file_tool import ReadFileTool from crewai.tools.base_tool import BaseTool +from crewai.types.callback import SerializableCallable from crewai.types.streaming import CrewStreamingOutput from crewai.types.usage_metrics import UsageMetrics from crewai.utilities.constants import NOT_SPECIFIED, TRAINING_DATA_FILE @@ -166,12 +170,12 @@ class Crew(FlowTrackable, BaseModel): """ __hash__ = object.__hash__ - _execution_span: Any = PrivateAttr() + _execution_span: Span | None = PrivateAttr() _rpm_controller: RPMController = PrivateAttr() _logger: Logger = PrivateAttr() _file_handler: FileHandler = PrivateAttr() _cache_handler: InstanceOf[CacheHandler] = PrivateAttr(default_factory=CacheHandler) - _memory: Any = PrivateAttr(default=None) # Unified Memory | MemoryScope + _memory: Memory | MemoryScope | MemorySlice | None = PrivateAttr(default=None) _train: bool | None = PrivateAttr(default=False) _train_iteration: int | None = PrivateAttr() _inputs: dict[str, Any] | None = PrivateAttr(default=None) @@ -189,7 +193,7 @@ class Crew(FlowTrackable, BaseModel): agents: list[BaseAgent] = Field(default_factory=list) process: Process = Field(default=Process.sequential) verbose: bool = Field(default=False) - memory: bool | Any = Field( + memory: bool | Memory | MemoryScope | MemorySlice | None = Field( default=False, description=( "Enable crew memory. Pass True for default Memory(), " @@ -204,36 +208,34 @@ class Crew(FlowTrackable, BaseModel): default=None, description="Metrics for the LLM usage during all tasks execution.", ) - manager_llm: str | InstanceOf[BaseLLM] | Any | None = Field( + manager_llm: str | InstanceOf[BaseLLM] | None = Field( description="Language model that will run the agent.", default=None ) manager_agent: BaseAgent | None = Field( description="Custom agent that will be used as manager.", default=None ) - function_calling_llm: str | InstanceOf[LLM] | Any | None = Field( + function_calling_llm: str | InstanceOf[LLM] | None = Field( description="Language model that will run the agent.", default=None ) config: Json[dict[str, Any]] | dict[str, Any] | None = Field(default=None) id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True) share_crew: bool | None = Field(default=False) - step_callback: Any | None = Field( + step_callback: SerializableCallable | None = Field( default=None, description="Callback to be executed after each step for all agents execution.", ) - task_callback: Any | None = Field( + task_callback: SerializableCallable | None = Field( default=None, description="Callback to be executed after each task for all agents execution.", ) - before_kickoff_callbacks: list[ - Callable[[dict[str, Any] | None], dict[str, Any] | None] - ] = Field( + before_kickoff_callbacks: list[SerializableCallable] = Field( default_factory=list, description=( "List of callbacks to be executed before crew kickoff. " "It may be used to adjust inputs before the crew is executed." ), ) - after_kickoff_callbacks: list[Callable[[CrewOutput], CrewOutput]] = Field( + after_kickoff_callbacks: list[SerializableCallable] = Field( default_factory=list, description=( "List of callbacks to be executed after crew kickoff. " @@ -349,7 +351,7 @@ class Crew(FlowTrackable, BaseModel): self._file_handler = FileHandler(self.output_log_file) self._rpm_controller = RPMController(max_rpm=self.max_rpm, logger=self._logger) if self.function_calling_llm and not isinstance(self.function_calling_llm, LLM): - self.function_calling_llm = create_llm(self.function_calling_llm) + self.function_calling_llm = create_llm(self.function_calling_llm) # type: ignore[assignment] return self @@ -363,7 +365,7 @@ class Crew(FlowTrackable, BaseModel): if self.embedder is not None: from crewai.rag.embeddings.factory import build_embedder - embedder = build_embedder(self.embedder) + embedder = build_embedder(self.embedder) # type: ignore[arg-type] self._memory = Memory(embedder=embedder) elif self.memory: # User passed a Memory / MemoryScope / MemorySlice instance diff --git a/lib/crewai/src/crewai/flow/flow.py b/lib/crewai/src/crewai/flow/flow.py index 71bd31915..99c5edab4 100644 --- a/lib/crewai/src/crewai/flow/flow.py +++ b/lib/crewai/src/crewai/flow/flow.py @@ -81,6 +81,7 @@ from crewai.flow.flow_wrappers import ( SimpleFlowCondition, StartMethod, ) +from crewai.flow.input_provider import InputProvider from crewai.flow.persistence.base import FlowPersistence from crewai.flow.types import ( FlowExecutionData, @@ -99,6 +100,8 @@ from crewai.flow.utils import ( is_flow_method_name, is_simple_flow_condition, ) +from crewai.memory.memory_scope import MemoryScope, MemorySlice +from crewai.memory.unified_memory import Memory if TYPE_CHECKING: @@ -501,7 +504,7 @@ class LockedListProxy(list, Generic[T]): # type: ignore[type-arg] def index( self, value: T, start: SupportsIndex = 0, stop: SupportsIndex | None = None - ) -> int: # type: ignore[override] + ) -> int: if stop is None: return self._list.index(value, start) return self._list.index(value, start, stop) @@ -520,13 +523,13 @@ class LockedListProxy(list, Generic[T]): # type: ignore[type-arg] def copy(self) -> list[T]: return self._list.copy() - def __add__(self, other: list[T]) -> list[T]: + def __add__(self, other: list[T]) -> list[T]: # type: ignore[override] return self._list + other def __radd__(self, other: list[T]) -> list[T]: return other + self._list - def __iadd__(self, other: Iterable[T]) -> LockedListProxy[T]: + def __iadd__(self, other: Iterable[T]) -> LockedListProxy[T]: # type: ignore[override] with self._lock: self._list += list(other) return self @@ -630,13 +633,13 @@ class LockedDictProxy(dict, Generic[T]): # type: ignore[type-arg] def copy(self) -> dict[str, T]: return self._dict.copy() - def __or__(self, other: dict[str, T]) -> dict[str, T]: + def __or__(self, other: dict[str, T]) -> dict[str, T]: # type: ignore[override] return self._dict | other - def __ror__(self, other: dict[str, T]) -> dict[str, T]: + def __ror__(self, other: dict[str, T]) -> dict[str, T]: # type: ignore[override] return other | self._dict - def __ior__(self, other: dict[str, T]) -> LockedDictProxy[T]: + def __ior__(self, other: dict[str, T]) -> LockedDictProxy[T]: # type: ignore[override] with self._lock: self._dict |= other return self @@ -822,10 +825,8 @@ class Flow(Generic[T], metaclass=FlowMeta): name: str | None = None tracing: bool | None = None stream: bool = False - memory: Any = ( - None # Memory | MemoryScope | MemorySlice | None; auto-created if not set - ) - input_provider: Any = None # InputProvider | None; per-flow override for self.ask() + memory: Memory | MemoryScope | MemorySlice | None = None + input_provider: InputProvider | None = None def __class_getitem__(cls: type[Flow[T]], item: type[T]) -> type[Flow[T]]: class _FlowGeneric(cls): # type: ignore @@ -904,8 +905,6 @@ class Flow(Generic[T], metaclass=FlowMeta): # Internal flows (RecallFlow, EncodingFlow) set _skip_auto_memory # to avoid creating a wasteful standalone Memory instance. if self.memory is None and not getattr(self, "_skip_auto_memory", False): - from crewai.memory.unified_memory import Memory - self.memory = Memory() # Register all flow-related methods @@ -951,10 +950,16 @@ class Flow(Generic[T], metaclass=FlowMeta): Raises: ValueError: If no memory is configured for this flow. + TypeError: If batch remember is attempted on a MemoryScope or MemorySlice. """ if self.memory is None: raise ValueError("No memory configured for this flow") if isinstance(content, list): + if not isinstance(self.memory, Memory): + raise TypeError( + "Batch remember requires a Memory instance, " + f"got {type(self.memory).__name__}" + ) return self.memory.remember_many(content, **kwargs) return self.memory.remember(content, **kwargs) @@ -2725,7 +2730,7 @@ class Flow(Generic[T], metaclass=FlowMeta): # ── User Input (self.ask) ──────────────────────────────────────── - def _resolve_input_provider(self) -> Any: + def _resolve_input_provider(self) -> InputProvider: """Resolve the input provider using the priority chain. Resolution order: diff --git a/lib/crewai/src/crewai/flow/flow_config.py b/lib/crewai/src/crewai/flow/flow_config.py index a4a6bfbe4..7cb838b42 100644 --- a/lib/crewai/src/crewai/flow/flow_config.py +++ b/lib/crewai/src/crewai/flow/flow_config.py @@ -6,7 +6,7 @@ customize Flow behavior at runtime. from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -32,17 +32,17 @@ class FlowConfig: self._input_provider: InputProvider | None = None @property - def hitl_provider(self) -> Any: + def hitl_provider(self) -> HumanFeedbackProvider | None: """Get the configured HITL provider.""" return self._hitl_provider @hitl_provider.setter - def hitl_provider(self, provider: Any) -> None: + def hitl_provider(self, provider: HumanFeedbackProvider | None) -> None: """Set the HITL provider.""" self._hitl_provider = provider @property - def input_provider(self) -> Any: + def input_provider(self) -> InputProvider | None: """Get the configured input provider for ``Flow.ask()``. Returns: @@ -52,7 +52,7 @@ class FlowConfig: return self._input_provider @input_provider.setter - def input_provider(self, provider: Any) -> None: + def input_provider(self, provider: InputProvider | None) -> None: """Set the input provider for ``Flow.ask()``. Args: diff --git a/lib/crewai/src/crewai/memory/unified_memory.py b/lib/crewai/src/crewai/memory/unified_memory.py index 2d367dcf8..74761c0bb 100644 --- a/lib/crewai/src/crewai/memory/unified_memory.py +++ b/lib/crewai/src/crewai/memory/unified_memory.py @@ -22,7 +22,6 @@ from crewai.events.types.memory_events import ( ) from crewai.llms.base_llm import BaseLLM from crewai.memory.analyze import extract_memories_from_content -from crewai.memory.recall_flow import RecallFlow from crewai.memory.storage.backend import StorageBackend from crewai.memory.types import ( MemoryConfig, @@ -620,6 +619,8 @@ class Memory(BaseModel): ) results.sort(key=lambda m: m.score, reverse=True) else: + from crewai.memory.recall_flow import RecallFlow + flow = RecallFlow( storage=self._storage, llm=self._llm, diff --git a/lib/crewai/src/crewai/task.py b/lib/crewai/src/crewai/task.py index 6977eb638..17fbac3d4 100644 --- a/lib/crewai/src/crewai/task.py +++ b/lib/crewai/src/crewai/task.py @@ -67,6 +67,7 @@ except ImportError: return [] +from crewai.types.callback import SerializableCallable from crewai.utilities.guardrail import ( process_guardrail, ) @@ -124,7 +125,7 @@ class Task(BaseModel): description="Configuration for the agent", default=None, ) - callback: Any | None = Field( + callback: SerializableCallable | None = Field( description="Callback to be executed after the task is completed.", default=None ) agent: BaseAgent | None = Field( diff --git a/lib/crewai/src/crewai/types/callback.py b/lib/crewai/src/crewai/types/callback.py new file mode 100644 index 000000000..2a8be235e --- /dev/null +++ b/lib/crewai/src/crewai/types/callback.py @@ -0,0 +1,152 @@ +"""Serializable callback type for Pydantic models. + +Provides a ``SerializableCallable`` type alias that enables full JSON +round-tripping of callback fields, e.g. ``"builtins.print"`` ↔ ``print``. +Lambdas and closures serialize to a dotted path but cannot be deserialized +back — use module-level named functions for checkpointable callbacks. +""" + +from __future__ import annotations + +from collections.abc import Callable +import importlib +import inspect +import os +from typing import Annotated, Any +import warnings + +from pydantic import BeforeValidator, WithJsonSchema +from pydantic.functional_serializers import PlainSerializer + + +def _is_non_roundtrippable(fn: object) -> bool: + """Return ``True`` if *fn* cannot survive a serialize/deserialize round-trip. + + Built-in functions, plain module-level functions, and classes produce + dotted paths that :func:`_resolve_dotted_path` can reliably resolve. + Bound methods, ``functools.partial`` objects, callable class instances, + lambdas, and closures all fail or silently change semantics during + round-tripping. + + Args: + fn: The object to check. + + Returns: + ``True`` if *fn* would not round-trip through JSON serialization. + """ + if inspect.isbuiltin(fn) or inspect.isclass(fn): + return False + if inspect.isfunction(fn): + qualname = getattr(fn, "__qualname__", "") + return qualname.endswith("") or "" in qualname + return True + + +def string_to_callable(value: Any) -> Callable[..., Any]: + """Convert a dotted path string to the callable it references. + + If *value* is already callable it is returned as-is, with a warning if + it cannot survive JSON round-tripping. Otherwise, it is treated as + ``"module.qualname"`` and resolved via :func:`_resolve_dotted_path`. + + Args: + value: A callable or a dotted-path string e.g. ``"builtins.print"``. + + Returns: + The resolved callable. + + Raises: + ValueError: If *value* is not callable or a resolvable dotted-path string. + """ + if callable(value): + if _is_non_roundtrippable(value): + warnings.warn( + f"{type(value).__name__} callbacks cannot be serialized " + "and will prevent checkpointing. " + "Use a module-level named function instead.", + UserWarning, + stacklevel=2, + ) + return value # type: ignore[no-any-return] + if not isinstance(value, str): + raise ValueError( + f"Expected a callable or dotted-path string, got {type(value).__name__}" + ) + if "." not in value: + raise ValueError( + f"Invalid callback path {value!r}: expected 'module.name' format" + ) + if not os.environ.get("CREWAI_DESERIALIZE_CALLBACKS"): + raise ValueError( + f"Refusing to resolve callback path {value!r}: " + "set CREWAI_DESERIALIZE_CALLBACKS=1 to allow. " + "Only enable this for trusted checkpoint data." + ) + return _resolve_dotted_path(value) + + +def _resolve_dotted_path(path: str) -> Callable[..., Any]: + """Import a module and walk attribute lookups to resolve a dotted path. + + Handles multi-level qualified names like ``"module.ClassName.method"`` + by trying progressively shorter module paths and resolving the remainder + as chained attribute lookups. + + Args: + path: A dotted string e.g. ``"builtins.print"`` or + ``"mymodule.MyClass.my_method"``. + + Returns: + The resolved callable. + + Raises: + ValueError: If no valid module can be imported from the path. + """ + parts = path.split(".") + # Try importing progressively shorter prefixes as the module. + for i in range(len(parts), 0, -1): + module_path = ".".join(parts[:i]) + try: + obj: Any = importlib.import_module(module_path) + except (ImportError, TypeError, ValueError): + continue + # Walk the remaining attribute chain. + try: + for attr in parts[i:]: + obj = getattr(obj, attr) + except AttributeError: + continue + if callable(obj): + return obj # type: ignore[no-any-return] + raise ValueError(f"Cannot resolve callback {path!r}") + + +def callable_to_string(fn: Callable[..., Any]) -> str: + """Serialize a callable to its dotted-path string representation. + + Uses ``fn.__module__`` and ``fn.__qualname__`` to produce a string such + as ``"builtins.print"``. Lambdas and closures produce paths that contain + ```` and cannot be round-tripped via :func:`string_to_callable`. + + Args: + fn: The callable to serialize. + + Returns: + A dotted string of the form ``"module.qualname"``. + """ + module = getattr(fn, "__module__", None) + qualname = getattr(fn, "__qualname__", None) + if module is None or qualname is None: + raise ValueError( + f"Cannot serialize {fn!r}: missing __module__ or __qualname__. " + "Use a module-level named function for checkpointable callbacks." + ) + return f"{module}.{qualname}" + + +SerializableCallable = Annotated[ + Callable[..., Any], + BeforeValidator(string_to_callable), + PlainSerializer(callable_to_string, return_type=str, when_used="json"), + WithJsonSchema({"type": "string"}), +] diff --git a/lib/crewai/tests/agents/test_agent.py b/lib/crewai/tests/agents/test_agent.py index a3aab28d6..d865ec541 100644 --- a/lib/crewai/tests/agents/test_agent.py +++ b/lib/crewai/tests/agents/test_agent.py @@ -1690,7 +1690,10 @@ def test_agent_with_knowledge_sources_works_with_copy(): with patch( "crewai.knowledge.storage.knowledge_storage.KnowledgeStorage" ) as mock_knowledge_storage: + from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage + mock_knowledge_storage_instance = mock_knowledge_storage.return_value + mock_knowledge_storage_instance.__class__ = BaseKnowledgeStorage agent.knowledge_storage = mock_knowledge_storage_instance agent_copy = agent.copy() diff --git a/lib/crewai/tests/test_callback.py b/lib/crewai/tests/test_callback.py new file mode 100644 index 000000000..417c74d98 --- /dev/null +++ b/lib/crewai/tests/test_callback.py @@ -0,0 +1,237 @@ +"""Tests for crewai.types.callback — SerializableCallable round-tripping.""" + +from __future__ import annotations + +import functools +import os +from typing import Any +import pytest +from pydantic import BaseModel, ValidationError + +from crewai.types.callback import ( + SerializableCallable, + _is_non_roundtrippable, + _resolve_dotted_path, + callable_to_string, + string_to_callable, +) + + +# ── Helpers ────────────────────────────────────────────────────────── + + +def module_level_function() -> str: + """Plain module-level function that should round-trip.""" + return "hello" + + +class _CallableInstance: + """Callable class instance — non-roundtrippable.""" + + def __call__(self) -> str: + return "instance" + + +class _HasMethod: + def method(self) -> str: + return "method" + + +class _Model(BaseModel): + cb: SerializableCallable | None = None + + +# ── _is_non_roundtrippable ─────────────────────────────────────────── + + +class TestIsNonRoundtrippable: + def test_builtin_is_roundtrippable(self) -> None: + assert _is_non_roundtrippable(print) is False + assert _is_non_roundtrippable(len) is False + + def test_class_is_roundtrippable(self) -> None: + assert _is_non_roundtrippable(dict) is False + assert _is_non_roundtrippable(_CallableInstance) is False + + def test_module_level_function_is_roundtrippable(self) -> None: + assert _is_non_roundtrippable(module_level_function) is False + + def test_lambda_is_non_roundtrippable(self) -> None: + assert _is_non_roundtrippable(lambda: None) is True + + def test_closure_is_non_roundtrippable(self) -> None: + x = 1 + + def closure() -> int: + return x + + assert _is_non_roundtrippable(closure) is True + + def test_bound_method_is_non_roundtrippable(self) -> None: + assert _is_non_roundtrippable(_HasMethod().method) is True + + def test_partial_is_non_roundtrippable(self) -> None: + assert _is_non_roundtrippable(functools.partial(print, "hi")) is True + + def test_callable_instance_is_non_roundtrippable(self) -> None: + assert _is_non_roundtrippable(_CallableInstance()) is True + + +# ── callable_to_string ─────────────────────────────────────────────── + + +class TestCallableToString: + def test_module_level_function(self) -> None: + result = callable_to_string(module_level_function) + assert result == f"{__name__}.module_level_function" + + def test_class(self) -> None: + result = callable_to_string(dict) + assert result == "builtins.dict" + + def test_builtin(self) -> None: + result = callable_to_string(print) + assert result == "builtins.print" + + def test_lambda_produces_locals_path(self) -> None: + fn = lambda: None # noqa: E731 + result = callable_to_string(fn) + assert "" in result + + def test_missing_qualname_raises(self) -> None: + obj = type("NoQual", (), {"__module__": "test"})() + obj.__qualname__ = None # type: ignore[assignment] + with pytest.raises(ValueError, match="missing __module__ or __qualname__"): + callable_to_string(obj) + + def test_missing_module_raises(self) -> None: + # Create an object where getattr(obj, "__module__", None) returns None + ns: dict[str, Any] = {"__qualname__": "x", "__module__": None} + obj = type("NoMod", (), ns)() + with pytest.raises(ValueError, match="missing __module__"): + callable_to_string(obj) + + +# ── string_to_callable ─────────────────────────────────────────────── + + +class TestStringToCallable: + def test_callable_passthrough(self) -> None: + assert string_to_callable(print) is print + + def test_roundtrippable_callable_no_warning(self, recwarn: pytest.WarningsChecker) -> None: + string_to_callable(module_level_function) + our_warnings = [ + w for w in recwarn if "cannot be serialized" in str(w.message) + ] + assert our_warnings == [] + + def test_non_roundtrippable_warns(self) -> None: + with pytest.warns(UserWarning, match="cannot be serialized"): + string_to_callable(functools.partial(print)) + + def test_non_callable_non_string_raises(self) -> None: + with pytest.raises(ValueError, match="Expected a callable"): + string_to_callable(42) + + def test_string_without_dot_raises(self) -> None: + with pytest.raises(ValueError, match="expected 'module.name' format"): + string_to_callable("nodots") + + def test_string_refused_without_env_var(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("CREWAI_DESERIALIZE_CALLBACKS", raising=False) + with pytest.raises(ValueError, match="Refusing to resolve"): + string_to_callable("builtins.print") + + def test_string_resolves_with_env_var(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("CREWAI_DESERIALIZE_CALLBACKS", "1") + result = string_to_callable("builtins.print") + assert result is print + + def test_string_resolves_multi_level_path(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("CREWAI_DESERIALIZE_CALLBACKS", "1") + result = string_to_callable("os.path.join") + assert result is os.path.join + + def test_unresolvable_path_raises(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("CREWAI_DESERIALIZE_CALLBACKS", "1") + with pytest.raises(ValueError, match="Cannot resolve"): + string_to_callable("nonexistent.module.func") + + +# ── _resolve_dotted_path ───────────────────────────────────────────── + + +class TestResolveDottedPath: + def test_builtin(self) -> None: + assert _resolve_dotted_path("builtins.print") is print + + def test_nested_module_attribute(self) -> None: + assert _resolve_dotted_path("os.path.join") is os.path.join + + def test_class_on_module(self) -> None: + from collections import OrderedDict + + assert _resolve_dotted_path("collections.OrderedDict") is OrderedDict + + def test_nonexistent_raises(self) -> None: + with pytest.raises(ValueError, match="Cannot resolve"): + _resolve_dotted_path("no.such.module.func") + + def test_non_callable_attribute_skipped(self) -> None: + # os.sep is a string, not callable — should not resolve + with pytest.raises(ValueError, match="Cannot resolve"): + _resolve_dotted_path("os.sep") + + +# ── Pydantic integration round-trip ────────────────────────────────── + + +class TestSerializableCallableRoundTrip: + def test_json_serialize_module_function(self) -> None: + m = _Model(cb=module_level_function) + data = m.model_dump(mode="json") + assert data["cb"] == f"{__name__}.module_level_function" + + def test_json_round_trip(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("CREWAI_DESERIALIZE_CALLBACKS", "1") + m = _Model(cb=print) + json_str = m.model_dump_json() + restored = _Model.model_validate_json(json_str) + assert restored.cb is print + + def test_json_round_trip_class(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("CREWAI_DESERIALIZE_CALLBACKS", "1") + m = _Model(cb=dict) + json_str = m.model_dump_json() + restored = _Model.model_validate_json(json_str) + assert restored.cb is dict + + def test_python_mode_preserves_callable(self) -> None: + m = _Model(cb=module_level_function) + data = m.model_dump(mode="python") + assert data["cb"] is module_level_function + + def test_none_field(self) -> None: + m = _Model(cb=None) + assert m.cb is None + data = m.model_dump(mode="json") + assert data["cb"] is None + + def test_validation_error_for_int(self) -> None: + with pytest.raises(ValidationError): + _Model(cb=42) # type: ignore[arg-type] + + def test_deserialization_refused_without_env( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.delenv("CREWAI_DESERIALIZE_CALLBACKS", raising=False) + with pytest.raises(ValidationError, match="Refusing to resolve"): + _Model.model_validate({"cb": "builtins.print"}) + + def test_json_schema_is_string(self) -> None: + schema = _Model.model_json_schema() + cb_schema = schema["properties"]["cb"] + # anyOf for Optional: one string, one null + types = {item.get("type") for item in cb_schema.get("anyOf", [cb_schema])} + assert "string" in types \ No newline at end of file diff --git a/lib/crewai/tests/test_project.py b/lib/crewai/tests/test_project.py index 4962ff08c..6334cb777 100644 --- a/lib/crewai/tests/test_project.py +++ b/lib/crewai/tests/test_project.py @@ -6,6 +6,7 @@ from crewai.agent import Agent from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.crew import Crew from crewai.llm import LLM +from crewai.llms.base_llm import BaseLLM from crewai.project import ( CrewBase, after_kickoff, @@ -371,9 +372,12 @@ def test_internal_crew_with_mcp(): mock_adapter = Mock() mock_adapter.tools = ToolCollection([simple_tool, another_simple_tool]) + mock_llm = Mock() + mock_llm.__class__ = BaseLLM + with ( patch("crewai_tools.MCPServerAdapter", return_value=mock_adapter) as adapter_mock, - patch("crewai.llm.LLM.__new__", return_value=Mock()), + patch("crewai.llm.LLM.__new__", return_value=mock_llm), ): crew = InternalCrewWithMCP() assert crew.reporting_analyst().tools == [simple_tool, another_simple_tool]