""" Tests for auth session lifecycle — session creation, verification, expiry, and lazy pruning of expired entries. """ import time import unittest from pathlib import Path import tempfile import os # Isolate state dir so we don't touch real sessions _TEST_STATE = Path(tempfile.mkdtemp()) os.environ["HERMES_WEBUI_STATE_DIR"] = str(_TEST_STATE) import sys sys.path.insert(0, str(Path(__file__).parent.parent)) import importlib # Force re-import of auth module so it picks up our TEST_STATE_DIR auth = importlib.import_module("api.auth") class TestSessionPruning(unittest.TestCase): """Verify expired session cleanup works correctly.""" def setUp(self): # Clear any leftover sessions from other tests auth._sessions.clear() def test_session_created_valid(self): """A fresh session token should verify as valid.""" token = auth.create_session() self.assertTrue(auth.verify_session(token)) def test_expired_session_pruned(self): """Manually inserting an expired entry should be pruned on next verify_session call.""" # Insert sessions that have already expired auth._sessions["fake_token"] = time.time() - 100 auth._sessions["another_fake"] = time.time() - 50 # Insert one valid session (far future) auth._sessions["good_token"] = time.time() + 3600 # _sessions has 3 entries, 2 expired self.assertEqual(len(auth._sessions), 3) # Call verify_session — this triggers _prune_expired_sessions() # Cookie format is token.signature, so we need a dot to pass the early check auth.verify_session("fake_token.fake_sig") # After verification, only the valid session should remain self.assertEqual(len(auth._sessions), 1) self.assertIn("good_token", auth._sessions) self.assertNotIn("fake_token", auth._sessions) self.assertNotIn("another_fake", auth._sessions) def test_prune_does_not_remove_valid_sessions(self): """_prune_expired_sessions should never remove sessions that are still active.""" auth._sessions["active_1"] = time.time() + 86400 # 24 hours from now auth._sessions["active_2"] = time.time() + 7200 # 2 hours from now auth._sessions["expired_1"] = time.time() - 10 auth._prune_expired_sessions() self.assertEqual(len(auth._sessions), 2) self.assertIn("active_1", auth._sessions) self.assertIn("active_2", auth._sessions) self.assertNotIn("expired_1", auth._sessions) def test_verify_session_prunes_before_verification(self): """verify_session should prune expired entries before checking the target token. This ensures that _prune_expired_sessions() is called at the very top of verify_session(), so cleanup happens on every auth check. """ auth._sessions["expired_for_test"] = time.time() - 999 # verify_session with an invalid cookie triggers the full path: # _prune_expired_sessions -> signature check -> return False result = auth.verify_session("nonexistent.bad_sig") self.assertFalse(result) # The expired entry should have been cleaned up self.assertNotIn("expired_for_test", auth._sessions) def test_prune_handles_empty_dict(self): """_prune_expired_sessions should be safe on an empty dict.""" auth._sessions.clear() auth._prune_expired_sessions() self.assertEqual(len(auth._sessions), 0) def test_session_ttl_is_24_hours(self): """Newly created sessions should have the expected 24-hour TTL.""" auth._sessions.clear() token_hex = auth.create_session().split(".")[0] # The _sessions dict stores token -> expiry_time # We can check the expiry is approximately SESSION_TTL seconds from now # by looking up the raw entry via the token from api.auth import _sessions, SESSION_TTL # find our entry for t, exp in _sessions.items(): if t == token_hex: # expiry should be within 5 seconds of now + SESSION_TTL expected = time.time() + SESSION_TTL self.assertAlmostEqual(exp, expected, delta=5) break else: self.fail("Session token not found in _sessions") class TestSessionInvalidation(unittest.TestCase): """Test session logout / invalidation.""" def setUp(self): auth._sessions.clear() def test_invalidate_session_removes_token(self): """Calling invalidate_session should remove the token from _sessions.""" token = auth.create_session() self.assertTrue(auth.verify_session(token)) auth.invalidate_session(token) # Token should be gone self.assertFalse(auth.verify_session(token)) def test_invalidate_unknown_token_is_safe(self): """Invalidating a non-existent token should not raise.""" auth._sessions.clear() auth.invalidate_session("nonexistent_token") # Should not raise if __name__ == "__main__": unittest.main() class TestSessionTtlResolution(unittest.TestCase): """Verify the three-layer TTL resolution (env > settings > default).""" def setUp(self): # Snapshot environment + load_settings so each test starts clean. self._saved_env = { k: os.environ.get(k) for k in ("HERMES_WEBUI_SESSION_TTL",) } os.environ.pop("HERMES_WEBUI_SESSION_TTL", None) self._saved_load_settings = auth.load_settings def tearDown(self): for k, v in self._saved_env.items(): if v is None: os.environ.pop(k, None) else: os.environ[k] = v auth.load_settings = self._saved_load_settings def test_env_var_overrides_settings(self): """HERMES_WEBUI_SESSION_TTL env var should take priority.""" os.environ["HERMES_WEBUI_SESSION_TTL"] = "3600" from api.auth import _resolve_session_ttl self.assertEqual(_resolve_session_ttl(), 3600) def test_clamps_minimum(self): """Values below 60 seconds fall through to settings/default (do not honor).""" os.environ["HERMES_WEBUI_SESSION_TTL"] = "10" auth.load_settings = lambda: {} from api.auth import _resolve_session_ttl # Out-of-range env values are rejected; falls through to default 30 days. self.assertEqual(_resolve_session_ttl(), auth.SESSION_TTL) def test_clamps_maximum(self): """Values above 1 year fall through to settings/default (do not honor).""" os.environ["HERMES_WEBUI_SESSION_TTL"] = "100000000" auth.load_settings = lambda: {} from api.auth import _resolve_session_ttl # Out-of-range env values are rejected; falls through to default 30 days. self.assertEqual(_resolve_session_ttl(), auth.SESSION_TTL) def test_invalid_env_falls_through(self): """Non-integer env var falls through to default.""" os.environ["HERMES_WEBUI_SESSION_TTL"] = "not-a-number" auth.load_settings = lambda: {} from api.auth import _resolve_session_ttl self.assertEqual(_resolve_session_ttl(), auth.SESSION_TTL) def test_empty_env_falls_through(self): """Empty env var falls through to default.""" os.environ["HERMES_WEBUI_SESSION_TTL"] = "" auth.load_settings = lambda: {} from api.auth import _resolve_session_ttl self.assertEqual(_resolve_session_ttl(), auth.SESSION_TTL) def test_settings_path_returns_value(self): """settings.json session_ttl_seconds path works when env is unset.""" os.environ.pop("HERMES_WEBUI_SESSION_TTL", None) auth.load_settings = lambda: {"session_ttl_seconds": 7200} from api.auth import _resolve_session_ttl self.assertEqual(_resolve_session_ttl(), 7200) def test_session_uses_dynamic_ttl(self): """Newly created sessions should honor the resolved TTL.""" auth._sessions.clear() os.environ["HERMES_WEBUI_SESSION_TTL"] = "3600" token_hex = auth.create_session().split(".")[0] from api.auth import _sessions for t, exp in _sessions.items(): if t == token_hex: # The resolved env-var value (3600s) should be applied, not # the SESSION_TTL fallback default. expected = time.time() + 3600 self.assertAlmostEqual(exp, expected, delta=5) break else: self.fail("Session token not found in _sessions")