mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-02 05:38:12 +00:00
refactor: replace Any-typed callback and model fields with serializable types
This commit is contained in:
@@ -66,6 +66,7 @@ from crewai.mcp.tool_resolver import MCPToolResolver
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
from crewai.security.fingerprint import Fingerprint
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.types.callback import SerializableCallable
|
||||
from crewai.utilities.agent_utils import (
|
||||
get_tool_names,
|
||||
is_inside_event_loop,
|
||||
@@ -143,7 +144,7 @@ class Agent(BaseAgent):
|
||||
default=None,
|
||||
description="Maximum execution time for an agent to execute a task",
|
||||
)
|
||||
step_callback: Any | None = Field(
|
||||
step_callback: SerializableCallable | None = Field(
|
||||
default=None,
|
||||
description="Callback to be executed after each step of the agent execution.",
|
||||
)
|
||||
@@ -151,10 +152,10 @@ class Agent(BaseAgent):
|
||||
default=True,
|
||||
description="Use system prompt for the agent.",
|
||||
)
|
||||
llm: str | InstanceOf[BaseLLM] | Any = Field(
|
||||
llm: str | InstanceOf[BaseLLM] | None = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
function_calling_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
|
||||
function_calling_llm: str | InstanceOf[BaseLLM] | None = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
system_template: str | None = Field(
|
||||
@@ -340,7 +341,7 @@ class Agent(BaseAgent):
|
||||
return (
|
||||
hasattr(self.llm, "supports_function_calling")
|
||||
and callable(getattr(self.llm, "supports_function_calling", None))
|
||||
and self.llm.supports_function_calling()
|
||||
and self.llm.supports_function_calling() # type: ignore[union-attr]
|
||||
and len(tools) > 0
|
||||
)
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ from pydantic import (
|
||||
UUID4,
|
||||
BaseModel,
|
||||
Field,
|
||||
InstanceOf,
|
||||
PrivateAttr,
|
||||
field_validator,
|
||||
model_validator,
|
||||
@@ -26,7 +27,10 @@ from crewai.agents.tools_handler import ToolsHandler
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.knowledge_config import KnowledgeConfig
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
|
||||
from crewai.mcp.config import MCPServerConfig
|
||||
from crewai.memory.memory_scope import MemoryScope, MemorySlice
|
||||
from crewai.memory.unified_memory import Memory
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
from crewai.security.security_config import SecurityConfig
|
||||
from crewai.tools.base_tool import BaseTool, Tool
|
||||
@@ -179,7 +183,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
default=None,
|
||||
description="Knowledge sources for the agent.",
|
||||
)
|
||||
knowledge_storage: Any | None = Field(
|
||||
knowledge_storage: InstanceOf[BaseKnowledgeStorage] | None = Field(
|
||||
default=None,
|
||||
description="Custom knowledge storage for the agent.",
|
||||
)
|
||||
@@ -205,7 +209,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
default=None,
|
||||
description="List of MCP server references. Supports 'https://server.com/path' for external servers and bare slugs like 'notion' for connected MCP integrations. Use '#tool_name' suffix for specific tools.",
|
||||
)
|
||||
memory: Any = Field(
|
||||
memory: bool | Memory | MemoryScope | MemorySlice | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Enable agent memory. Pass True for default Memory(), "
|
||||
|
||||
@@ -35,6 +35,7 @@ from typing_extensions import Self
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai_files import FileInput
|
||||
from opentelemetry.trace import Span
|
||||
|
||||
try:
|
||||
from crewai_files import get_supported_content_types
|
||||
@@ -83,6 +84,8 @@ from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.memory.memory_scope import MemoryScope, MemorySlice
|
||||
from crewai.memory.unified_memory import Memory
|
||||
from crewai.process import Process
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
from crewai.rag.types import SearchResult
|
||||
@@ -94,6 +97,7 @@ from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.tools.agent_tools.read_file_tool import ReadFileTool
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.types.callback import SerializableCallable
|
||||
from crewai.types.streaming import CrewStreamingOutput
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
from crewai.utilities.constants import NOT_SPECIFIED, TRAINING_DATA_FILE
|
||||
@@ -166,12 +170,12 @@ class Crew(FlowTrackable, BaseModel):
|
||||
"""
|
||||
|
||||
__hash__ = object.__hash__
|
||||
_execution_span: Any = PrivateAttr()
|
||||
_execution_span: Span | None = PrivateAttr()
|
||||
_rpm_controller: RPMController = PrivateAttr()
|
||||
_logger: Logger = PrivateAttr()
|
||||
_file_handler: FileHandler = PrivateAttr()
|
||||
_cache_handler: InstanceOf[CacheHandler] = PrivateAttr(default_factory=CacheHandler)
|
||||
_memory: Any = PrivateAttr(default=None) # Unified Memory | MemoryScope
|
||||
_memory: Memory | MemoryScope | MemorySlice | None = PrivateAttr(default=None)
|
||||
_train: bool | None = PrivateAttr(default=False)
|
||||
_train_iteration: int | None = PrivateAttr()
|
||||
_inputs: dict[str, Any] | None = PrivateAttr(default=None)
|
||||
@@ -189,7 +193,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
agents: list[BaseAgent] = Field(default_factory=list)
|
||||
process: Process = Field(default=Process.sequential)
|
||||
verbose: bool = Field(default=False)
|
||||
memory: bool | Any = Field(
|
||||
memory: bool | Memory | MemoryScope | MemorySlice | None = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"Enable crew memory. Pass True for default Memory(), "
|
||||
@@ -204,23 +208,23 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default=None,
|
||||
description="Metrics for the LLM usage during all tasks execution.",
|
||||
)
|
||||
manager_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
|
||||
manager_llm: str | InstanceOf[BaseLLM] | None = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
manager_agent: BaseAgent | None = Field(
|
||||
description="Custom agent that will be used as manager.", default=None
|
||||
)
|
||||
function_calling_llm: str | InstanceOf[LLM] | Any | None = Field(
|
||||
function_calling_llm: str | InstanceOf[LLM] | None = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
config: Json[dict[str, Any]] | dict[str, Any] | None = Field(default=None)
|
||||
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
|
||||
share_crew: bool | None = Field(default=False)
|
||||
step_callback: Any | None = Field(
|
||||
step_callback: SerializableCallable | None = Field(
|
||||
default=None,
|
||||
description="Callback to be executed after each step for all agents execution.",
|
||||
)
|
||||
task_callback: Any | None = Field(
|
||||
task_callback: SerializableCallable | None = Field(
|
||||
default=None,
|
||||
description="Callback to be executed after each task for all agents execution.",
|
||||
)
|
||||
@@ -349,7 +353,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self._file_handler = FileHandler(self.output_log_file)
|
||||
self._rpm_controller = RPMController(max_rpm=self.max_rpm, logger=self._logger)
|
||||
if self.function_calling_llm and not isinstance(self.function_calling_llm, LLM):
|
||||
self.function_calling_llm = create_llm(self.function_calling_llm)
|
||||
self.function_calling_llm = create_llm(self.function_calling_llm) # type: ignore[assignment]
|
||||
|
||||
return self
|
||||
|
||||
@@ -363,7 +367,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
if self.embedder is not None:
|
||||
from crewai.rag.embeddings.factory import build_embedder
|
||||
|
||||
embedder = build_embedder(self.embedder)
|
||||
embedder = build_embedder(self.embedder) # type: ignore[arg-type]
|
||||
self._memory = Memory(embedder=embedder)
|
||||
elif self.memory:
|
||||
# User passed a Memory / MemoryScope / MemorySlice instance
|
||||
|
||||
@@ -22,7 +22,6 @@ from crewai.events.types.memory_events import (
|
||||
)
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.memory.analyze import extract_memories_from_content
|
||||
from crewai.memory.recall_flow import RecallFlow
|
||||
from crewai.memory.storage.backend import StorageBackend
|
||||
from crewai.memory.types import (
|
||||
MemoryConfig,
|
||||
@@ -620,6 +619,8 @@ class Memory(BaseModel):
|
||||
)
|
||||
results.sort(key=lambda m: m.score, reverse=True)
|
||||
else:
|
||||
from crewai.memory.recall_flow import RecallFlow
|
||||
|
||||
flow = RecallFlow(
|
||||
storage=self._storage,
|
||||
llm=self._llm,
|
||||
|
||||
@@ -67,6 +67,7 @@ except ImportError:
|
||||
return []
|
||||
|
||||
|
||||
from crewai.types.callback import SerializableCallable
|
||||
from crewai.utilities.guardrail import (
|
||||
process_guardrail,
|
||||
)
|
||||
@@ -124,7 +125,7 @@ class Task(BaseModel):
|
||||
description="Configuration for the agent",
|
||||
default=None,
|
||||
)
|
||||
callback: Any | None = Field(
|
||||
callback: SerializableCallable | None = Field(
|
||||
description="Callback to be executed after the task is completed.", default=None
|
||||
)
|
||||
agent: BaseAgent | None = Field(
|
||||
|
||||
90
lib/crewai/src/crewai/types/callback.py
Normal file
90
lib/crewai/src/crewai/types/callback.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""Serializable callback type for Pydantic models.
|
||||
|
||||
Provides a ``SerializableCallable`` type alias that enables full JSON
|
||||
round-tripping of callback fields, e.g. ``"builtins.print"`` ↔ ``print``.
|
||||
Lambdas and closures serialize to a dotted path but cannot be deserialized
|
||||
back — use module-level named functions for checkpointable callbacks.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import importlib
|
||||
import inspect
|
||||
from typing import Annotated, Any
|
||||
import warnings
|
||||
|
||||
from pydantic import BeforeValidator, WithJsonSchema
|
||||
from pydantic.functional_serializers import PlainSerializer
|
||||
|
||||
|
||||
def _is_lambda(fn: object) -> bool:
|
||||
"""Return ``True`` if *fn* is a lambda expression.
|
||||
|
||||
Uses ``__qualname__`` ending with ``"<lambda>"`` for resilience against
|
||||
``__name__`` being reassigned. ``inspect.isfunction`` gates the check
|
||||
so non-function callables (classes, partials, etc.) are never flagged.
|
||||
|
||||
Args:
|
||||
fn: The object to check.
|
||||
|
||||
Returns:
|
||||
``True`` if *fn* is a lambda, ``False`` otherwise.
|
||||
"""
|
||||
return inspect.isfunction(fn) and getattr(fn, "__qualname__", "").endswith(
|
||||
"<lambda>"
|
||||
)
|
||||
|
||||
|
||||
def string_to_callable(value: Any) -> Callable[..., Any]:
|
||||
"""Convert a dotted path string to the callable it references.
|
||||
|
||||
If *value* is already callable it is returned as-is, with a warning if
|
||||
it is a lambda. Otherwise, it is treated as ``"module.qualname"`` and
|
||||
resolved via :func:`importlib.import_module`.
|
||||
|
||||
Args:
|
||||
value: A callable or a dotted-path string e.g. ``"builtins.print"``.
|
||||
|
||||
Returns:
|
||||
The resolved callable.
|
||||
|
||||
Raises:
|
||||
ModuleNotFoundError: If the module portion of the path cannot be imported.
|
||||
AttributeError: If the attribute cannot be found on the imported module.
|
||||
"""
|
||||
if callable(value):
|
||||
if _is_lambda(value):
|
||||
warnings.warn(
|
||||
"Lambdas cannot be serialized and will prevent checkpointing. "
|
||||
"Use a module-level named function instead.",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return value # type: ignore[no-any-return]
|
||||
module, func = value.rsplit(".", 1)
|
||||
return getattr(importlib.import_module(module), func) # type: ignore[no-any-return]
|
||||
|
||||
|
||||
def callable_to_string(fn: Callable[..., Any]) -> str:
|
||||
"""Serialize a callable to its dotted-path string representation.
|
||||
|
||||
Uses ``fn.__module__`` and ``fn.__qualname__`` to produce a string such
|
||||
as ``"builtins.print"``. Lambdas and closures produce paths that contain
|
||||
``<locals>`` and cannot be round-tripped via :func:`string_to_callable`.
|
||||
|
||||
Args:
|
||||
fn: The callable to serialize.
|
||||
|
||||
Returns:
|
||||
A dotted string of the form ``"module.qualname"``.
|
||||
"""
|
||||
return f"{fn.__module__}.{fn.__qualname__}"
|
||||
|
||||
|
||||
SerializableCallable = Annotated[
|
||||
Callable[..., Any],
|
||||
BeforeValidator(string_to_callable),
|
||||
PlainSerializer(callable_to_string, return_type=str),
|
||||
WithJsonSchema({"type": "string"}),
|
||||
]
|
||||
@@ -1690,7 +1690,10 @@ def test_agent_with_knowledge_sources_works_with_copy():
|
||||
with patch(
|
||||
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
|
||||
) as mock_knowledge_storage:
|
||||
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
|
||||
|
||||
mock_knowledge_storage_instance = mock_knowledge_storage.return_value
|
||||
mock_knowledge_storage_instance.__class__ = BaseKnowledgeStorage
|
||||
agent.knowledge_storage = mock_knowledge_storage_instance
|
||||
|
||||
agent_copy = agent.copy()
|
||||
|
||||
Reference in New Issue
Block a user