fix: align fork-from-here with merged messaging history

This commit is contained in:
Michael Lam
2026-05-17 15:01:12 -07:00
parent f1d399b437
commit f986507809
5 changed files with 151 additions and 33 deletions
+4
View File
@@ -2,6 +2,10 @@
## [Unreleased]
### Fixed
- **PR #2480** by @Michaelyklam (refs #2472) — Make "Fork from here" use the same merged messaging-session transcript coordinate space that `/api/session` exposes, so forking an older message no longer silently copies the full sidecar when CLI/Gateway history inflated the visible message offset. The frontend now snapshots the source session id across the async full-history load, reloads the forked transcript fully after creation, and the branch handler best-effort saves the source session before slicing to keep undo/retry state coherent.
## [v0.51.85] — 2026-05-17 — Release BI (stage-378 — 3-PR batch — workspace-prefix display leakage fix + release-tag update banner + Slice 3a cancel-control gate RFC)
### Fixed
+56 -30
View File
@@ -1903,6 +1903,46 @@ def _messages_include_tool_metadata(messages) -> bool:
return False
def _merged_session_messages_for_display(session, cli_messages=None) -> list:
"""Return the message coordinate space exposed by ``GET /api/session``.
Messaging sessions can have a WebUI sidecar transcript plus messages from
the Agent/CLI store. The frontend computes fork keep-counts against this
merged display list, so branch/fork must slice the same list rather than
the sidecar-only ``session.messages`` array.
"""
cli_messages = list(cli_messages or [])
sidecar_messages = list(getattr(session, "messages", []) or [])
if cli_messages:
if sidecar_messages and sidecar_messages != cli_messages:
merged_messages = []
seen_message_keys = set()
for msg in sorted(list(cli_messages) + list(sidecar_messages), key=lambda m: (
float(m.get("timestamp") or 0),
str(m.get("role") or ""),
str(m.get("content") or ""),
)):
message_identity = msg.get("id") or msg.get("message_id")
if message_identity:
key = ("message_id", str(message_identity))
else:
key = (
"legacy",
str(msg.get("role") or ""),
str(msg.get("content") or ""),
str(msg.get("timestamp") or ""),
str(msg.get("tool_call_id") or ""),
str(msg.get("tool_name") or msg.get("name") or ""),
)
if key in seen_message_keys:
continue
seen_message_keys.add(key)
merged_messages.append(msg)
return merged_messages
return sidecar_messages if len(sidecar_messages) > len(cli_messages) else cli_messages
return sidecar_messages
def _session_requires_cli_metadata_lookup(session) -> bool:
"""Return True when a sidecar/session row still needs CLI metadata.
@@ -3568,7 +3608,6 @@ def handle_get(handler, parsed) -> bool:
_t3 = _time.monotonic()
if load_messages:
if is_messaging_session and cli_messages:
sidecar_messages = getattr(s, "messages", []) or []
# Recovery/aggregate sidecars can intentionally contain a
# longer visible conversation than the single state.db
# segment for this messaging session id. Prefer the longer
@@ -3576,33 +3615,7 @@ def handle_get(handler, parsed) -> bool:
# canonical per-segment transcript. When both sources carry
# different slices of the same stitched conversation, merge
# them chronologically and dedupe exact repeats.
if sidecar_messages and sidecar_messages != cli_messages:
merged_messages = []
seen_message_keys = set()
for msg in sorted(list(cli_messages) + list(sidecar_messages), key=lambda m: (
float(m.get("timestamp") or 0),
str(m.get("role") or ""),
str(m.get("content") or ""),
)):
message_identity = msg.get("id") or msg.get("message_id")
if message_identity:
key = ("message_id", str(message_identity))
else:
key = (
"legacy",
str(msg.get("role") or ""),
str(msg.get("content") or ""),
str(msg.get("timestamp") or ""),
str(msg.get("tool_call_id") or ""),
str(msg.get("tool_name") or msg.get("name") or ""),
)
if key in seen_message_keys:
continue
seen_message_keys.add(key)
merged_messages.append(msg)
_all_msgs = merged_messages
else:
_all_msgs = sidecar_messages if len(sidecar_messages) > len(cli_messages) else cli_messages
_all_msgs = _merged_session_messages_for_display(s, cli_messages)
else:
_all_msgs = s.messages
else:
@@ -4886,8 +4899,21 @@ def handle_post(handler, parsed) -> bool:
if custom_title:
custom_title = str(custom_title).strip()[:80] or None
# Build messages slice
source_messages = source.messages or []
# Build messages slice in the same coordinate space exposed by GET
# /api/session so frontend keep_count values from merged messaging
# transcripts do not silently become full sidecar copies.
try:
source.save()
except Exception:
pass
cli_meta = _lookup_cli_session_metadata(source.session_id) if _session_requires_cli_metadata_lookup(source) else {}
is_messaging_session = _is_messaging_session_record(source) or _is_messaging_session_record(cli_meta)
cli_messages = get_cli_session_messages(source.session_id) if is_messaging_session else []
source_messages = (
_merged_session_messages_for_display(source, cli_messages)
if is_messaging_session and cli_messages
else list(source.messages or [])
)
if keep_count is not None:
forked_messages = source_messages[:keep_count]
else:
+4 -1
View File
@@ -1188,6 +1188,7 @@ async function cmdBranch(args){
// which resets _oldestIdx to 0 after its wholesale replace. See #2184.
async function forkFromMessage(msgIdx){
if(!S.session||S.busy)return;
const initialSid = S.session.session_id;
// Capture the absolute keep_count before any async work that may
// reset _oldestIdx. _oldestIdx is 0 when the full transcript is
// already loaded, so short/already-full sessions send msgIdx unchanged.
@@ -1197,16 +1198,18 @@ async function forkFromMessage(msgIdx){
if(typeof _ensureAllMessagesLoaded==='function'){
await _ensureAllMessagesLoaded();
}
if(!S.session || S.session.session_id !== initialSid) return;
try{
const data=await api('/api/session/branch',{
method:'POST',
body:JSON.stringify({
session_id:S.session.session_id,
session_id:initialSid,
keep_count:absoluteKeepCount,
}),
});
if(data&&data.session_id){
await loadSession(data.session_id);
if(typeof _ensureAllMessagesLoaded==='function') await _ensureAllMessagesLoaded();
if(typeof renderSessionList==='function') await renderSessionList();
showToast(t('branch_forked'));
}
@@ -0,0 +1,83 @@
"""Regression coverage for #2472 fork-from-here on messaging sessions."""
from __future__ import annotations
from pathlib import Path
from types import SimpleNamespace
import api.routes as routes
REPO = Path(__file__).resolve().parents[1]
COMMANDS_JS = (REPO / "static" / "commands.js").read_text(encoding="utf-8")
ROUTES_PY = (REPO / "api" / "routes.py").read_text(encoding="utf-8")
def _function_body(src: str, name: str) -> str:
start = src.index(f"async function {name}")
brace = src.index("{", start)
depth = 0
for i in range(brace, len(src)):
if src[i] == "{":
depth += 1
elif src[i] == "}":
depth -= 1
if depth == 0:
return src[start : i + 1]
raise AssertionError(f"function {name!r} body not found")
def test_messaging_merge_helper_matches_session_get_coordinate_space():
session = SimpleNamespace(
messages=[
{"role": "user", "content": "sidecar only", "timestamp": 2},
{"role": "assistant", "content": "shared", "timestamp": 3},
]
)
cli_messages = [
{"role": "user", "content": "cli earlier", "timestamp": 1},
{"role": "assistant", "content": "shared", "timestamp": 3},
{"role": "assistant", "content": "cli later", "timestamp": 4},
]
merged = routes._merged_session_messages_for_display(session, cli_messages)
assert [m["content"] for m in merged] == [
"cli earlier",
"sidecar only",
"shared",
"cli later",
]
def test_branch_handler_uses_merged_messaging_messages_for_keep_count():
branch_idx = ROUTES_PY.index('parsed.path == "/api/session/branch":')
block = ROUTES_PY[branch_idx : branch_idx + 2600]
assert "_merged_session_messages_for_display(source, cli_messages)" in block
assert "get_cli_session_messages(source.session_id)" in block
assert "source_messages = source.messages or []" not in block
def test_branch_handler_best_effort_saves_source_before_fork_slice():
branch_idx = ROUTES_PY.index('parsed.path == "/api/session/branch":')
block = ROUTES_PY[branch_idx : branch_idx + 2600]
assert "source.save()" in block
assert block.index("source.save()") < block.index("source_messages =")
def test_fork_from_message_snapshots_session_id_across_async_load():
body = _function_body(COMMANDS_JS, "forkFromMessage")
assert "const initialSid = S.session.session_id;" in body
assert "S.session.session_id !== initialSid" in body
assert "session_id:initialSid" in body
assert "session_id:S.session.session_id" not in body
def test_fork_loads_full_fork_transcript_after_branch():
body = _function_body(COMMANDS_JS, "forkFromMessage")
load_idx = body.index("await loadSession(data.session_id)")
after_load = body[load_idx:]
assert "await _ensureAllMessagesLoaded()" in after_load
@@ -279,5 +279,7 @@ def test_messaging_session_loader_prefers_longer_sidecar_transcript():
handler = _extract_handler("handle_get")
old = "if is_messaging_session and cli_messages:\n _all_msgs = cli_messages"
assert old not in handler
assert "sidecar_messages = getattr(s, \"messages\", []) or []" in handler
assert "len(sidecar_messages) > len(cli_messages)" in handler
assert "_all_msgs = _merged_session_messages_for_display(s, cli_messages)" in handler
src = (REPO / "api" / "routes.py").read_text(encoding="utf-8")
assert "sidecar_messages = list(getattr(session, \"messages\", []) or [])" in src
assert "len(sidecar_messages) > len(cli_messages)" in src