mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-03 06:08:15 +00:00
fix(checkpoint,task): serialize Task class refs and propagate JSON mode through events
This commit is contained in:
@@ -8,7 +8,14 @@ from __future__ import annotations
|
||||
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import BaseModel, BeforeValidator, Field, PlainSerializer, PrivateAttr
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
BeforeValidator,
|
||||
Field,
|
||||
PrivateAttr,
|
||||
SerializationInfo,
|
||||
field_serializer,
|
||||
)
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
from crewai.utilities.rw_lock import RWLock
|
||||
@@ -66,10 +73,24 @@ class EventNode(BaseModel):
|
||||
event: Annotated[
|
||||
BaseEvent,
|
||||
BeforeValidator(_resolve_event),
|
||||
PlainSerializer(lambda v: v.model_dump()),
|
||||
]
|
||||
edges: dict[EdgeType, list[str]] = Field(default_factory=dict)
|
||||
|
||||
@field_serializer("event")
|
||||
def _serialize_event(
|
||||
self, value: BaseEvent, info: SerializationInfo
|
||||
) -> dict[str, Any]:
|
||||
"""Dump the event, propagating JSON mode to nested fields.
|
||||
|
||||
Without this the default ``v.model_dump()`` discards JSON mode, so any
|
||||
non-JSON-native nested values (e.g. ``type[BaseModel]`` references on
|
||||
a Task payload) are passed raw to ``json.dumps`` and explode with
|
||||
``PydanticSerializationError``.
|
||||
"""
|
||||
if info.mode == "json":
|
||||
return value.model_dump(mode="json")
|
||||
return value.model_dump()
|
||||
|
||||
def add_edge(self, edge_type: EdgeType, target_id: str) -> None:
|
||||
"""Add an edge from this node to another.
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ from pydantic import (
|
||||
BaseModel,
|
||||
BeforeValidator,
|
||||
Field,
|
||||
PlainSerializer,
|
||||
PrivateAttr,
|
||||
field_validator,
|
||||
model_validator,
|
||||
@@ -86,6 +87,58 @@ from crewai.utilities.printer import PRINTER
|
||||
from crewai.utilities.string_utils import interpolate_only
|
||||
|
||||
|
||||
def _serialize_class_ref(value: Any) -> str | None:
|
||||
"""Serialize a class reference to a ``module.qualname`` string.
|
||||
|
||||
Pydantic's default JSON serializer cannot handle ``type[BaseModel]``
|
||||
and similar class-valued fields, which raises
|
||||
``PydanticSerializationError`` during checkpointing. We emit a
|
||||
dotted import path so the value is round-trippable.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, type):
|
||||
module = getattr(value, "__module__", None)
|
||||
qualname = getattr(value, "__qualname__", None) or getattr(
|
||||
value, "__name__", None
|
||||
)
|
||||
if module and qualname:
|
||||
return f"{module}.{qualname}"
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def _validate_class_ref(value: Any) -> Any:
|
||||
"""Resolve a serialized class reference back into a class.
|
||||
|
||||
Accepts an existing class/``None`` unchanged. A string is interpreted as
|
||||
a ``module.qualname`` path; if it cannot be imported, ``None`` is
|
||||
returned so restoration degrades gracefully (user code re-instantiates
|
||||
the Task with the correct class anyway).
|
||||
"""
|
||||
if value is None or isinstance(value, type):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
import importlib
|
||||
|
||||
module_path, _, qualname = value.rpartition(".")
|
||||
if not module_path or not qualname:
|
||||
return None
|
||||
try:
|
||||
module = importlib.import_module(module_path)
|
||||
except ImportError:
|
||||
return None
|
||||
obj: Any = module
|
||||
for part in qualname.split("."):
|
||||
obj = getattr(obj, part, None)
|
||||
if obj is None:
|
||||
return None
|
||||
return obj if isinstance(obj, type) else None
|
||||
return value
|
||||
|
||||
|
||||
class Task(BaseModel):
|
||||
"""Class that represents a task to be executed.
|
||||
|
||||
@@ -141,15 +194,33 @@ class Task(BaseModel):
|
||||
description="Whether the task should be executed asynchronously or not.",
|
||||
default=False,
|
||||
)
|
||||
output_json: type[BaseModel] | None = Field(
|
||||
output_json: Annotated[
|
||||
type[BaseModel] | None,
|
||||
BeforeValidator(_validate_class_ref),
|
||||
PlainSerializer(
|
||||
_serialize_class_ref, return_type=str | None, when_used="json"
|
||||
),
|
||||
] = Field(
|
||||
description="A Pydantic model to be used to create a JSON output.",
|
||||
default=None,
|
||||
)
|
||||
output_pydantic: type[BaseModel] | None = Field(
|
||||
output_pydantic: Annotated[
|
||||
type[BaseModel] | None,
|
||||
BeforeValidator(_validate_class_ref),
|
||||
PlainSerializer(
|
||||
_serialize_class_ref, return_type=str | None, when_used="json"
|
||||
),
|
||||
] = Field(
|
||||
description="A Pydantic model to be used to create a Pydantic output.",
|
||||
default=None,
|
||||
)
|
||||
response_model: type[BaseModel] | None = Field(
|
||||
response_model: Annotated[
|
||||
type[BaseModel] | None,
|
||||
BeforeValidator(_validate_class_ref),
|
||||
PlainSerializer(
|
||||
_serialize_class_ref, return_type=str | None, when_used="json"
|
||||
),
|
||||
] = Field(
|
||||
description="A Pydantic model for structured LLM outputs using native provider features.",
|
||||
default=None,
|
||||
)
|
||||
@@ -189,7 +260,13 @@ class Task(BaseModel):
|
||||
description="Whether the task should instruct the agent to return the final answer formatted in Markdown",
|
||||
default=False,
|
||||
)
|
||||
converter_cls: type[Converter] | None = Field(
|
||||
converter_cls: Annotated[
|
||||
type[Converter] | None,
|
||||
BeforeValidator(_validate_class_ref),
|
||||
PlainSerializer(
|
||||
_serialize_class_ref, return_type=str | None, when_used="json"
|
||||
),
|
||||
] = Field(
|
||||
description="A converter class used to export structured output",
|
||||
default=None,
|
||||
)
|
||||
@@ -1052,6 +1129,27 @@ Follow these guidelines:
|
||||
tools=cloned_tools,
|
||||
)
|
||||
|
||||
def _normalize_agent_result(
|
||||
self, result: Any
|
||||
) -> tuple[str, BaseModel | None, dict[str, Any] | None]:
|
||||
"""Convert an agent execution result into ``(raw, pydantic, json)``.
|
||||
|
||||
The agent may return either a string or a Pydantic model (when the
|
||||
task uses ``output_pydantic``/``response_model`` and the LLM returned
|
||||
a structured payload). ``TaskOutput.raw`` is typed as ``str`` so the
|
||||
Pydantic model has to be serialized to JSON before it can be stored
|
||||
on a ``TaskOutput`` (e.g. during a guardrail-triggered retry).
|
||||
"""
|
||||
if isinstance(result, BaseModel):
|
||||
raw = result.model_dump_json()
|
||||
if self.output_pydantic:
|
||||
return raw, result, None
|
||||
if self.output_json:
|
||||
return raw, None, result.model_dump()
|
||||
return raw, None, None
|
||||
pydantic_output, json_output = self._export_output(result)
|
||||
return result, pydantic_output, json_output
|
||||
|
||||
def _export_output(
|
||||
self, result: str
|
||||
) -> tuple[BaseModel | None, dict[str, Any] | None]:
|
||||
@@ -1241,12 +1339,12 @@ Follow these guidelines:
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
pydantic_output, json_output = self._export_output(result)
|
||||
raw, pydantic_output, json_output = self._normalize_agent_result(result)
|
||||
task_output = TaskOutput(
|
||||
name=self.name or self.description,
|
||||
description=self.description,
|
||||
expected_output=self.expected_output,
|
||||
raw=result,
|
||||
raw=raw,
|
||||
pydantic=pydantic_output,
|
||||
json_dict=json_output,
|
||||
agent=agent.role,
|
||||
@@ -1337,12 +1435,12 @@ Follow these guidelines:
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
pydantic_output, json_output = self._export_output(result)
|
||||
raw, pydantic_output, json_output = self._normalize_agent_result(result)
|
||||
task_output = TaskOutput(
|
||||
name=self.name or self.description,
|
||||
description=self.description,
|
||||
expected_output=self.expected_output,
|
||||
raw=result,
|
||||
raw=raw,
|
||||
pydantic=pydantic_output,
|
||||
json_dict=json_output,
|
||||
agent=agent.role,
|
||||
|
||||
@@ -562,3 +562,110 @@ class TestKickoffFromCheckpoint:
|
||||
)
|
||||
assert mock_restored.checkpoint.restore_from is None
|
||||
assert result == "flow_result"
|
||||
|
||||
|
||||
# ---------- Pydantic model serialization in checkpoints (issue #5544) ----------
|
||||
|
||||
|
||||
class TestPydanticTypeFieldSerialization:
|
||||
"""Issue #5544 (Issue I): checkpoint serialization must not blow up on
|
||||
fields that hold ``type[BaseModel]`` references — e.g. a Task's
|
||||
``output_pydantic`` / ``output_json`` / ``response_model`` — nor on
|
||||
events that wrap such tasks in their payload.
|
||||
"""
|
||||
|
||||
def test_task_dumps_type_class_field_to_dotted_path(self) -> None:
|
||||
from pydantic import BaseModel as PydanticModel
|
||||
|
||||
class FamilyList(PydanticModel):
|
||||
families: list[str]
|
||||
|
||||
task = Task(
|
||||
description="d",
|
||||
expected_output="e",
|
||||
output_pydantic=FamilyList,
|
||||
)
|
||||
dumped = task.model_dump(mode="json")
|
||||
# The class is serialized as ``module.qualname``
|
||||
assert isinstance(dumped["output_pydantic"], str)
|
||||
assert dumped["output_pydantic"].endswith("FamilyList")
|
||||
|
||||
def test_task_round_trip_restores_class_reference(self) -> None:
|
||||
from pydantic import BaseModel as PydanticModel
|
||||
|
||||
global _CheckpointReplyModel # noqa: PLW0603
|
||||
|
||||
class _CheckpointReplyModel(PydanticModel):
|
||||
value: int
|
||||
|
||||
task = Task(
|
||||
description="d",
|
||||
expected_output="e",
|
||||
output_pydantic=_CheckpointReplyModel,
|
||||
)
|
||||
dumped_json = task.model_dump_json()
|
||||
restored = Task.model_validate_json(
|
||||
dumped_json, context={"from_checkpoint": True}
|
||||
)
|
||||
assert restored.output_pydantic is _CheckpointReplyModel
|
||||
|
||||
def test_task_round_trip_unknown_class_path_degrades_gracefully(self) -> None:
|
||||
# Mirrors a checkpoint produced in a different process / repo where
|
||||
# the class is no longer importable. We accept a None restore over
|
||||
# blowing up — user code re-instantiates the Task with the right
|
||||
# class anyway.
|
||||
restored = Task.model_validate(
|
||||
{
|
||||
"description": "d",
|
||||
"expected_output": "e",
|
||||
"output_pydantic": "no_such_module.NoSuchClass",
|
||||
},
|
||||
context={"from_checkpoint": True},
|
||||
)
|
||||
assert restored.output_pydantic is None
|
||||
|
||||
def test_runtime_state_with_event_carrying_pydantic_task_dumps_to_json(
|
||||
self,
|
||||
) -> None:
|
||||
"""End-to-end regression for issue #5544 Issue I.
|
||||
|
||||
A Crew + Task with ``output_pydantic`` produces events whose payload
|
||||
carries the Task. Without the field-level JSON serialization on
|
||||
``EventNode.event``, this dump explodes with PydanticSerializationError
|
||||
on the embedded ``type[BaseModel]`` reference.
|
||||
"""
|
||||
from pydantic import BaseModel as PydanticModel
|
||||
|
||||
from crewai import Agent, Crew
|
||||
from crewai.events.types.task_events import TaskCompletedEvent
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
|
||||
class FamilyList(PydanticModel):
|
||||
families: list[str]
|
||||
|
||||
agent = Agent(role="r", goal="g", backstory="b", llm="gpt-4o-mini")
|
||||
task = Task(
|
||||
description="d",
|
||||
expected_output="e",
|
||||
agent=agent,
|
||||
output_pydantic=FamilyList,
|
||||
)
|
||||
crew = Crew(agents=[agent], tasks=[task], verbose=False)
|
||||
state = RuntimeState(root=[crew])
|
||||
|
||||
event = TaskCompletedEvent(
|
||||
task=task,
|
||||
output=TaskOutput(
|
||||
description="d",
|
||||
expected_output="e",
|
||||
raw="{}",
|
||||
agent="r",
|
||||
),
|
||||
)
|
||||
state._event_record.add(event)
|
||||
|
||||
# Should not raise PydanticSerializationError.
|
||||
payload = state.model_dump(mode="json")
|
||||
# And it should round-trip through json.dumps (the actual checkpoint
|
||||
# writer does this immediately after).
|
||||
json.dumps(payload)
|
||||
|
||||
@@ -768,3 +768,59 @@ def test_per_guardrail_independent_retry_tracking():
|
||||
assert call_counts["g3"] == 1
|
||||
|
||||
assert "G3(1)" in result.raw
|
||||
|
||||
|
||||
def test_guardrail_retry_with_pydantic_agent_result():
|
||||
"""Regression test for issue #5544 (Issue II).
|
||||
|
||||
When a task has ``output_pydantic`` set and the LLM returns a structured
|
||||
Pydantic model, the agent's execute result is the Pydantic instance — not
|
||||
a string. On a guardrail retry, ``TaskOutput.raw`` is typed ``str``, so
|
||||
feeding the model directly to ``raw=`` blew up with a ValidationError and
|
||||
aborted the retry path. The retry should normalize the model to JSON
|
||||
before constructing ``TaskOutput``.
|
||||
"""
|
||||
from pydantic import BaseModel
|
||||
|
||||
class Family(BaseModel):
|
||||
family_id: int
|
||||
name: str
|
||||
size: int
|
||||
|
||||
class FamilyList(BaseModel):
|
||||
families: list[Family]
|
||||
|
||||
bad = FamilyList(families=[Family(family_id=1, name="X", size=2)])
|
||||
good = FamilyList(
|
||||
families=[Family(family_id=1, name="Smiths", size=2)]
|
||||
)
|
||||
|
||||
def is_family_guardrail(result: TaskOutput) -> tuple[bool, str]:
|
||||
if result.pydantic is None:
|
||||
return (False, "No pydantic output")
|
||||
bad_names = [f for f in result.pydantic.families if len(f.name) < 3]
|
||||
if bad_names:
|
||||
return (False, "Family name too short, must be >= 3 chars")
|
||||
return (True, result)
|
||||
|
||||
agent = Mock()
|
||||
agent.role = "test_agent"
|
||||
agent.execute_task.side_effect = [bad, good]
|
||||
agent.crew = None
|
||||
agent.last_messages = []
|
||||
|
||||
task = create_smart_task(
|
||||
description="Test pydantic retry",
|
||||
expected_output="JSON list of families",
|
||||
output_pydantic=FamilyList,
|
||||
guardrails=[is_family_guardrail],
|
||||
guardrail_max_retries=2,
|
||||
)
|
||||
|
||||
result = task.execute_sync(agent=agent)
|
||||
|
||||
assert isinstance(result, TaskOutput)
|
||||
assert isinstance(result.raw, str)
|
||||
assert isinstance(result.pydantic, FamilyList)
|
||||
assert result.pydantic.families[0].name == "Smiths"
|
||||
assert agent.execute_task.call_count == 2
|
||||
|
||||
Reference in New Issue
Block a user