mirror of
https://github.com/nesquena/hermes-webui.git
synced 2026-05-24 18:50:15 +00:00
3d96dc1498
Co-authored-by: nesquena-hermes <nesquena-hermes@users.noreply.github.com> Co-authored-by: nesquena <nesquena@users.noreply.github.com>
348 lines
15 KiB
Python
348 lines
15 KiB
Python
"""Tests for real /steer functionality (follow-up to PR #1062).
|
|
|
|
Covers the new POST /api/chat/steer endpoint which mirrors the CLI's /steer
|
|
command (cli.py:6140-6155): the endpoint looks up the cached AIAgent for the
|
|
session, calls agent.steer(text), and the agent's run loop appends the steer
|
|
text to the next tool-result message — no interruption.
|
|
|
|
Falls back to {"accepted": false, "fallback": "<reason>"} when the agent
|
|
isn't running, isn't cached, or doesn't support steer (older agent versions).
|
|
The frontend uses the fallback signal to drop back to interrupt mode.
|
|
|
|
Plus a leftover-delivery flow: if the agent finishes its turn before the
|
|
steer is consumed (no tool-call boundary), _drain_pending_steer is called
|
|
after run_conversation returns and a `pending_steer_leftover` SSE event is
|
|
emitted so the frontend can queue the leftover text as a next-turn message.
|
|
"""
|
|
import sys
|
|
import os
|
|
import unittest
|
|
from pathlib import Path
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _restore_auth_sessions():
|
|
"""Snapshot and restore api.auth._sessions — see test_1058 for the rationale."""
|
|
import api.auth as _auth
|
|
snapshot = dict(_auth._sessions)
|
|
yield
|
|
_auth._sessions.clear()
|
|
_auth._sessions.update(snapshot)
|
|
|
|
|
|
@pytest.fixture
|
|
def _clear_caches():
|
|
"""Snapshot SESSION_AGENT_CACHE and STREAMS so tests don't bleed."""
|
|
from api.config import SESSION_AGENT_CACHE, SESSION_AGENT_CACHE_LOCK, STREAMS, STREAMS_LOCK
|
|
with SESSION_AGENT_CACHE_LOCK:
|
|
cache_snap = dict(SESSION_AGENT_CACHE)
|
|
SESSION_AGENT_CACHE.clear()
|
|
with STREAMS_LOCK:
|
|
streams_snap = dict(STREAMS)
|
|
STREAMS.clear()
|
|
yield
|
|
with SESSION_AGENT_CACHE_LOCK:
|
|
SESSION_AGENT_CACHE.clear()
|
|
SESSION_AGENT_CACHE.update(cache_snap)
|
|
with STREAMS_LOCK:
|
|
STREAMS.clear()
|
|
STREAMS.update(streams_snap)
|
|
|
|
|
|
def _make_handler():
|
|
"""Minimal handler stub matching the methods api.helpers.j() touches."""
|
|
h = MagicMock()
|
|
h.wfile = MagicMock()
|
|
h.headers = MagicMock()
|
|
h.headers.get = MagicMock(return_value="")
|
|
return h
|
|
|
|
|
|
def _captured_response(handler):
|
|
"""Pull the JSON body that j() wrote to handler.wfile."""
|
|
import json as _json
|
|
# j() calls handler.wfile.write(body)
|
|
write_calls = handler.wfile.write.call_args_list
|
|
assert write_calls, "no body was written to handler.wfile"
|
|
body = write_calls[-1][0][0]
|
|
return _json.loads(body.decode("utf-8"))
|
|
|
|
|
|
def _captured_status(handler):
|
|
"""Pull the HTTP status passed to handler.send_response()."""
|
|
calls = handler.send_response.call_args_list
|
|
assert calls, "no status was sent"
|
|
return calls[-1][0][0]
|
|
|
|
|
|
# ── Backend: the /api/chat/steer endpoint ─────────────────────────────────
|
|
|
|
class TestHandleChatSteerHappyPath:
|
|
"""Endpoint accepts text and calls agent.steer() when all gates pass."""
|
|
|
|
def test_accepts_when_agent_cached_and_running(self, _clear_caches):
|
|
from api.streaming import _handle_chat_steer
|
|
from api.config import SESSION_AGENT_CACHE, SESSION_AGENT_CACHE_LOCK, STREAMS, STREAMS_LOCK
|
|
sid, stream_id = "sid_happy", "stream_happy"
|
|
agent = MagicMock()
|
|
agent.steer = MagicMock(return_value=True)
|
|
with SESSION_AGENT_CACHE_LOCK:
|
|
SESSION_AGENT_CACHE[sid] = (agent, "sig")
|
|
with STREAMS_LOCK:
|
|
import queue as _q
|
|
STREAMS[stream_id] = _q.Queue()
|
|
|
|
sess = MagicMock()
|
|
sess.active_stream_id = stream_id
|
|
with patch("api.streaming.get_session", return_value=sess):
|
|
handler = _make_handler()
|
|
_handle_chat_steer(handler, {"session_id": sid, "text": "Use Python instead"})
|
|
|
|
agent.steer.assert_called_once_with("Use Python instead")
|
|
body = _captured_response(handler)
|
|
assert body == {"accepted": True, "fallback": None, "stream_id": stream_id}
|
|
|
|
|
|
class TestHandleChatSteerFallbacks:
|
|
"""Each gate that fails returns a structured fallback the frontend can branch on."""
|
|
|
|
def test_no_cached_agent(self, _clear_caches):
|
|
from api.streaming import _handle_chat_steer
|
|
handler = _make_handler()
|
|
_handle_chat_steer(handler, {"session_id": "sid_x", "text": "hint"})
|
|
body = _captured_response(handler)
|
|
assert body["accepted"] is False
|
|
assert body["fallback"] == "no_cached_agent"
|
|
|
|
def test_agent_lacks_steer_method(self, _clear_caches):
|
|
from api.streaming import _handle_chat_steer
|
|
from api.config import SESSION_AGENT_CACHE, SESSION_AGENT_CACHE_LOCK
|
|
sid = "sid_old"
|
|
# Older agent without steer() — use spec to suppress MagicMock auto-create
|
|
agent = MagicMock(spec=["interrupt", "run_conversation"])
|
|
with SESSION_AGENT_CACHE_LOCK:
|
|
SESSION_AGENT_CACHE[sid] = (agent, "sig")
|
|
handler = _make_handler()
|
|
_handle_chat_steer(handler, {"session_id": sid, "text": "hint"})
|
|
body = _captured_response(handler)
|
|
assert body["accepted"] is False
|
|
assert body["fallback"] == "agent_lacks_steer"
|
|
|
|
def test_session_not_found(self, _clear_caches):
|
|
from api.streaming import _handle_chat_steer
|
|
from api.config import SESSION_AGENT_CACHE, SESSION_AGENT_CACHE_LOCK
|
|
sid = "sid_missing"
|
|
agent = MagicMock()
|
|
agent.steer = MagicMock(return_value=True)
|
|
with SESSION_AGENT_CACHE_LOCK:
|
|
SESSION_AGENT_CACHE[sid] = (agent, "sig")
|
|
with patch("api.streaming.get_session", side_effect=KeyError(sid)):
|
|
handler = _make_handler()
|
|
_handle_chat_steer(handler, {"session_id": sid, "text": "hint"})
|
|
body = _captured_response(handler)
|
|
assert body["accepted"] is False
|
|
assert body["fallback"] == "session_not_found"
|
|
agent.steer.assert_not_called() # never reached the steer call
|
|
|
|
def test_session_not_running(self, _clear_caches):
|
|
from api.streaming import _handle_chat_steer
|
|
from api.config import SESSION_AGENT_CACHE, SESSION_AGENT_CACHE_LOCK
|
|
sid = "sid_idle"
|
|
agent = MagicMock()
|
|
agent.steer = MagicMock(return_value=True)
|
|
with SESSION_AGENT_CACHE_LOCK:
|
|
SESSION_AGENT_CACHE[sid] = (agent, "sig")
|
|
sess = MagicMock()
|
|
sess.active_stream_id = None # idle session
|
|
with patch("api.streaming.get_session", return_value=sess):
|
|
handler = _make_handler()
|
|
_handle_chat_steer(handler, {"session_id": sid, "text": "hint"})
|
|
body = _captured_response(handler)
|
|
assert body["accepted"] is False
|
|
assert body["fallback"] == "not_running"
|
|
agent.steer.assert_not_called()
|
|
|
|
def test_stream_dead(self, _clear_caches):
|
|
"""Session has active_stream_id but the stream is gone from STREAMS (e.g. crashed)."""
|
|
from api.streaming import _handle_chat_steer
|
|
from api.config import SESSION_AGENT_CACHE, SESSION_AGENT_CACHE_LOCK
|
|
sid = "sid_zombie"
|
|
agent = MagicMock()
|
|
agent.steer = MagicMock(return_value=True)
|
|
with SESSION_AGENT_CACHE_LOCK:
|
|
SESSION_AGENT_CACHE[sid] = (agent, "sig")
|
|
sess = MagicMock()
|
|
sess.active_stream_id = "stream_zombie"
|
|
with patch("api.streaming.get_session", return_value=sess):
|
|
handler = _make_handler()
|
|
_handle_chat_steer(handler, {"session_id": sid, "text": "hint"})
|
|
body = _captured_response(handler)
|
|
assert body["accepted"] is False
|
|
assert body["fallback"] == "stream_dead"
|
|
agent.steer.assert_not_called()
|
|
|
|
def test_steer_raises(self, _clear_caches):
|
|
"""If agent.steer() raises, return steer_error rather than 500."""
|
|
from api.streaming import _handle_chat_steer
|
|
from api.config import SESSION_AGENT_CACHE, SESSION_AGENT_CACHE_LOCK, STREAMS, STREAMS_LOCK
|
|
sid, stream_id = "sid_throws", "stream_throws"
|
|
agent = MagicMock()
|
|
agent.steer = MagicMock(side_effect=RuntimeError("boom"))
|
|
with SESSION_AGENT_CACHE_LOCK:
|
|
SESSION_AGENT_CACHE[sid] = (agent, "sig")
|
|
with STREAMS_LOCK:
|
|
import queue as _q
|
|
STREAMS[stream_id] = _q.Queue()
|
|
sess = MagicMock()
|
|
sess.active_stream_id = stream_id
|
|
with patch("api.streaming.get_session", return_value=sess):
|
|
handler = _make_handler()
|
|
_handle_chat_steer(handler, {"session_id": sid, "text": "hint"})
|
|
body = _captured_response(handler)
|
|
assert body["accepted"] is False
|
|
assert body["fallback"] == "steer_error"
|
|
|
|
|
|
class TestHandleChatSteerInputValidation:
|
|
"""Bad input → 400 Bad Request, not silent acceptance."""
|
|
|
|
def test_missing_session_id(self, _clear_caches):
|
|
from api.streaming import _handle_chat_steer
|
|
handler = _make_handler()
|
|
_handle_chat_steer(handler, {"text": "hint"})
|
|
assert _captured_status(handler) == 400
|
|
|
|
def test_missing_text(self, _clear_caches):
|
|
from api.streaming import _handle_chat_steer
|
|
handler = _make_handler()
|
|
_handle_chat_steer(handler, {"session_id": "sid"})
|
|
assert _captured_status(handler) == 400
|
|
|
|
def test_empty_text_after_strip(self, _clear_caches):
|
|
from api.streaming import _handle_chat_steer
|
|
handler = _make_handler()
|
|
_handle_chat_steer(handler, {"session_id": "sid", "text": " \n\t "})
|
|
assert _captured_status(handler) == 400
|
|
|
|
|
|
# ── Routing ───────────────────────────────────────────────────────────────
|
|
|
|
class TestRouting:
|
|
"""The POST handler must dispatch /api/chat/steer to _handle_chat_steer."""
|
|
|
|
def test_route_registered(self):
|
|
src = (Path(__file__).parent.parent / "api" / "routes.py").read_text(encoding="utf-8")
|
|
assert '/api/chat/steer' in src
|
|
assert '_handle_chat_steer' in src
|
|
|
|
|
|
# ── Frontend: cmdSteer + busy-mode steer use the new endpoint ────────────
|
|
|
|
class TestFrontendWiring:
|
|
"""The slash command and busy-mode steer paths must call /api/chat/steer."""
|
|
|
|
@classmethod
|
|
def setup_class(cls):
|
|
cls.cmds = (Path(__file__).parent.parent / "static" / "commands.js").read_text(encoding="utf-8")
|
|
cls.msgs = (Path(__file__).parent.parent / "static" / "messages.js").read_text(encoding="utf-8")
|
|
cls.i18n = (Path(__file__).parent.parent / "static" / "i18n.js").read_text(encoding="utf-8")
|
|
|
|
def test_cmd_steer_calls_endpoint(self):
|
|
idx = self.cmds.find("async function cmdSteer(")
|
|
assert idx >= 0
|
|
body = self.cmds[idx:idx + 600]
|
|
# Should call _trySteer (which calls the endpoint), not directly cancelStream
|
|
assert "_trySteer" in body, "cmdSteer must delegate to _trySteer"
|
|
|
|
def test_try_steer_calls_endpoint(self):
|
|
idx = self.cmds.find("async function _trySteer(")
|
|
assert idx >= 0
|
|
body = self.cmds[idx:idx + 1500]
|
|
assert "/api/chat/steer" in body, "_trySteer must POST to /api/chat/steer"
|
|
assert "method:'POST'" in body or 'method:"POST"' in body
|
|
|
|
def test_try_steer_handles_fallback(self):
|
|
idx = self.cmds.find("async function _trySteer(")
|
|
body = self.cmds[idx:idx + 1500]
|
|
# Must check result.accepted and fall back via queueSessionMessage + cancelStream
|
|
assert "result&&result.accepted" in body or "result.accepted" in body
|
|
assert "queueSessionMessage" in body
|
|
assert "cancelStream" in body, "fallback path must cancel the stream"
|
|
|
|
def test_send_busy_steer_uses_try_steer(self):
|
|
# send() in messages.js: when busyMode === 'steer', should call _trySteer
|
|
idx = self.msgs.find("busyMode==='steer'")
|
|
assert idx >= 0
|
|
block = self.msgs[idx:idx + 800]
|
|
assert "_trySteer" in block, "send()'s steer branch must delegate to _trySteer"
|
|
|
|
def test_pending_steer_leftover_listener(self):
|
|
"""Frontend must listen for pending_steer_leftover SSE events and queue them."""
|
|
idx = self.msgs.find("addEventListener('pending_steer_leftover'")
|
|
assert idx >= 0, "messages.js must add a listener for pending_steer_leftover"
|
|
block = self.msgs[idx:idx + 600]
|
|
assert "queueSessionMessage" in block, (
|
|
"pending_steer_leftover handler must queue the leftover text for the next turn"
|
|
)
|
|
|
|
|
|
# ── i18n keys ─────────────────────────────────────────────────────────────
|
|
|
|
class TestI18nKeys:
|
|
"""The two new keys (cmd_steer_delivered, steer_leftover_queued) must be in all 6 locales."""
|
|
|
|
@classmethod
|
|
def setup_class(cls):
|
|
cls.i18n = (Path(__file__).parent.parent / "static" / "i18n.js").read_text(encoding="utf-8")
|
|
|
|
def test_cmd_steer_delivered_in_all_locales(self):
|
|
assert self.i18n.count("cmd_steer_delivered:") >= 6, (
|
|
f"cmd_steer_delivered appears {self.i18n.count('cmd_steer_delivered:')} times; "
|
|
f"expected ≥6 (one per locale)"
|
|
)
|
|
|
|
def test_steer_leftover_queued_in_all_locales(self):
|
|
assert self.i18n.count("steer_leftover_queued:") >= 6, (
|
|
f"steer_leftover_queued appears {self.i18n.count('steer_leftover_queued:')} times; "
|
|
f"expected ≥6 (one per locale)"
|
|
)
|
|
|
|
|
|
# ── Leftover SSE delivery: streaming.py emits pending_steer_leftover ─────
|
|
|
|
class TestLeftoverDelivery:
|
|
"""After run_conversation returns, _drain_pending_steer is called and a
|
|
pending_steer_leftover SSE event is emitted if there's still text stashed."""
|
|
|
|
def test_leftover_drain_call_in_streaming(self):
|
|
"""Verify the streaming.py source contains the drain call before put('done', ...)."""
|
|
src = (Path(__file__).parent.parent / "api" / "streaming.py").read_text(encoding="utf-8")
|
|
assert "_drain_pending_steer" in src, (
|
|
"_run_agent_streaming must call agent._drain_pending_steer() to deliver leftovers"
|
|
)
|
|
assert "pending_steer_leftover" in src, (
|
|
"_run_agent_streaming must emit a pending_steer_leftover SSE event"
|
|
)
|
|
|
|
def test_leftover_drain_runs_before_done_event(self):
|
|
"""The drain must happen BEFORE put('done', ...) so frontend gets both events
|
|
on the same turn."""
|
|
src = (Path(__file__).parent.parent / "api" / "streaming.py").read_text(encoding="utf-8")
|
|
# Find the drain invocation and the next put('done', ...) AFTER it
|
|
drain_idx = src.find("_drain_pending_steer()")
|
|
assert drain_idx >= 0
|
|
done_idx = src.find("put('done'", drain_idx)
|
|
assert done_idx >= 0
|
|
# No put('done', ...) should appear BEFORE the drain in the same code block
|
|
# (we already check the drain is in the file; ordering matters within the
|
|
# non-ephemeral success path)
|
|
assert drain_idx < done_idx, (
|
|
"_drain_pending_steer must run before put('done', ...) so the SSE listener "
|
|
"sees the leftover before stream_end fires"
|
|
)
|