mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-07 02:02:35 +00:00
fix: serialize guardrail callable fields for checkpointing
Task fields `guardrail` and `guardrails` store callable references that caused PydanticSerializationError when RuntimeState serialized entities during checkpointing. Add PlainSerializer annotations that convert callables to their dotted-path strings via callable_to_string, matching the existing pattern used for callback fields. Fixes #5620 Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -73,7 +73,7 @@ except ImportError:
|
||||
return []
|
||||
|
||||
|
||||
from crewai.types.callback import SerializableCallable
|
||||
from crewai.types.callback import SerializableCallable, callable_to_string
|
||||
from crewai.utilities.guardrail import (
|
||||
process_guardrail,
|
||||
)
|
||||
@@ -87,6 +87,36 @@ from crewai.utilities.printer import PRINTER
|
||||
from crewai.utilities.string_utils import interpolate_only
|
||||
|
||||
|
||||
def _serialize_guardrail_item(v: Any) -> str | None:
|
||||
"""Serialize a single guardrail value for JSON checkpointing.
|
||||
|
||||
Callables are converted to their dotted-path string via
|
||||
:func:`callable_to_string`. Strings (LLM guardrail descriptions)
|
||||
are returned as-is.
|
||||
"""
|
||||
if v is None:
|
||||
return None
|
||||
if callable(v):
|
||||
return callable_to_string(v)
|
||||
if isinstance(v, str):
|
||||
return v
|
||||
return None
|
||||
|
||||
|
||||
def _serialize_guardrail(v: Any) -> str | None:
|
||||
"""PlainSerializer for the ``guardrail`` field."""
|
||||
return _serialize_guardrail_item(v)
|
||||
|
||||
|
||||
def _serialize_guardrails(v: Any) -> list[str] | str | None:
|
||||
"""PlainSerializer for the ``guardrails`` field."""
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, (list, tuple)):
|
||||
return [_serialize_guardrail_item(item) for item in v]
|
||||
return _serialize_guardrail_item(v)
|
||||
|
||||
|
||||
def _serialize_model_class(v: type[BaseModel] | None) -> dict[str, Any] | None:
|
||||
"""Serialize a Pydantic model class reference to its JSON schema."""
|
||||
return v.model_json_schema() if v else None
|
||||
@@ -235,11 +265,19 @@ class Task(BaseModel):
|
||||
default=None,
|
||||
)
|
||||
processed_by_agents: set[str] = Field(default_factory=set)
|
||||
guardrail: GuardrailType | None = Field(
|
||||
guardrail: Annotated[
|
||||
GuardrailType | None,
|
||||
PlainSerializer(_serialize_guardrail, return_type=str | None, when_used="json"),
|
||||
] = Field(
|
||||
default=None,
|
||||
description="Function or string description of a guardrail to validate task output before proceeding to next task",
|
||||
)
|
||||
guardrails: GuardrailsType | None = Field(
|
||||
guardrails: Annotated[
|
||||
GuardrailsType | None,
|
||||
PlainSerializer(
|
||||
_serialize_guardrails, return_type=list | str | None, when_used="json"
|
||||
),
|
||||
] = Field(
|
||||
default=None,
|
||||
description="List of guardrails to validate task output before proceeding to next task. Also supports a single guardrail function or string description of a guardrail to validate task output before proceeding to next task",
|
||||
)
|
||||
|
||||
@@ -694,3 +694,150 @@ class TestAgentCheckpoint:
|
||||
cfg = CheckpointConfig(restore_from=loc)
|
||||
restored = Agent.from_checkpoint(cfg)
|
||||
assert restored._kickoff_event_id == "evt-456"
|
||||
|
||||
|
||||
# ---------- Guardrail serialization (issue #5620) ----------
|
||||
|
||||
|
||||
def _sample_guardrail(output):
|
||||
"""Module-level guardrail function used in serialization tests."""
|
||||
return (True, output)
|
||||
|
||||
|
||||
def _another_guardrail(output):
|
||||
"""A second module-level guardrail for multi-guardrail tests."""
|
||||
return (True, output)
|
||||
|
||||
|
||||
class TestGuardrailCheckpointSerialization:
|
||||
"""Regression tests for checkpoint serialization of guardrail functions.
|
||||
|
||||
Issue #5620: ``model_dump(mode="json")`` raised
|
||||
``PydanticSerializationError: Unable to serialize unknown type: <class 'function'>``
|
||||
when a Task carried callable guardrails.
|
||||
"""
|
||||
|
||||
def test_task_with_callable_guardrail_serializes(self) -> None:
|
||||
"""A Task with a single callable guardrail must serialize to JSON."""
|
||||
task = Task(
|
||||
description="d",
|
||||
expected_output="e",
|
||||
guardrail=_sample_guardrail,
|
||||
)
|
||||
dumped = task.model_dump(mode="json")
|
||||
# The callable should be serialized as a dotted-path string
|
||||
assert isinstance(dumped["guardrail"], str)
|
||||
assert "_sample_guardrail" in dumped["guardrail"]
|
||||
|
||||
def test_task_with_callable_guardrails_list_serializes(self) -> None:
|
||||
"""A Task with a list of callable guardrails must serialize to JSON."""
|
||||
task = Task(
|
||||
description="d",
|
||||
expected_output="e",
|
||||
guardrails=[_sample_guardrail, _another_guardrail],
|
||||
)
|
||||
dumped = task.model_dump(mode="json")
|
||||
assert isinstance(dumped["guardrails"], list)
|
||||
assert len(dumped["guardrails"]) == 2
|
||||
assert all(isinstance(g, str) for g in dumped["guardrails"])
|
||||
assert "_sample_guardrail" in dumped["guardrails"][0]
|
||||
assert "_another_guardrail" in dumped["guardrails"][1]
|
||||
|
||||
def test_task_with_string_guardrail_serializes(self) -> None:
|
||||
"""A Task with a string guardrail must still serialize correctly.
|
||||
Note: string guardrails on the ``guardrail`` field require an agent
|
||||
with an LLM, so we supply a minimal agent."""
|
||||
agent = Agent(role="r", goal="g", backstory="b", llm="gpt-4o-mini")
|
||||
task = Task(
|
||||
description="d",
|
||||
expected_output="e",
|
||||
agent=agent,
|
||||
guardrail="Ensure output is valid JSON",
|
||||
)
|
||||
dumped = task.model_dump(mode="json")
|
||||
# String guardrails are converted to LLMGuardrail by the validator;
|
||||
# the field is cleared in favour of _guardrail
|
||||
# but we can still check the JSON round-trip doesn't crash
|
||||
assert isinstance(dumped, dict)
|
||||
|
||||
def test_task_with_mixed_guardrails_serializes(self) -> None:
|
||||
"""A Task with a mix of callable and string guardrails must serialize."""
|
||||
agent = Agent(role="r", goal="g", backstory="b", llm="gpt-4o-mini")
|
||||
task = Task(
|
||||
description="d",
|
||||
expected_output="e",
|
||||
agent=agent,
|
||||
guardrails=[_sample_guardrail, "Ensure output is valid"],
|
||||
)
|
||||
dumped = task.model_dump(mode="json")
|
||||
# The guardrails list may be processed by the validator; just ensure
|
||||
# serialization succeeds without PydanticSerializationError
|
||||
assert isinstance(dumped, dict)
|
||||
|
||||
def test_task_with_none_guardrails_serializes(self) -> None:
|
||||
"""A Task with no guardrails must serialize with None values."""
|
||||
task = Task(description="d", expected_output="e")
|
||||
dumped = task.model_dump(mode="json")
|
||||
assert dumped["guardrail"] is None
|
||||
assert dumped["guardrails"] is None
|
||||
|
||||
def test_crew_with_guardrail_task_serializes_for_checkpoint(self) -> None:
|
||||
"""A Crew containing tasks with callable guardrails must serialize
|
||||
through RuntimeState.model_dump (the checkpoint code path)."""
|
||||
agent = Agent(role="r", goal="g", backstory="b", llm="gpt-4o-mini")
|
||||
task = Task(
|
||||
description="d",
|
||||
expected_output="e",
|
||||
agent=agent,
|
||||
guardrails=[_sample_guardrail],
|
||||
)
|
||||
crew = Crew(agents=[agent], tasks=[task], verbose=False)
|
||||
state = RuntimeState(root=[crew])
|
||||
|
||||
from crewai.state.runtime import _prepare_entities
|
||||
|
||||
_prepare_entities(state.root)
|
||||
# This is the exact call that raised PydanticSerializationError
|
||||
# before the fix for issue #5620
|
||||
payload = state.model_dump(mode="json")
|
||||
assert "entities" in payload
|
||||
|
||||
def test_crew_with_guardrail_task_checkpoints_to_json(self) -> None:
|
||||
"""End-to-end: a Crew with guardrail tasks checkpoints to disk."""
|
||||
agent = Agent(role="r", goal="g", backstory="b", llm="gpt-4o-mini")
|
||||
task = Task(
|
||||
description="d",
|
||||
expected_output="e",
|
||||
agent=agent,
|
||||
guardrails=[_sample_guardrail],
|
||||
)
|
||||
crew = Crew(agents=[agent], tasks=[task], verbose=False)
|
||||
state = RuntimeState(root=[crew])
|
||||
state._provider = JsonProvider()
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
loc = state.checkpoint(d)
|
||||
# Verify the checkpoint file was written and is valid JSON
|
||||
with open(loc) as f:
|
||||
data = json.load(f)
|
||||
assert "entities" in data
|
||||
|
||||
def test_flow_with_guardrail_crew_serializes(self) -> None:
|
||||
"""A Flow whose state is fully serializable must not fail
|
||||
when the RuntimeState also includes a Crew with callable guardrails."""
|
||||
agent = Agent(role="r", goal="g", backstory="b", llm="gpt-4o-mini")
|
||||
task = Task(
|
||||
description="d",
|
||||
expected_output="e",
|
||||
agent=agent,
|
||||
guardrail=_sample_guardrail,
|
||||
)
|
||||
crew = Crew(agents=[agent], tasks=[task], verbose=False)
|
||||
flow = Flow(checkpoint=True)
|
||||
state = RuntimeState(root=[crew, flow])
|
||||
|
||||
from crewai.state.runtime import _prepare_entities
|
||||
|
||||
_prepare_entities(state.root)
|
||||
payload = state.model_dump(mode="json")
|
||||
assert "entities" in payload
|
||||
assert len(payload["entities"]) == 2
|
||||
|
||||
Reference in New Issue
Block a user