diff --git a/gateway/platforms/yuanbao.py b/gateway/platforms/yuanbao.py index 89b2a82942..18d0787c97 100644 --- a/gateway/platforms/yuanbao.py +++ b/gateway/platforms/yuanbao.py @@ -1410,33 +1410,43 @@ class RecallGuardMiddleware(InboundMiddleware): logger.warning("[%s] Recall: failed to resolve session: %s", adapter.name, exc) return - # Load transcript from canonical store (state.db). See Branch A below - # for why we can no longer match by platform `message_id`. + # Load transcript from canonical store (state.db). Since PR #29278 + # added a ``platform_message_id`` column to the messages table and + # ``append_to_transcript`` wires the incoming dict's ``message_id`` + # into it, ``load_transcript`` returns rows with ``message_id`` set + # for any message that was observed with one — Branch A1 (exact id + # match) is the canonical path again. try: transcript = store.load_transcript(sid) except Exception as exc: logger.warning("[%s] Recall: failed to load transcript: %s", adapter.name, exc) return - # Branch A: content-match redaction. state.db does NOT preserve the - # platform `message_id` (only its own autoincrement primary key), so we - # cannot redact by exact id. Match by content instead. Most yuanbao - # recalls carry the recalled text via `recalled_content`, which is - # sufficient for any non-duplicate message. - # - # TODO: add a `platform_message_id` column to state.db messages to - # restore exact-id matching. Tracked separately. + # Branch A1: exact platform message_id match. Authoritative when the + # row was persisted with a platform_message_id (observed group + # messages and any inbound message whose adapter carried a msg_id). target = None - if recalled_content: + branch_label = "" + for entry in transcript: + if entry.get("message_id") == recalled_id: + target = entry + branch_label = "branch A1: id match" + break + # Branch A2: content-match fallback for messages that lack an exact + # platform id on the row — e.g. agent-processed @bot messages + # (run.py doesn't carry msg_id through) or older rows persisted + # before the platform_message_id column existed. + if target is None and recalled_content: for entry in transcript: if entry.get("role") == "user" and entry.get("content") == recalled_content: target = entry + branch_label = "branch A2: content match" break if target is not None: target["content"] = cls._REDACTED try: store.rewrite_transcript(sid, transcript) - logger.info("[%s] Recall: redacted msg_id=%s (branch A: content match)", adapter.name, recalled_id) + logger.info("[%s] Recall: redacted msg_id=%s (%s)", adapter.name, recalled_id, branch_label) except Exception as exc: logger.warning("[%s] Recall: rewrite_transcript failed: %s", adapter.name, exc) return diff --git a/gateway/session.py b/gateway/session.py index 4ad2600c1e..648f8cddf1 100644 --- a/gateway/session.py +++ b/gateway/session.py @@ -1271,6 +1271,12 @@ class SessionStore: reasoning_details=message.get("reasoning_details") if message.get("role") == "assistant" else None, codex_reasoning_items=message.get("codex_reasoning_items") if message.get("role") == "assistant" else None, codex_message_items=message.get("codex_message_items") if message.get("role") == "assistant" else None, + # Platform-side message id (yuanbao msg_id, telegram update_id, …). + # Accept either explicit ``platform_message_id`` or the legacy + # ``message_id`` key the JSONL transcript used. + platform_message_id=( + message.get("platform_message_id") or message.get("message_id") + ), ) except Exception as e: logger.debug("Session DB operation failed: %s", e) diff --git a/hermes_state.py b/hermes_state.py index e8e8947c05..5804437198 100644 --- a/hermes_state.py +++ b/hermes_state.py @@ -33,7 +33,7 @@ T = TypeVar("T") DEFAULT_DB_PATH = get_hermes_home() / "state.db" -SCHEMA_VERSION = 11 +SCHEMA_VERSION = 12 # --------------------------------------------------------------------------- # WAL-compatibility fallback @@ -236,7 +236,8 @@ CREATE TABLE IF NOT EXISTS messages ( reasoning_content TEXT, reasoning_details TEXT, codex_reasoning_items TEXT, - codex_message_items TEXT + codex_message_items TEXT, + platform_message_id TEXT ); CREATE TABLE IF NOT EXISTS state_meta ( @@ -571,6 +572,19 @@ class SessionDB: # column gets created here. self._reconcile_columns(cursor) + # Indexes that reference reconciler-added columns must be created + # AFTER _reconcile_columns runs — declaring them in SCHEMA_SQL + # makes the initial executescript fail on legacy DBs (the index's + # WHERE clause references a column that doesn't exist yet). + try: + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_messages_platform_msg_id " + "ON messages(session_id, platform_message_id) " + "WHERE platform_message_id IS NOT NULL" + ) + except sqlite3.OperationalError as exc: + logger.debug("idx_messages_platform_msg_id create skipped: %s", exc) + # ── Schema version bookkeeping ───────────────────────────────── # Bump to current so future data migrations (if any) can gate on # version. No version-gated column additions remain. @@ -1445,12 +1459,19 @@ class SessionDB: reasoning_details: Any = None, codex_reasoning_items: Any = None, codex_message_items: Any = None, + platform_message_id: str = None, ) -> int: """ Append a message to a session. Returns the message row ID. Also increments the session's message_count (and tool_call_count if role is 'tool' or tool_calls is present). + + ``platform_message_id`` is the external messaging platform's own + message ID (e.g. Telegram update_id, Yuanbao msg_id). It is + independent of the SQLite autoincrement primary key and is used by + platform-specific flows like yuanbao's recall guard to redact a + message by its platform-side identifier. """ # Serialize structured fields to JSON before entering the write txn reasoning_details_json = ( @@ -1480,8 +1501,8 @@ class SessionDB: """INSERT INTO messages (session_id, role, content, tool_call_id, tool_calls, tool_name, timestamp, token_count, finish_reason, reasoning, reasoning_content, reasoning_details, codex_reasoning_items, - codex_message_items) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", + codex_message_items, platform_message_id) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", ( session_id, role, @@ -1497,6 +1518,7 @@ class SessionDB: reasoning_details_json, codex_items_json, codex_message_items_json, + platform_message_id, ), ) msg_id = cursor.lastrowid @@ -1558,13 +1580,18 @@ class SessionDB: json.dumps(codex_message_items) if codex_message_items else None ) tool_calls_json = json.dumps(tool_calls) if tool_calls else None + # Accept either `platform_message_id` (new explicit name) or + # `message_id` (yuanbao's existing convention on message dicts). + platform_msg_id = ( + msg.get("platform_message_id") or msg.get("message_id") + ) conn.execute( """INSERT INTO messages (session_id, role, content, tool_call_id, tool_calls, tool_name, timestamp, token_count, finish_reason, reasoning, reasoning_content, reasoning_details, codex_reasoning_items, - codex_message_items) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", + codex_message_items, platform_message_id) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", ( session_id, role, @@ -1580,6 +1607,7 @@ class SessionDB: reasoning_details_json, codex_items_json, codex_message_items_json, + platform_msg_id, ), ) total_messages += 1 @@ -1897,7 +1925,7 @@ class SessionDB: rows = self._conn.execute( "SELECT role, content, tool_call_id, tool_calls, tool_name, " "finish_reason, reasoning, reasoning_content, reasoning_details, " - "codex_reasoning_items, codex_message_items " + "codex_reasoning_items, codex_message_items, platform_message_id " f"FROM messages WHERE session_id IN ({placeholders}) ORDER BY id", tuple(session_ids), ).fetchall() @@ -1918,6 +1946,13 @@ class SessionDB: except (json.JSONDecodeError, TypeError): logger.warning("Failed to deserialize tool_calls in conversation replay, falling back to []") msg["tool_calls"] = [] + # Surface the platform-side message id (e.g. yuanbao msg_id, + # telegram update_id) so platform-specific flows like recall + # can match by external identifier instead of having to fall + # back to content-match heuristics. Exposed as ``message_id`` + # for backward compatibility with the JSONL transcript shape. + if row["platform_message_id"]: + msg["message_id"] = row["platform_message_id"] # Restore reasoning fields on assistant messages so providers # that replay reasoning (OpenRouter, OpenAI, Nous) receive # coherent multi-turn reasoning context. diff --git a/tests/gateway/platforms/test_yuanbao_recall_db_only.py b/tests/gateway/platforms/test_yuanbao_recall_db_only.py index f54a5f3467..3b8cd6d912 100644 --- a/tests/gateway/platforms/test_yuanbao_recall_db_only.py +++ b/tests/gateway/platforms/test_yuanbao_recall_db_only.py @@ -1,31 +1,88 @@ -"""Yuanbao recall: branch A (content-match) works against DB-only transcripts.""" +"""Yuanbao recall: branch A1 (exact id) and A2 (content-match) against DB-only transcripts. + +state.db persists the platform-side ``message_id`` via the +``platform_message_id`` column (added in the salvage of PR #29211) and +``load_transcript`` surfaces it back on each message dict as ``message_id`` +— so the recall guard's exact-id match path stays canonical even with the +JSONL file gone. When a row has no platform id (e.g. agent-processed +@bot messages whose adapter didn't carry a msg_id, or pre-column legacy +rows), recall falls through to content-match. +""" from gateway.session import SessionStore from gateway.config import GatewayConfig -def test_recall_content_match_finds_target_in_db_transcript(tmp_path, monkeypatch): - """state.db doesn't preserve message_id, so recall uses content-match. - - Pin DEFAULT_DB_PATH to tmp_path so SessionDB() can't write to the real - ~/.hermes/state.db. (Module-level constant snapshot, see test_load_transcript_db_only.) - """ +def _pin_db(monkeypatch, tmp_path): + """Force SessionDB() to write into tmp_path instead of the real ~/.hermes.""" import hermes_state monkeypatch.setattr(hermes_state, "DEFAULT_DB_PATH", tmp_path / "state.db") + +def test_recall_branch_a1_exact_id_match_round_trips_through_db(tmp_path, monkeypatch): + """A user message persisted with ``message_id`` must round-trip through + state.db so recall can find and redact it by exact id (branch A1).""" + _pin_db(monkeypatch, tmp_path) + config = GatewayConfig() store = SessionStore(sessions_dir=tmp_path, config=config) - sid = "test-yuanbao-recall" + sid = "test-yuanbao-recall-a1" store._db.create_session(session_id=sid, source="yuanbao:group:G") - store.append_to_transcript(sid, {"role": "user", "content": "sensitive content", "timestamp": 1.0}) - store.append_to_transcript(sid, {"role": "assistant", "content": "ack", "timestamp": 2.0}) + store.append_to_transcript(sid, { + "role": "user", + "content": "sensitive content", + "timestamp": 1.0, + "message_id": "platform-msg-abc", + }) + store.append_to_transcript(sid, { + "role": "assistant", + "content": "ack", + "timestamp": 2.0, + }) - # DB-only history carries no platform message_id (PR #29211 dropped that path). history = store.load_transcript(sid) - assert all("message_id" not in msg for msg in history) + # The user row must carry its platform id back so the recall guard can + # match by exact id; the assistant row had no platform id so it should + # not gain one spuriously. + user_msg = next(m for m in history if m["role"] == "user") + assistant_msg = next(m for m in history if m["role"] == "assistant") + assert user_msg.get("message_id") == "platform-msg-abc" + assert "message_id" not in assistant_msg - # Branch A: content match finds the target row that recall would redact. - target = next((m for m in history - if m.get("role") == "user" and m.get("content") == "sensitive content"), None) + # Branch A1: locate the row by exact platform id — no content heuristics. + target = next( + (m for m in history if m.get("message_id") == "platform-msg-abc"), + None, + ) + assert target is not None + assert target["content"] == "sensitive content" + + +def test_recall_branch_a2_content_match_when_no_platform_id(tmp_path, monkeypatch): + """Rows that lack a platform_message_id (e.g. agent-processed @bot + messages) still match by content as a fallback.""" + _pin_db(monkeypatch, tmp_path) + + config = GatewayConfig() + store = SessionStore(sessions_dir=tmp_path, config=config) + + sid = "test-yuanbao-recall-a2" + store._db.create_session(session_id=sid, source="yuanbao:group:G") + # No message_id on the dict — simulates an agent-processed message + # that did not carry the platform msg_id through. + store.append_to_transcript(sid, { + "role": "user", + "content": "sensitive content", + "timestamp": 1.0, + }) + + history = store.load_transcript(sid) + assert all("message_id" not in m for m in history) + + # Branch A2: content match recovers the target. + target = next( + (m for m in history + if m.get("role") == "user" and m.get("content") == "sensitive content"), + None, + ) assert target is not None - # Caller would then redact: target["content"] = REDACTED; store.rewrite_transcript(sid, history) diff --git a/tests/test_hermes_state.py b/tests/test_hermes_state.py index 2676457f58..7c3cae7552 100644 --- a/tests/test_hermes_state.py +++ b/tests/test_hermes_state.py @@ -316,6 +316,42 @@ class TestMessageStorage: assert conv[0] == {"role": "user", "content": "Hello"} assert conv[1] == {"role": "assistant", "content": "Hi!"} + def test_platform_message_id_round_trips(self, db): + """Platform-side message ids (yuanbao msg_id, telegram update_id, …) + survive append → get_messages_as_conversation under the + ``message_id`` key so platform recall flows can match by exact id.""" + db.create_session(session_id="s_pmi", source="yuanbao") + db.append_message( + "s_pmi", + role="user", + content="hi", + platform_message_id="abc-123", + ) + db.append_message("s_pmi", role="assistant", content="hello") + + conv = db.get_messages_as_conversation("s_pmi") + user_msg = next(m for m in conv if m["role"] == "user") + assistant_msg = next(m for m in conv if m["role"] == "assistant") + assert user_msg.get("message_id") == "abc-123" + # Assistant row had no platform id — must not gain one spuriously. + assert "message_id" not in assistant_msg + + def test_replace_messages_preserves_platform_message_id(self, db): + """``rewrite_transcript`` (which goes through replace_messages) must + keep the platform_message_id round-trip working for /retry, /undo, + /compress and yuanbao's recall rewrite path.""" + db.create_session(session_id="s_rep", source="yuanbao") + db.replace_messages( + "s_rep", + [ + {"role": "user", "content": "x", "message_id": "ext-1"}, + {"role": "assistant", "content": "y"}, + ], + ) + conv = db.get_messages_as_conversation("s_rep") + assert next(m for m in conv if m["role"] == "user").get("message_id") == "ext-1" + assert "message_id" not in next(m for m in conv if m["role"] == "assistant") + def test_get_messages_as_conversation_includes_ancestor_chain(self, db): db.create_session("root", "tui") db.append_message("root", role="user", content="first prompt") @@ -1462,9 +1498,10 @@ class TestSchemaInit: assert "schema_version" in tables def test_schema_version(self, db): + from hermes_state import SCHEMA_VERSION cursor = db._conn.execute("SELECT version FROM schema_version") version = cursor.fetchone()[0] - assert version == 11 + assert version == SCHEMA_VERSION def test_title_column_exists(self, db): """Verify the title column was created in the sessions table.""" @@ -1760,8 +1797,9 @@ class TestSchemaInit: migrated_db = SessionDB(db_path=db_path) # Verify migration + from hermes_state import SCHEMA_VERSION cursor = migrated_db._conn.execute("SELECT version FROM schema_version") - assert cursor.fetchone()[0] == 11 + assert cursor.fetchone()[0] == SCHEMA_VERSION # Verify title column exists and is NULL for existing sessions session = migrated_db.get_session("existing") @@ -2952,11 +2990,12 @@ class TestFTS5ToolCallMigration: assert len(session_db.search_messages("LEGACYARG")) == 1, \ "v11 migration must backfill tool_calls JSON into FTS" # schema_version bumped + from hermes_state import SCHEMA_VERSION row = session_db._conn.execute( "SELECT version FROM schema_version LIMIT 1" ).fetchone() version = row["version"] if hasattr(row, "keys") else row[0] - assert version == 11 + assert version == SCHEMA_VERSION finally: session_db.close()