From 395f6f1ae211fe2920410e6c87a7bb5dcbca289f Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Thu, 19 Mar 2026 19:17:24 -0400 Subject: [PATCH] refactor: replace Any-typed callback and model fields with serializable types --- lib/crewai/src/crewai/agent/core.py | 9 +- .../crewai/agents/agent_builder/base_agent.py | 8 +- lib/crewai/src/crewai/crew.py | 22 +++-- .../src/crewai/memory/unified_memory.py | 3 +- lib/crewai/src/crewai/task.py | 3 +- lib/crewai/src/crewai/types/callback.py | 90 +++++++++++++++++++ lib/crewai/tests/agents/test_agent.py | 3 + 7 files changed, 121 insertions(+), 17 deletions(-) create mode 100644 lib/crewai/src/crewai/types/callback.py diff --git a/lib/crewai/src/crewai/agent/core.py b/lib/crewai/src/crewai/agent/core.py index 55eb807ef..3aa48137d 100644 --- a/lib/crewai/src/crewai/agent/core.py +++ b/lib/crewai/src/crewai/agent/core.py @@ -66,6 +66,7 @@ from crewai.mcp.tool_resolver import MCPToolResolver from crewai.rag.embeddings.types import EmbedderConfig from crewai.security.fingerprint import Fingerprint from crewai.tools.agent_tools.agent_tools import AgentTools +from crewai.types.callback import SerializableCallable from crewai.utilities.agent_utils import ( get_tool_names, is_inside_event_loop, @@ -143,7 +144,7 @@ class Agent(BaseAgent): default=None, description="Maximum execution time for an agent to execute a task", ) - step_callback: Any | None = Field( + step_callback: SerializableCallable | None = Field( default=None, description="Callback to be executed after each step of the agent execution.", ) @@ -151,10 +152,10 @@ class Agent(BaseAgent): default=True, description="Use system prompt for the agent.", ) - llm: str | InstanceOf[BaseLLM] | Any = Field( + llm: str | InstanceOf[BaseLLM] | None = Field( description="Language model that will run the agent.", default=None ) - function_calling_llm: str | InstanceOf[BaseLLM] | Any | None = Field( + function_calling_llm: str | InstanceOf[BaseLLM] | None = Field( description="Language model that will run the agent.", default=None ) system_template: str | None = Field( @@ -340,7 +341,7 @@ class Agent(BaseAgent): return ( hasattr(self.llm, "supports_function_calling") and callable(getattr(self.llm, "supports_function_calling", None)) - and self.llm.supports_function_calling() + and self.llm.supports_function_calling() # type: ignore[union-attr] and len(tools) > 0 ) diff --git a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py index da32d9c1c..e2e146e10 100644 --- a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py +++ b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py @@ -12,6 +12,7 @@ from pydantic import ( UUID4, BaseModel, Field, + InstanceOf, PrivateAttr, field_validator, model_validator, @@ -26,7 +27,10 @@ from crewai.agents.tools_handler import ToolsHandler from crewai.knowledge.knowledge import Knowledge from crewai.knowledge.knowledge_config import KnowledgeConfig from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource +from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage from crewai.mcp.config import MCPServerConfig +from crewai.memory.memory_scope import MemoryScope, MemorySlice +from crewai.memory.unified_memory import Memory from crewai.rag.embeddings.types import EmbedderConfig from crewai.security.security_config import SecurityConfig from crewai.tools.base_tool import BaseTool, Tool @@ -179,7 +183,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta): default=None, description="Knowledge sources for the agent.", ) - knowledge_storage: Any | None = Field( + knowledge_storage: InstanceOf[BaseKnowledgeStorage] | None = Field( default=None, description="Custom knowledge storage for the agent.", ) @@ -205,7 +209,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta): default=None, description="List of MCP server references. Supports 'https://server.com/path' for external servers and bare slugs like 'notion' for connected MCP integrations. Use '#tool_name' suffix for specific tools.", ) - memory: Any = Field( + memory: bool | Memory | MemoryScope | MemorySlice | None = Field( default=None, description=( "Enable agent memory. Pass True for default Memory(), " diff --git a/lib/crewai/src/crewai/crew.py b/lib/crewai/src/crewai/crew.py index 61d1f52cf..5e23f37d6 100644 --- a/lib/crewai/src/crewai/crew.py +++ b/lib/crewai/src/crewai/crew.py @@ -35,6 +35,7 @@ from typing_extensions import Self if TYPE_CHECKING: from crewai_files import FileInput + from opentelemetry.trace import Span try: from crewai_files import get_supported_content_types @@ -83,6 +84,8 @@ from crewai.knowledge.knowledge import Knowledge from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.llm import LLM from crewai.llms.base_llm import BaseLLM +from crewai.memory.memory_scope import MemoryScope, MemorySlice +from crewai.memory.unified_memory import Memory from crewai.process import Process from crewai.rag.embeddings.types import EmbedderConfig from crewai.rag.types import SearchResult @@ -94,6 +97,7 @@ from crewai.tasks.task_output import TaskOutput from crewai.tools.agent_tools.agent_tools import AgentTools from crewai.tools.agent_tools.read_file_tool import ReadFileTool from crewai.tools.base_tool import BaseTool +from crewai.types.callback import SerializableCallable from crewai.types.streaming import CrewStreamingOutput from crewai.types.usage_metrics import UsageMetrics from crewai.utilities.constants import NOT_SPECIFIED, TRAINING_DATA_FILE @@ -166,12 +170,12 @@ class Crew(FlowTrackable, BaseModel): """ __hash__ = object.__hash__ - _execution_span: Any = PrivateAttr() + _execution_span: Span | None = PrivateAttr() _rpm_controller: RPMController = PrivateAttr() _logger: Logger = PrivateAttr() _file_handler: FileHandler = PrivateAttr() _cache_handler: InstanceOf[CacheHandler] = PrivateAttr(default_factory=CacheHandler) - _memory: Any = PrivateAttr(default=None) # Unified Memory | MemoryScope + _memory: Memory | MemoryScope | MemorySlice | None = PrivateAttr(default=None) _train: bool | None = PrivateAttr(default=False) _train_iteration: int | None = PrivateAttr() _inputs: dict[str, Any] | None = PrivateAttr(default=None) @@ -189,7 +193,7 @@ class Crew(FlowTrackable, BaseModel): agents: list[BaseAgent] = Field(default_factory=list) process: Process = Field(default=Process.sequential) verbose: bool = Field(default=False) - memory: bool | Any = Field( + memory: bool | Memory | MemoryScope | MemorySlice | None = Field( default=False, description=( "Enable crew memory. Pass True for default Memory(), " @@ -204,23 +208,23 @@ class Crew(FlowTrackable, BaseModel): default=None, description="Metrics for the LLM usage during all tasks execution.", ) - manager_llm: str | InstanceOf[BaseLLM] | Any | None = Field( + manager_llm: str | InstanceOf[BaseLLM] | None = Field( description="Language model that will run the agent.", default=None ) manager_agent: BaseAgent | None = Field( description="Custom agent that will be used as manager.", default=None ) - function_calling_llm: str | InstanceOf[LLM] | Any | None = Field( + function_calling_llm: str | InstanceOf[LLM] | None = Field( description="Language model that will run the agent.", default=None ) config: Json[dict[str, Any]] | dict[str, Any] | None = Field(default=None) id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True) share_crew: bool | None = Field(default=False) - step_callback: Any | None = Field( + step_callback: SerializableCallable | None = Field( default=None, description="Callback to be executed after each step for all agents execution.", ) - task_callback: Any | None = Field( + task_callback: SerializableCallable | None = Field( default=None, description="Callback to be executed after each task for all agents execution.", ) @@ -349,7 +353,7 @@ class Crew(FlowTrackable, BaseModel): self._file_handler = FileHandler(self.output_log_file) self._rpm_controller = RPMController(max_rpm=self.max_rpm, logger=self._logger) if self.function_calling_llm and not isinstance(self.function_calling_llm, LLM): - self.function_calling_llm = create_llm(self.function_calling_llm) + self.function_calling_llm = create_llm(self.function_calling_llm) # type: ignore[assignment] return self @@ -363,7 +367,7 @@ class Crew(FlowTrackable, BaseModel): if self.embedder is not None: from crewai.rag.embeddings.factory import build_embedder - embedder = build_embedder(self.embedder) + embedder = build_embedder(self.embedder) # type: ignore[arg-type] self._memory = Memory(embedder=embedder) elif self.memory: # User passed a Memory / MemoryScope / MemorySlice instance diff --git a/lib/crewai/src/crewai/memory/unified_memory.py b/lib/crewai/src/crewai/memory/unified_memory.py index 2d367dcf8..74761c0bb 100644 --- a/lib/crewai/src/crewai/memory/unified_memory.py +++ b/lib/crewai/src/crewai/memory/unified_memory.py @@ -22,7 +22,6 @@ from crewai.events.types.memory_events import ( ) from crewai.llms.base_llm import BaseLLM from crewai.memory.analyze import extract_memories_from_content -from crewai.memory.recall_flow import RecallFlow from crewai.memory.storage.backend import StorageBackend from crewai.memory.types import ( MemoryConfig, @@ -620,6 +619,8 @@ class Memory(BaseModel): ) results.sort(key=lambda m: m.score, reverse=True) else: + from crewai.memory.recall_flow import RecallFlow + flow = RecallFlow( storage=self._storage, llm=self._llm, diff --git a/lib/crewai/src/crewai/task.py b/lib/crewai/src/crewai/task.py index 6977eb638..17fbac3d4 100644 --- a/lib/crewai/src/crewai/task.py +++ b/lib/crewai/src/crewai/task.py @@ -67,6 +67,7 @@ except ImportError: return [] +from crewai.types.callback import SerializableCallable from crewai.utilities.guardrail import ( process_guardrail, ) @@ -124,7 +125,7 @@ class Task(BaseModel): description="Configuration for the agent", default=None, ) - callback: Any | None = Field( + callback: SerializableCallable | None = Field( description="Callback to be executed after the task is completed.", default=None ) agent: BaseAgent | None = Field( diff --git a/lib/crewai/src/crewai/types/callback.py b/lib/crewai/src/crewai/types/callback.py new file mode 100644 index 000000000..162311c63 --- /dev/null +++ b/lib/crewai/src/crewai/types/callback.py @@ -0,0 +1,90 @@ +"""Serializable callback type for Pydantic models. + +Provides a ``SerializableCallable`` type alias that enables full JSON +round-tripping of callback fields, e.g. ``"builtins.print"`` ↔ ``print``. +Lambdas and closures serialize to a dotted path but cannot be deserialized +back — use module-level named functions for checkpointable callbacks. +""" + +from __future__ import annotations + +from collections.abc import Callable +import importlib +import inspect +from typing import Annotated, Any +import warnings + +from pydantic import BeforeValidator, WithJsonSchema +from pydantic.functional_serializers import PlainSerializer + + +def _is_lambda(fn: object) -> bool: + """Return ``True`` if *fn* is a lambda expression. + + Uses ``__qualname__`` ending with ``""`` for resilience against + ``__name__`` being reassigned. ``inspect.isfunction`` gates the check + so non-function callables (classes, partials, etc.) are never flagged. + + Args: + fn: The object to check. + + Returns: + ``True`` if *fn* is a lambda, ``False`` otherwise. + """ + return inspect.isfunction(fn) and getattr(fn, "__qualname__", "").endswith( + "" + ) + + +def string_to_callable(value: Any) -> Callable[..., Any]: + """Convert a dotted path string to the callable it references. + + If *value* is already callable it is returned as-is, with a warning if + it is a lambda. Otherwise, it is treated as ``"module.qualname"`` and + resolved via :func:`importlib.import_module`. + + Args: + value: A callable or a dotted-path string e.g. ``"builtins.print"``. + + Returns: + The resolved callable. + + Raises: + ModuleNotFoundError: If the module portion of the path cannot be imported. + AttributeError: If the attribute cannot be found on the imported module. + """ + if callable(value): + if _is_lambda(value): + warnings.warn( + "Lambdas cannot be serialized and will prevent checkpointing. " + "Use a module-level named function instead.", + UserWarning, + stacklevel=2, + ) + return value # type: ignore[no-any-return] + module, func = value.rsplit(".", 1) + return getattr(importlib.import_module(module), func) # type: ignore[no-any-return] + + +def callable_to_string(fn: Callable[..., Any]) -> str: + """Serialize a callable to its dotted-path string representation. + + Uses ``fn.__module__`` and ``fn.__qualname__`` to produce a string such + as ``"builtins.print"``. Lambdas and closures produce paths that contain + ```` and cannot be round-tripped via :func:`string_to_callable`. + + Args: + fn: The callable to serialize. + + Returns: + A dotted string of the form ``"module.qualname"``. + """ + return f"{fn.__module__}.{fn.__qualname__}" + + +SerializableCallable = Annotated[ + Callable[..., Any], + BeforeValidator(string_to_callable), + PlainSerializer(callable_to_string, return_type=str), + WithJsonSchema({"type": "string"}), +] diff --git a/lib/crewai/tests/agents/test_agent.py b/lib/crewai/tests/agents/test_agent.py index a3aab28d6..d865ec541 100644 --- a/lib/crewai/tests/agents/test_agent.py +++ b/lib/crewai/tests/agents/test_agent.py @@ -1690,7 +1690,10 @@ def test_agent_with_knowledge_sources_works_with_copy(): with patch( "crewai.knowledge.storage.knowledge_storage.KnowledgeStorage" ) as mock_knowledge_storage: + from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage + mock_knowledge_storage_instance = mock_knowledge_storage.return_value + mock_knowledge_storage_instance.__class__ = BaseKnowledgeStorage agent.knowledge_storage = mock_knowledge_storage_instance agent_copy = agent.copy()