mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-13 14:32:47 +00:00
fix: deterministic message override via system boundary, drop dead _location
This commit is contained in:
@@ -552,14 +552,24 @@ async def _run_checkpoint_tui_async(location: str) -> None:
|
||||
preview = preview[:77] + "..."
|
||||
click.echo(f" Task {task_idx + 1}: {desc}")
|
||||
click.echo(f" -> {preview}")
|
||||
# Update the assistant message in the executor's history
|
||||
# so the LLM sees the override in its conversation
|
||||
agent = crew.tasks[task_idx].agent
|
||||
if agent and agent.agent_executor:
|
||||
for msg in reversed(agent.agent_executor.messages):
|
||||
if msg.get("role") == "assistant":
|
||||
msg["content"] = new_output
|
||||
break
|
||||
nth = sum(1 for t in crew.tasks[:task_idx] if t.agent is agent)
|
||||
messages = agent.agent_executor.messages
|
||||
system_positions = [
|
||||
i for i, m in enumerate(messages) if m.get("role") == "system"
|
||||
]
|
||||
if nth < len(system_positions):
|
||||
seg_start = system_positions[nth]
|
||||
seg_end = (
|
||||
system_positions[nth + 1]
|
||||
if nth + 1 < len(system_positions)
|
||||
else len(messages)
|
||||
)
|
||||
for j in range(seg_end - 1, seg_start, -1):
|
||||
if messages[j].get("role") == "assistant":
|
||||
messages[j]["content"] = new_output
|
||||
break
|
||||
overridden_agents.add(id(agent))
|
||||
|
||||
earliest = min(task_overrides)
|
||||
|
||||
@@ -10,7 +10,6 @@ via ``RuntimeState.model_rebuild()``.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
import uuid
|
||||
|
||||
@@ -38,19 +37,6 @@ if TYPE_CHECKING:
|
||||
from crewai import Entity
|
||||
|
||||
|
||||
def _base_location(location: str, provider: BaseProvider) -> str:
|
||||
"""Extract the base storage location from a restore path.
|
||||
|
||||
For SQLite (``db_path#id``), returns ``db_path``.
|
||||
For JSON (a file path), returns the parent directory.
|
||||
"""
|
||||
from crewai.state.provider.sqlite_provider import SqliteProvider
|
||||
|
||||
if isinstance(provider, SqliteProvider):
|
||||
return location.rsplit("#", 1)[0]
|
||||
return str(Path(location).parent)
|
||||
|
||||
|
||||
def _sync_checkpoint_fields(entity: object) -> None:
|
||||
"""Copy private runtime attrs into checkpoint fields before serializing.
|
||||
|
||||
@@ -125,7 +111,6 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||
_checkpoint_id: str | None = PrivateAttr(default=None)
|
||||
_parent_id: str | None = PrivateAttr(default=None)
|
||||
_branch: str = PrivateAttr(default="main")
|
||||
_location: str | None = PrivateAttr(default=None)
|
||||
_trigger: str | None = PrivateAttr(default=None)
|
||||
|
||||
@property
|
||||
@@ -253,7 +238,6 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||
raw = provider.from_checkpoint(location)
|
||||
state = cls.model_validate_json(raw, **kwargs)
|
||||
state._provider = provider
|
||||
state._location = _base_location(location, provider)
|
||||
checkpoint_id = provider.extract_id(location)
|
||||
state._checkpoint_id = checkpoint_id
|
||||
state._parent_id = checkpoint_id
|
||||
@@ -281,7 +265,6 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||
raw = await provider.afrom_checkpoint(location)
|
||||
state = cls.model_validate_json(raw, **kwargs)
|
||||
state._provider = provider
|
||||
state._location = _base_location(location, provider)
|
||||
checkpoint_id = provider.extract_id(location)
|
||||
state._checkpoint_id = checkpoint_id
|
||||
state._parent_id = checkpoint_id
|
||||
|
||||
Reference in New Issue
Block a user