mirror of
https://github.com/nesquena/hermes-webui.git
synced 2026-05-19 13:47:04 +00:00
47c210899e
Co-authored-by: Michael Lam <michael@example.local>
196 lines
7.6 KiB
Python
196 lines
7.6 KiB
Python
import queue
|
|
import sys
|
|
import types
|
|
from unittest import mock
|
|
|
|
# Sentinel for sys.modules save/restore — distinguishes "key wasn't there" from None.
|
|
_MISSING = object()
|
|
|
|
|
|
def test_stream_completion_overwrites_session_usage_with_latest_turn(cleanup_test_sessions):
|
|
"""#1857: completed turns must not add prompt tokens to stale session totals."""
|
|
import api.streaming as streaming
|
|
|
|
saved_snapshots = []
|
|
|
|
class FakeSession:
|
|
def __init__(self):
|
|
self.session_id = "issue1857_usage_overwrite"
|
|
self.title = "Existing title"
|
|
self.workspace = "/tmp"
|
|
self.model = "gpt-5.4"
|
|
self.model_provider = None
|
|
self.profile = None
|
|
self.personality = None
|
|
self.messages = [
|
|
{"role": "user", "content": "old"},
|
|
{"role": "assistant", "content": "old answer"},
|
|
]
|
|
self.context_messages = list(self.messages)
|
|
self.input_tokens = 9000
|
|
self.output_tokens = 800
|
|
self.estimated_cost = 12.34
|
|
self.cache_read_tokens = 1000
|
|
self.cache_write_tokens = 200
|
|
self.tool_calls = []
|
|
self.gateway_routing = None
|
|
self.gateway_routing_history = []
|
|
self.active_stream_id = None
|
|
self.pending_user_message = None
|
|
self.pending_attachments = []
|
|
self.pending_started_at = None
|
|
self.context_length = 0
|
|
self.threshold_tokens = 0
|
|
self.last_prompt_tokens = 0
|
|
self.llm_title_generated = True
|
|
|
|
def save(self, *args, **kwargs):
|
|
saved_snapshots.append(
|
|
{
|
|
"input_tokens": self.input_tokens,
|
|
"output_tokens": self.output_tokens,
|
|
"estimated_cost": self.estimated_cost,
|
|
"cache_read_tokens": self.cache_read_tokens,
|
|
"cache_write_tokens": self.cache_write_tokens,
|
|
"kwargs": kwargs,
|
|
}
|
|
)
|
|
|
|
def compact(self):
|
|
return {
|
|
"session_id": self.session_id,
|
|
"title": self.title,
|
|
"workspace": self.workspace,
|
|
"model": self.model,
|
|
"created_at": 0,
|
|
"updated_at": 0,
|
|
"pinned": False,
|
|
"archived": False,
|
|
"project_id": None,
|
|
"profile": self.profile,
|
|
"input_tokens": self.input_tokens,
|
|
"output_tokens": self.output_tokens,
|
|
"estimated_cost": self.estimated_cost,
|
|
"cache_read_tokens": self.cache_read_tokens,
|
|
"cache_write_tokens": self.cache_write_tokens,
|
|
"personality": self.personality,
|
|
}
|
|
|
|
class UsageAgent:
|
|
def __init__(
|
|
self,
|
|
model=None,
|
|
provider=None,
|
|
base_url=None,
|
|
api_key=None,
|
|
platform=None,
|
|
quiet_mode=False,
|
|
enabled_toolsets=None,
|
|
fallback_model=None,
|
|
session_id=None,
|
|
session_db=None,
|
|
stream_delta_callback=None,
|
|
reasoning_callback=None,
|
|
tool_progress_callback=None,
|
|
clarify_callback=None,
|
|
):
|
|
self.session_id = session_id
|
|
self.context_compressor = None
|
|
self.session_prompt_tokens = 123
|
|
self.session_completion_tokens = 45
|
|
self.session_estimated_cost_usd = 0.067
|
|
self.session_cache_read_tokens = 9000
|
|
self.session_cache_write_tokens = 1000
|
|
self.reasoning_config = None
|
|
self.ephemeral_system_prompt = None
|
|
self._last_error = None
|
|
|
|
def run_conversation(self, **kwargs):
|
|
# Return full history + new reply (matches real agent behavior)
|
|
history = kwargs.get("conversation_history", [])
|
|
return {
|
|
"messages": history + [
|
|
{"role": "user", "content": kwargs["persist_user_message"]},
|
|
{"role": "assistant", "content": "new answer"},
|
|
]
|
|
}
|
|
|
|
def interrupt(self, _message):
|
|
pass
|
|
|
|
fake_session = FakeSession()
|
|
fake_stream_id = "stream_issue1857_usage_overwrite"
|
|
fake_session.active_stream_id = fake_stream_id
|
|
fake_queue = queue.Queue()
|
|
fake_runtime_module = types.ModuleType("hermes_cli.runtime_provider")
|
|
fake_runtime_module.resolve_runtime_provider = mock.Mock(
|
|
return_value={
|
|
"provider": "openai",
|
|
"base_url": None,
|
|
"api_key": "sk-test",
|
|
"api_mode": "chat_completions",
|
|
"command": None,
|
|
"args": [],
|
|
"credential_pool": None,
|
|
}
|
|
)
|
|
fake_hermes_cli = types.ModuleType("hermes_cli")
|
|
fake_hermes_cli.runtime_provider = fake_runtime_module
|
|
fake_hermes_state = types.ModuleType("hermes_state")
|
|
fake_hermes_state.SessionDB = mock.Mock(return_value=None)
|
|
|
|
# NOTE: We deliberately avoid mock.patch.dict(sys.modules, ...) here.
|
|
# patch.dict tracks original keys at __enter__ and on __exit__ DELETES any
|
|
# keys added during the patch that weren't in the original snapshot. That
|
|
# silently evicts lazily-imported submodules (e.g. pydantic.root_model)
|
|
# that other tests rely on, producing KeyError: 'pydantic.root_model' in
|
|
# downstream tests (notably tests/test_mcp_server.py via fastmcp imports).
|
|
# Manual save/restore only touches the three keys we explicitly inject.
|
|
_injected = {
|
|
"hermes_cli": fake_hermes_cli,
|
|
"hermes_cli.runtime_provider": fake_runtime_module,
|
|
"hermes_state": fake_hermes_state,
|
|
}
|
|
_saved = {k: sys.modules.get(k, _MISSING) for k in _injected}
|
|
sys.modules.update(_injected)
|
|
try:
|
|
with mock.patch.object(streaming, "get_session", return_value=fake_session), \
|
|
mock.patch.object(streaming, "_get_ai_agent", return_value=UsageAgent), \
|
|
mock.patch.object(streaming, "resolve_model_provider", return_value=("gpt-5.4", "openai", None)), \
|
|
mock.patch("api.config.get_config", return_value={}), \
|
|
mock.patch("api.config._resolve_cli_toolsets", return_value=[]):
|
|
streaming.STREAMS[fake_stream_id] = fake_queue
|
|
streaming._run_agent_streaming(
|
|
session_id=fake_session.session_id,
|
|
msg_text="new turn",
|
|
model="gpt-5.4",
|
|
workspace="/tmp",
|
|
stream_id=fake_stream_id,
|
|
)
|
|
finally:
|
|
for k, prev in _saved.items():
|
|
if prev is _MISSING:
|
|
sys.modules.pop(k, None)
|
|
else:
|
|
sys.modules[k] = prev
|
|
|
|
assert fake_session.input_tokens == 123
|
|
assert fake_session.output_tokens == 45
|
|
assert fake_session.estimated_cost == 0.067
|
|
assert fake_session.cache_read_tokens == 9000
|
|
assert fake_session.cache_write_tokens == 1000
|
|
assert any(
|
|
event == "done"
|
|
and payload["usage"]["input_tokens"] == 123
|
|
and payload["usage"]["output_tokens"] == 45
|
|
and payload["usage"]["estimated_cost"] == 0.067
|
|
and payload["usage"]["cache_read_tokens"] == 9000
|
|
and payload["usage"]["cache_write_tokens"] == 1000
|
|
for event, payload in list(fake_queue.queue)
|
|
)
|
|
assert saved_snapshots[-1]["input_tokens"] == 123
|
|
assert saved_snapshots[-1]["output_tokens"] == 45
|
|
assert saved_snapshots[-1]["estimated_cost"] == 0.067
|
|
assert saved_snapshots[-1]["cache_read_tokens"] == 9000
|
|
assert saved_snapshots[-1]["cache_write_tokens"] == 1000
|