Files
hermes-webui/tests/test_stale_stream_writeback.py
2026-05-13 10:23:03 +08:00

116 lines
4.4 KiB
Python

import queue
import threading
from pathlib import Path
from unittest.mock import Mock
import pytest
import api.config as config
import api.models as models
import api.streaming as streaming
from api.models import Session
@pytest.fixture(autouse=True)
def _isolate_sessions(tmp_path, monkeypatch):
session_dir = tmp_path / "sessions"
session_dir.mkdir()
index_file = session_dir / "_index.json"
monkeypatch.setattr(models, "SESSION_DIR", session_dir)
monkeypatch.setattr(models, "SESSION_INDEX_FILE", index_file)
monkeypatch.setattr(streaming, "SESSION_DIR", session_dir)
monkeypatch.setattr(config, "SESSION_INDEX_FILE", index_file, raising=False)
models.SESSIONS.clear()
config.STREAMS.clear()
config.CANCEL_FLAGS.clear()
config.AGENT_INSTANCES.clear()
config.SESSION_AGENT_LOCKS.clear()
yield
models.SESSIONS.clear()
config.STREAMS.clear()
config.CANCEL_FLAGS.clear()
config.AGENT_INSTANCES.clear()
config.SESSION_AGENT_LOCKS.clear()
def test_stream_writeback_requires_active_stream_ownership():
s = Session(session_id="ownership", messages=[])
s.active_stream_id = "current-stream"
assert streaming._stream_writeback_is_current(s, "current-stream") is True
s.active_stream_id = None
assert streaming._stream_writeback_is_current(s, "current-stream") is False
s.active_stream_id = "newer-stream"
assert streaming._stream_writeback_is_current(s, "current-stream") is False
def test_cancel_stream_does_not_append_marker_after_stream_ownership_rotated():
sid = "rotated_cancel_sid"
old_stream = "old-stream"
s = Session(
session_id=sid,
title="Rotated stream",
messages=[{"role": "user", "content": "newer prompt"}],
)
s.active_stream_id = "newer-stream"
s.pending_user_message = "newer prompt"
s.pending_started_at = 456.0
s.save()
models.SESSIONS[sid] = s
config.STREAMS[old_stream] = queue.Queue()
config.CANCEL_FLAGS[old_stream] = threading.Event()
mock_agent = Mock()
mock_agent.session_id = sid
mock_agent.interrupt = Mock()
config.AGENT_INSTANCES[old_stream] = mock_agent
assert streaming.cancel_stream(old_stream) is True
assert s.active_stream_id == "newer-stream"
assert s.pending_user_message == "newer prompt"
assert [m["content"] for m in s.messages] == ["newer prompt"]
assert all(m.get("content") != "*Task cancelled.*" for m in s.messages)
def test_success_path_checks_stream_ownership_before_persisting_result():
src = Path("api/streaming.py").read_text(encoding="utf-8")
guard = "if not ephemeral and not _stream_writeback_is_current(s, stream_id):"
guard_pos = src.find(guard)
result_merge_pos = src.find("_result_messages = result.get('messages') or _previous_context_messages")
compression_pos = src.find("Handle context compression side effects")
assert guard_pos != -1
assert result_merge_pos != -1
assert compression_pos != -1
assert guard_pos < result_merge_pos
assert guard_pos < compression_pos
def test_self_heal_retry_success_checks_stream_ownership_before_writeback():
src = Path("api/streaming.py").read_text(encoding="utf-8")
start = src.index("logger.info('[webui] self-heal (except path): retrying stream")
end = src.index("logger.info('[webui] self-heal (except path): retry succeeded')", start)
block = src[start:end]
guard = "if not ephemeral and not _stream_writeback_is_current(s, stream_id):"
assert guard in block
assert block.index(guard) < block.index("_result_messages = _heal_result.get('messages') or _previous_context_messages")
assert block.index(guard) < block.index("s.save()")
def test_outer_exception_path_checks_stream_ownership_before_error_writeback():
src = Path("api/streaming.py").read_text(encoding="utf-8")
outer_error_payload = src.index("_error_payload = _provider_error_payload(err_str, _exc_type, _exc_hint)")
start = src.index("# Persist the error so it survives page reload.", outer_error_payload)
end = src.index("put('apperror', _error_payload)", start)
block = src[start:end]
guard = "if not ephemeral and not _stream_writeback_is_current(s, stream_id):"
assert guard in block
assert block.index(guard) < block.index("_materialize_pending_user_turn_before_error(s)")
assert block.index(guard) < block.index("s.active_stream_id = None")
assert block.index(guard) < block.index("s.messages.append(_error_message)")