""" Tests for periodic session persistence during streaming (Issue #765). Validates: - Session.save(skip_index=True) writes the JSON file but skips the index rebuild - The periodic checkpoint fires when _checkpoint_activity is incremented (as it would be by on_tool() during real agent execution) - Messages stored via pending_user_message survive a simulated server restart """ import json import threading import time from pathlib import Path import pytest import api.models as models from api.models import Session @pytest.fixture(autouse=True) def _isolate_session_dir(tmp_path, monkeypatch): """Redirect SESSION_DIR and SESSION_INDEX_FILE to a temp directory.""" session_dir = tmp_path / "sessions" session_dir.mkdir() index_file = session_dir / "_index.json" monkeypatch.setattr(models, "SESSION_DIR", session_dir) monkeypatch.setattr(models, "SESSION_INDEX_FILE", index_file) models.SESSIONS.clear() yield session_dir, index_file models.SESSIONS.clear() def _make_session(session_id="abc123", messages=None): """Helper to create a Session with a known ID.""" return Session( session_id=session_id, title="Test Session", messages=messages or [{"role": "user", "content": "hello"}], ) class TestSaveSkipIndex: """Tests for the skip_index parameter on Session.save().""" def test_save_writes_json_file(self): """save() always writes the session JSON file, regardless of skip_index.""" s = _make_session("s1") s.save() assert s.path.exists() data = json.loads(s.path.read_text()) assert data["session_id"] == "s1" assert len(data["messages"]) == 1 def test_save_with_skip_index_writes_json(self): """save(skip_index=True) still writes the session JSON file.""" s = _make_session("s2") s.save(skip_index=True) assert s.path.exists() data = json.loads(s.path.read_text()) assert data["session_id"] == "s2" def test_save_with_skip_index_skips_index_rebuild(self): """save(skip_index=True) does NOT create or update the session index.""" s = _make_session("s3") s.save(skip_index=True) index = models.SESSION_INDEX_FILE assert not index.exists(), "Index file should not be created with skip_index=True" def test_save_without_skip_index_creates_index(self): """save() (default) DOES create the session index.""" s = _make_session("s4") s.save() index = models.SESSION_INDEX_FILE assert index.exists(), "Index file should be created by default save()" data = json.loads(index.read_text()) sids = [e["session_id"] for e in data] assert "s4" in sids def test_skip_index_then_full_save_updates_index(self): """After skip_index saves, a full save() correctly builds the index.""" s = _make_session("s5") s.messages.append({"role": "assistant", "content": "hi there"}) s.save(skip_index=True) assert not models.SESSION_INDEX_FILE.exists() s.messages.append({"role": "user", "content": "thanks"}) s.save() assert models.SESSION_INDEX_FILE.exists() data = json.loads(s.path.read_text()) assert len(data["messages"]) == 3 def test_skip_index_save_with_touch_updated_at_false(self): """save(skip_index=True, touch_updated_at=False) preserves updated_at.""" s = _make_session("touch1") original_updated_at = s.updated_at time.sleep(0.05) s.save(skip_index=True, touch_updated_at=False) data = json.loads(s.path.read_text()) assert data["updated_at"] == original_updated_at assert not models.SESSION_INDEX_FILE.exists() class TestPeriodicCheckpoint: """Tests for the periodic checkpoint mechanism during streaming. The checkpoint is keyed off an activity counter (_checkpoint_activity[0]), incremented by on_tool() on each tool.completed event — NOT off s.messages which is never mutated during agent.run_conversation() (the agent copies it). """ def test_checkpoint_fires_on_activity_counter_increment(self): """Checkpoint saves when _checkpoint_activity counter grows. Deterministic: instead of relying on time-based polling windows, we wait for the checkpoint thread's save_count to advance after each increment. Generous timeout guards against CI scheduling jitter. """ s = _make_session("ckpt1") s.pending_user_message = "do a long task" s.save() # initial save (like routes.py does before streaming starts) stop_event = threading.Event() _checkpoint_activity = [0] save_count = [0] save_event = threading.Event() def periodic_checkpoint(): last = 0 while not stop_event.wait(0.02): # fast poll for low-jitter test try: cur = _checkpoint_activity[0] if cur > last: s.save(skip_index=True) last = cur save_count[0] += 1 save_event.set() except Exception: pass t = threading.Thread(target=periodic_checkpoint, daemon=True) t.start() def _wait_for_save(target_count, timeout=3.0): """Wait until save_count[0] >= target_count, or timeout.""" deadline = time.monotonic() + timeout while save_count[0] < target_count and time.monotonic() < deadline: save_event.wait(timeout=0.05) save_event.clear() return save_count[0] >= target_count # Simulate on_tool() completing twice _checkpoint_activity[0] += 1 # first tool completes assert _wait_for_save(1), f"Expected 1 save after first increment; got {save_count[0]}" _checkpoint_activity[0] += 1 # second tool completes assert _wait_for_save(2), f"Expected 2 saves after second increment; got {save_count[0]}" stop_event.set() t.join(timeout=2) assert save_count[0] >= 2, ( "Expected at least 2 checkpoint saves (one per activity increment); " f"got {save_count[0]}" ) # Verify the JSON is on disk and readable data = json.loads(s.path.read_text()) assert data["pending_user_message"] == "do a long task" def test_checkpoint_does_not_fire_without_activity(self): """Checkpoint skips save when activity counter has not changed.""" s = _make_session("ckpt2") s.save() stop_event = threading.Event() _checkpoint_activity = [0] save_count = [0] def periodic_checkpoint(): last = 0 while not stop_event.wait(0.05): cur = _checkpoint_activity[0] if cur > last: s.save(skip_index=True) last = cur save_count[0] += 1 t = threading.Thread(target=periodic_checkpoint, daemon=True) t.start() # No increments — checkpoint should stay quiet time.sleep(0.4) stop_event.set() t.join(timeout=2) assert save_count[0] == 0, ( f"Expected 0 saves when activity is unchanged; got {save_count[0]}" ) def test_checkpoint_stops_on_signal(self): """Checkpoint thread exits cleanly when stop event is set.""" s = _make_session("ckpt3") stop_event = threading.Event() iterations = [0] def periodic_checkpoint(): while not stop_event.wait(0.02): iterations[0] += 1 t = threading.Thread(target=periodic_checkpoint, daemon=True) t.start() time.sleep(0.15) stop_event.set() t.join(timeout=1) assert not t.is_alive(), "Checkpoint thread should have stopped" def test_pending_message_survives_simulated_restart(self): """pending_user_message written before run_conversation survives a restart. This is the minimal guarantee for Issue #765: even if the agent produces no tool calls before a crash, the user's message is not silently lost. """ s = _make_session("survive1", messages=[{"role": "user", "content": "first turn"}]) s.save() # initial full save # Simulate what routes.py does before _run_agent_streaming: s.pending_user_message = "do a long research task" s.pending_started_at = time.time() s.active_stream_id = "stream-abc123" s.save(skip_index=True) # checkpoint-style save # Simulate restart: clear in-memory state, reload from disk del s models.SESSIONS.clear() reloaded = Session.load("survive1") assert reloaded is not None assert reloaded.pending_user_message == "do a long research task" assert reloaded.active_stream_id == "stream-abc123" # Original messages still intact assert len(reloaded.messages) == 1 def test_activity_checkpoint_persists_updated_at(self): """Each checkpoint save updates updated_at, keeping session fresh in sidebar.""" s = _make_session("ts1") s.save() ts_before = s.updated_at time.sleep(0.05) _checkpoint_activity = [1] # simulate one tool completion stop_event = threading.Event() def periodic_checkpoint(): last = 0 while not stop_event.wait(0.05): cur = _checkpoint_activity[0] if cur > last: s.save(skip_index=True) last = cur t = threading.Thread(target=periodic_checkpoint, daemon=True) t.start() time.sleep(0.2) stop_event.set() t.join(timeout=1) data = json.loads(s.path.read_text()) assert data["updated_at"] > ts_before, "Checkpoint should update updated_at" class TestIssue765FollowupHardening: """Regression tests for the follow-up hardening pass on Issue #765. Includes the guard that the outer `finally` must not UnboundLocalError when an exception fires before the checkpoint thread is created. """ def test_same_session_concurrent_saves_use_distinct_temp_files(self, monkeypatch): """Two concurrent saves of the same session must not collide on one tmp path. The key regression guard here is that each save call should reach os.replace() with a distinct source tmp path. With the old shared `.tmp` scheme, both threads would target the same path and the second replace would deterministically fail once the first consume/remove happened. """ s = _make_session("same_sid") s.save(skip_index=True) # seed the file on disk original_replace = models.os.replace barrier = threading.Barrier(2) replace_sources = [] errors = [] def _replace_with_barrier(src, dst): replace_sources.append(str(src)) barrier.wait(timeout=5) return original_replace(src, dst) monkeypatch.setattr(models.os, "replace", _replace_with_barrier) def _save_worker(): try: s.save(skip_index=True) except Exception as e: errors.append(e) t1 = threading.Thread(target=_save_worker) t2 = threading.Thread(target=_save_worker) t1.start() t2.start() t1.join(timeout=5) t2.join(timeout=5) assert not errors, f"Concurrent same-session saves should not fail: {errors}" assert len(replace_sources) == 2, f"Expected 2 replace calls, got {replace_sources}" assert len(set(replace_sources)) == 2, ( "Concurrent same-session saves must use distinct temp files; " f"got {replace_sources}" ) data = json.loads(s.path.read_text(encoding="utf-8")) assert data["session_id"] == "same_sid" def test_success_path_joins_checkpoint_before_session_mutation(self): """Static guard: success path must stop/join checkpoint thread before mutating. This keeps the post-run_conversation session rewrite serialized relative to the periodic checkpoint worker. """ src = (Path(__file__).parent.parent / "api" / "streaming.py").read_text( encoding="utf-8" ) stop_idx = src.find("if _checkpoint_stop is not None:\n _checkpoint_stop.set()") join_idx = src.find("if _ckpt_thread is not None:\n _ckpt_thread.join(timeout=15)") lock_idx = src.find("with _agent_lock:\n _result_messages =") save_idx = src.find("s.context_messages = _next_context_messages") assert stop_idx != -1, "Success path must stop the checkpoint thread" assert join_idx != -1, "Success path must join the checkpoint thread" assert lock_idx != -1, "Success path must serialize mutation with _agent_lock" assert save_idx != -1, "Success path restore/mutation block not found" assert stop_idx < join_idx < lock_idx <= save_idx, ( "Checkpoint stop/join must happen before the success-path session mutation block" ) def test_silent_failure_path_does_not_reacquire_agent_lock(self): """Silent-failure path must not nest `_agent_lock` inside the success lock. Reacquiring the same per-session lock inside the post-run_conversation block deadlocks because `_get_session_agent_lock()` returns a non-reentrant Lock. """ src = (Path(__file__).parent.parent / "api" / "streaming.py").read_text( encoding="utf-8" ) outer_lock_idx = src.find("with _agent_lock:\n _result_messages =") silent_failure_idx = src.find("if not _assistant_added and not _token_sent:") inner_lock_idx = src.find("with _agent_lock:", outer_lock_idx + 1) compression_idx = src.find("# ── Handle context compression side effects ──") assert outer_lock_idx != -1, "Outer success-path _agent_lock block not found" assert silent_failure_idx != -1, "Silent-failure branch not found" assert compression_idx != -1, "Compression marker not found" assert not ( inner_lock_idx != -1 and silent_failure_idx < inner_lock_idx < compression_idx ), "Silent-failure path must not reacquire _agent_lock inside the outer lock" def test_checkpoint_stop_initialised_before_any_raiseable_code(self): """Static check: `_checkpoint_stop = None` must appear before any code that could raise inside _run_agent_streaming's outer try.""" src = (Path(__file__).parent.parent / "api" / "streaming.py").read_text( encoding="utf-8" ) lines = src.splitlines() try_line = next( i for i, ln in enumerate(lines, 1) if ln.rstrip().endswith("try:") and any( lines[j].strip().startswith("_checkpoint_stop = None") for j in range(max(0, i - 4), i - 1) ) ) # The assignment must precede the `try:` — not sit inside the nested # block where an earlier line could raise before it runs. init_line = next( i for i, ln in enumerate(lines, 1) if "_checkpoint_stop = None" in ln ) assert init_line < try_line, ( f"_checkpoint_stop = None (line {init_line}) must precede the outer " f"try block (line {try_line}) so the finally can safely check it." ) def test_finally_path_when_early_exception_does_not_unbound_error(self): """Mirror the _run_agent_streaming try/finally structure — proves that pre-initialising _checkpoint_stop = None outside any raiseable code keeps the finally safe.""" def mimic_run_agent_streaming(): _checkpoint_stop = None # pre-init (the fix) try: # Anything here could raise — simulate early failure raise ValueError("early failure, e.g. get_session KeyError") _checkpoint_stop = threading.Event() # never reached finally: # The guard the PR added — must not itself raise if _checkpoint_stop is not None: _checkpoint_stop.set() with pytest.raises(ValueError, match="early failure"): mimic_run_agent_streaming() def test_agent_lock_null_guard_in_except_block(self): """The except block must not crash with AttributeError when _agent_lock is None (e.g. when get_session succeeds but _get_session_agent_lock hasn't been called yet, or _get_session_agent_lock itself raised). The code must use a nullcontext fallback rather than unconditionally entering `with _agent_lock:`.""" src = (Path(__file__).parent.parent / "api" / "streaming.py").read_text( encoding="utf-8" ) # Verify contextlib.nullcontext is used as a fallback assert "contextlib.nullcontext()" in src, ( "The except block must guard _agent_lock being None by falling " "back to contextlib.nullcontext() instead of unconditionally " "entering `with _agent_lock:`" ) # Verify the except block uses _lock_ctx (the guarded variable) assert "_lock_ctx" in src, ( "The except block must assign _agent_lock / nullcontext to a " "variable and use it, not enter `with _agent_lock:` directly" ) def test_periodic_checkpoint_uses_agent_lock(self): """The periodic checkpoint thread must hold _agent_lock while saving to prevent concurrent mutation races with other endpoints.""" src = (Path(__file__).parent.parent / "api" / "streaming.py").read_text( encoding="utf-8" ) # Find the _periodic_checkpoint function ckpt_idx = src.find("def _periodic_checkpoint():") assert ckpt_idx != -1, "_periodic_checkpoint function not found" ckpt_block = src[ckpt_idx:ckpt_idx + 600] assert "with _agent_lock:" in ckpt_block, ( "_periodic_checkpoint must hold _agent_lock while calling s.save() " "to prevent race conditions with other session-mutating endpoints" ) def test_background_title_update_rebinds_to_canonical_session_instance(self): """Guard against stale Session object mutation after LLM round-trip. _run_background_title_update must re-bind `s` to SESSIONS.get(session_id, s) under LOCK before deciding whether a manual rename should block the generated title write. """ src = (Path(__file__).parent.parent / "api" / "streaming.py").read_text( encoding="utf-8" ) fn_idx = src.find("def _run_background_title_update(") assert fn_idx != -1, "_run_background_title_update not found" fn_block = src[fn_idx:fn_idx + 3200] assert "with LOCK:" in fn_block, ( "_run_background_title_update must acquire LOCK before rebinding " "to canonical cached session instance" ) assert "s = SESSIONS.get(session_id, s)" in fn_block, ( "_run_background_title_update must rebind to canonical cached " "session instance under LOCK" ) def test_cancel_stream_uses_agent_lock(self): """cancel_stream must hold _agent_lock during session cleanup to prevent races with checkpoint saves and other writers.""" src = (Path(__file__).parent.parent / "api" / "streaming.py").read_text( encoding="utf-8" ) cancel_idx = src.find("def cancel_stream(") assert cancel_idx != -1, "cancel_stream function not found" cancel_block = src[cancel_idx:] # Find the session cleanup section cleanup_idx = cancel_block.find("Session cleanup outside STREAMS_LOCK") assert cleanup_idx != -1, "Session cleanup comment not found in cancel_stream" cleanup_section = cancel_block[cleanup_idx:cleanup_idx + 800] assert "_get_session_agent_lock" in cleanup_section, ( "cancel_stream must acquire _get_session_agent_lock during " "session cleanup to serialise with the checkpoint thread and " "other session-mutating endpoints" ) def test_session_ops_retry_undo_hold_agent_lock(self): """retry_last and undo_last must hold _get_session_agent_lock for the entire read-modify-save cycle.""" src = (Path(__file__).parent.parent / "api" / "session_ops.py").read_text( encoding="utf-8" ) assert "_get_session_agent_lock" in src, ( "session_ops must import _get_session_agent_lock" ) # Both functions must use with _get_session_agent_lock(session_id): for func_name in ("retry_last", "undo_last"): func_idx = src.find(f"def {func_name}(") assert func_idx != -1, f"{func_name} not found in session_ops.py" func_block = src[func_idx:func_idx + 1200] assert "with _get_session_agent_lock" in func_block, ( f"{func_name} must wrap its read-modify-save cycle in " f"with _get_session_agent_lock(session_id)" ) def test_periodic_checkpoint_mutation_race_with_undo_last(self, tmp_path, monkeypatch): """Run _periodic_checkpoint against a session whose messages list is concurrently truncated by undo_last; the on-disk JSON must remain parseable and internally consistent. The simulated checkpoint mirrors production by acquiring _get_session_agent_lock around s.save(), and we assert that every on-disk snapshot's messages list is one of the allowed snapshots (never an interleaving of fields from two different saves). """ session_dir = tmp_path / "sessions_undo_race" session_dir.mkdir() index_file = session_dir / "_index.json" monkeypatch.setattr(models, "SESSION_DIR", session_dir) monkeypatch.setattr(models, "SESSION_INDEX_FILE", index_file) models.SESSIONS.clear() try: s = Session( session_id="race_test", title="Race Test", messages=[ {"role": "user", "content": "first"}, {"role": "assistant", "content": "reply 1"}, {"role": "user", "content": "second"}, {"role": "assistant", "content": "reply 2"}, {"role": "user", "content": "third"}, {"role": "assistant", "content": "reply 3"}, ], ) s.save() models.SESSIONS[s.session_id] = s _checkpoint_stop = threading.Event() _checkpoint_activity = [0] errors = [] # Collect every on-disk messages snapshot observed by the # checkpoint thread so we can assert atomicity after the run. checkpoint_snapshots = [] _lock = threading.Lock() from api.config import _get_session_agent_lock _agent_lock = _get_session_agent_lock("race_test") def _periodic_checkpoint(): last = 0 while not _checkpoint_stop.wait(0.01): try: cur = _checkpoint_activity[0] if cur > last: with _agent_lock: s.save(skip_index=True) # Read back the on-disk JSON to verify atomicity try: snap = json.loads(s.path.read_text()) with _lock: checkpoint_snapshots.append(snap.get("messages")) except Exception: pass last = cur except Exception as e: errors.append(e) t = threading.Thread(target=_periodic_checkpoint, daemon=True) t.start() from api.session_ops import undo_last # Collect the allowed message snapshots (each state the session # is in at a point where a checkpoint might observe it). allowed_message_snapshots = [] # The initial state (before any undo) is a valid checkpoint target. allowed_message_snapshots.append( [dict(m) if isinstance(m, dict) else m for m in s.messages] ) for _ in range(5): _checkpoint_activity[0] += 1 time.sleep(0.02) try: undo_last("race_test") except ValueError: pass # Record the post-undo state (before appending new messages) # as an allowed snapshot — the checkpoint may observe this. allowed_message_snapshots.append( [dict(m) if isinstance(m, dict) else m for m in s.messages] ) # Wrap mutation + save in _agent_lock to mirror production # paths and prevent the checkpoint from observing an # intermediate +1-message snapshot. with _agent_lock: s.messages.append({"role": "user", "content": f"msg-{_}"}) s.messages.append({"role": "assistant", "content": f"ans-{_}"}) # Record the in-memory messages list *before* save so we # can verify that every checkpoint snapshot matches one # of these. allowed_message_snapshots.append( [dict(m) if isinstance(m, dict) else m for m in s.messages] ) s.save() _checkpoint_stop.set() t.join(timeout=2) assert not errors, f"Checkpoint thread encountered errors: {errors}" # Verify the on-disk JSON is parseable data = json.loads(s.path.read_text()) assert data["session_id"] == "race_test" # Messages must be a list (not corrupted by concurrent mutation) assert isinstance(data["messages"], list) # Contract assertion: every checkpoint snapshot's messages must # equal one of the allowed in-memory snapshots, never an # interleaving of fields from two different saves. This assertion # has teeth: if the _agent_lock were removed from the checkpoint # or the undo path, concurrent mutations would produce snapshots # that match no allowed state (e.g. a list with some messages # from before undo and some from after). for snap_msgs in checkpoint_snapshots: if snap_msgs is None: continue # Normalize for comparison (strip display-only metadata) normalized = [ {k: v for k, v in m.items() if k in ("role", "content")} if isinstance(m, dict) else m for m in snap_msgs ] matched = False for allowed in allowed_message_snapshots: norm_allowed = [ {k: v for k, v in m.items() if k in ("role", "content")} if isinstance(m, dict) else m for m in allowed ] if normalized == norm_allowed: matched = True break assert matched, ( f"Checkpoint snapshot {normalized!r} does not match any " f"allowed state — this indicates a serialization failure " f"(the _agent_lock is not preventing interleaved writes)." ) finally: models.SESSIONS.clear() def test_cancel_stream_concurrent_checkpoint_produces_valid_json(self, tmp_path, monkeypatch): """Run cancel_stream while a _periodic_checkpoint thread is concurrently saving the same session; the resulting on-disk JSON must be parseable and active_stream_id must be None. The simulated checkpoint mirrors production by acquiring _get_session_agent_lock around s.save(), and we assert that every on-disk snapshot is internally consistent (never an interleaving of fields from two different saves). """ session_dir = tmp_path / "sessions_cancel_race" session_dir.mkdir() index_file = session_dir / "_index.json" monkeypatch.setattr(models, "SESSION_DIR", session_dir) monkeypatch.setattr(models, "SESSION_INDEX_FILE", index_file) models.SESSIONS.clear() try: s = Session( session_id="cancel_race", title="Cancel Race Test", messages=[ {"role": "user", "content": "hello"}, {"role": "assistant", "content": "world"}, ], active_stream_id="stream-abc", ) s.save() models.SESSIONS[s.session_id] = s _checkpoint_stop = threading.Event() _checkpoint_activity = [0] errors = [] # Collect every on-disk snapshot observed by the checkpoint thread. checkpoint_snapshots = [] _snap_lock = threading.Lock() from api.config import _get_session_agent_lock _agent_lock = _get_session_agent_lock("cancel_race") def _periodic_checkpoint(): last = 0 while not _checkpoint_stop.wait(0.01): try: cur = _checkpoint_activity[0] if cur > last: with _agent_lock: s.save(skip_index=True) # Read back the on-disk JSON to verify atomicity try: snap = json.loads(s.path.read_text()) with _snap_lock: checkpoint_snapshots.append(snap) except Exception: pass last = cur except Exception as e: errors.append(e) t = threading.Thread(target=_periodic_checkpoint, daemon=True) t.start() # Simulate cancel_stream session cleanup directly for i in range(10): _checkpoint_activity[0] += 1 time.sleep(0.01) with _get_session_agent_lock("cancel_race"): s.active_stream_id = None s.pending_user_message = None s.pending_attachments = [] s.pending_started_at = None s.save() _checkpoint_stop.set() t.join(timeout=2) assert not errors, f"Checkpoint thread encountered errors: {errors}" data = json.loads(s.path.read_text()) assert data["session_id"] == "cancel_race" assert data["active_stream_id"] is None, ( "active_stream_id must be None after cancel cleanup" ) assert isinstance(data["messages"], list) # Contract assertion: every checkpoint snapshot must be # internally consistent (no interleaving of fields from two # different saves). Because both the cancel cleanup and the # checkpoint hold the same _agent_lock, they are serialized — # but ordering is nondeterministic, so a snapshot taken # *before* cancel will see active_stream_id="stream-abc" and # one taken *after* will see None. The guarantee is that # each snapshot is self-consistent, never a partial mix. # # This assertion has teeth: if the _agent_lock were removed # from either the checkpoint or the cancel path, a snapshot # could see active_stream_id=None while pending_user_message # still holds the pre-cancel value — a partial state that # violates the atomicity contract. for snap in checkpoint_snapshots: assert isinstance(snap.get("messages"), list), ( "Checkpoint snapshot messages must be a list" ) assert snap.get("active_stream_id") in ("stream-abc", None), ( "Checkpoint snapshot active_stream_id must be either " "the initial value or None (serialized, not interleaved), " f"got {snap.get('active_stream_id')!r}" ) # When active_stream_id is None, the cancel cleanup must # have run — so all four cancel fields must be cleared # atomically. A partial state (e.g. active_stream_id=None # but pending_user_message still set) would indicate a # serialization failure. if snap.get("active_stream_id") is None: assert snap.get("pending_user_message") is None, ( "Snapshot with active_stream_id=None must also have " "pending_user_message=None (atomic cancel cleanup " "under _agent_lock)" ) assert snap.get("pending_attachments") == [] or snap.get("pending_attachments") is None, ( "Snapshot with active_stream_id=None must also have " "empty pending_attachments (atomic cancel cleanup " "under _agent_lock)" ) assert snap.get("pending_started_at") is None, ( "Snapshot with active_stream_id=None must also have " "pending_started_at=None (atomic cancel cleanup " "under _agent_lock)" ) finally: models.SESSIONS.clear() def test_lock_identity_preserved_after_session_id_rotation(self): """When compression rotates session_id, the per-session lock must be aliased so that _get_session_agent_lock(new_sid) returns the *same* Lock object as _get_session_agent_lock(old_sid). This is a static guard: it directly simulates the migration that streaming.py performs inside the compression rotation block. """ from api.config import ( _get_session_agent_lock, SESSION_AGENT_LOCKS, SESSION_AGENT_LOCKS_LOCK, ) old_sid = "pre-rotation-id" new_sid = "post-rotation-id" # Acquire the lock under the old ID old_lock = _get_session_agent_lock(old_sid) # Simulate the migration that streaming.py does during compression: # alias new_sid → held _agent_lock reference, then pop old_sid. _agent_lock = old_lock with SESSION_AGENT_LOCKS_LOCK: SESSION_AGENT_LOCKS[new_sid] = _agent_lock SESSION_AGENT_LOCKS.pop(old_sid, None) # Now looking up the new ID must return the exact same Lock object new_lock = _get_session_agent_lock(new_sid) assert new_lock is old_lock, ( f"After rotation, _get_session_agent_lock({new_sid!r}) must " f"return the same Lock object as _get_session_agent_lock({old_sid!r}); " f"got {new_lock!r} vs {old_lock!r}" ) # The old ID entry must no longer exist (it was popped) with SESSION_AGENT_LOCKS_LOCK: assert old_sid not in SESSION_AGENT_LOCKS, ( f"Old session ID {old_sid!r} must be removed from " f"SESSION_AGENT_LOCKS after rotation" ) # Cleanup with SESSION_AGENT_LOCKS_LOCK: SESSION_AGENT_LOCKS.pop(new_sid, None) def test_lock_rotation_migration_survives_old_id_already_pruned(self): """Compression lock migration must not require old_sid to exist in dict. A concurrent /api/session/delete can prune old_sid before rotation code runs. The migration must still succeed by assigning the held _agent_lock reference directly. """ from api.config import ( _get_session_agent_lock, SESSION_AGENT_LOCKS, SESSION_AGENT_LOCKS_LOCK, ) old_sid = "pre-rotation-pruned" new_sid = "post-rotation-pruned" _agent_lock = _get_session_agent_lock(old_sid) with SESSION_AGENT_LOCKS_LOCK: SESSION_AGENT_LOCKS.pop(old_sid, None) # simulate concurrent prune # Must not raise KeyError even though old_sid is absent. with SESSION_AGENT_LOCKS_LOCK: SESSION_AGENT_LOCKS[new_sid] = _agent_lock SESSION_AGENT_LOCKS.pop(old_sid, None) new_lock = _get_session_agent_lock(new_sid) assert new_lock is _agent_lock with SESSION_AGENT_LOCKS_LOCK: SESSION_AGENT_LOCKS.pop(new_sid, None)