fix(xai-oauth): quarantine dead tokens on terminal refresh failure

resolve_xai_oauth_runtime_credentials() called _refresh_xai_oauth_tokens()
with no try/except. A terminal refresh failure (HTTP 400/401/403 —
invalid_grant, token revoked) propagated without clearing the dead
access_token / refresh_token from auth.json, causing every subsequent
session to retry the same doomed network request.

Add a try/except around the refresh call that mirrors the existing
credential_pool.py quarantine: when _is_terminal_xai_oauth_refresh_error
identifies a non-retryable failure, clear the dead token fields from
auth.json and write a last_auth_error diagnostic marker so future calls
fail fast with a clear relogin_required error instead of hitting the
network.

active_provider is preserved (set_active=False) so multi-provider users
whose chosen provider is not xai-oauth are unaffected.

Tests: two new cases in test_auth_xai_oauth_provider.py cover terminal
quarantine and transient pass-through.
This commit is contained in:
EloquentBrush0x
2026-05-18 22:53:51 +03:00
committed by Teknium
parent 7321b3c2db
commit b3e714e8b7
2 changed files with 152 additions and 7 deletions
+35 -7
View File
@@ -3546,13 +3546,41 @@ def resolve_xai_oauth_runtime_credentials(
if should_refresh:
if not token_endpoint:
token_endpoint = _xai_oauth_discovery(refresh_timeout_seconds)["token_endpoint"]
tokens = _refresh_xai_oauth_tokens(
tokens,
token_endpoint=token_endpoint,
redirect_uri=redirect_uri,
timeout_seconds=refresh_timeout_seconds,
)
access_token = str(tokens.get("access_token", "") or "").strip()
try:
tokens = _refresh_xai_oauth_tokens(
tokens,
token_endpoint=token_endpoint,
redirect_uri=redirect_uri,
timeout_seconds=refresh_timeout_seconds,
)
access_token = str(tokens.get("access_token", "") or "").strip()
except AuthError as exc:
if _is_terminal_xai_oauth_refresh_error(exc):
# Terminal failure (HTTP 400/401/403 — invalid_grant, token revoked).
# Clear dead tokens from auth.json so subsequent sessions fail fast
# without a network retry. Mirrors credential_pool.py quarantine.
try:
_q_store = _load_auth_store()
_q_state = _load_provider_state(_q_store, "xai-oauth") or {}
_q_tokens = dict(_q_state.get("tokens") or {})
_q_tokens.pop("access_token", None)
_q_tokens.pop("refresh_token", None)
_q_state["tokens"] = _q_tokens
_q_state["last_auth_error"] = {
"provider": "xai-oauth",
"code": exc.code or "xai_refresh_failed",
"message": str(exc),
"reason": "runtime_refresh_failure",
"relogin_required": True,
"at": datetime.now(timezone.utc).isoformat(),
}
_store_provider_state(_q_store, "xai-oauth", _q_state, set_active=False)
_save_auth_store(_q_store)
except Exception as _save_exc:
logger.debug(
"xAI OAuth: failed to persist quarantined state: %s", _save_exc,
)
raise
base_url = (
os.getenv("HERMES_XAI_BASE_URL", "").strip().rstrip("/")
@@ -553,6 +553,123 @@ def test_resolve_xai_runtime_credentials_honours_env_base_url(tmp_path, monkeypa
assert creds["base_url"] == "https://custom.x.ai/v1"
# ---------------------------------------------------------------------------
# Quarantine: terminal refresh failure clears dead tokens (#28155 sibling)
# ---------------------------------------------------------------------------
_STALE_XAI_OAUTH_STATE = {
"tokens": {
"access_token": "dead-access-token",
"refresh_token": "dead-refresh-token",
"id_token": "",
"expires_in": 3600,
"token_type": "Bearer",
},
"discovery": {"token_endpoint": "https://auth.x.ai/oauth2/token"},
"redirect_uri": "http://127.0.0.1:51827/callback",
"last_refresh": "2000-01-01T00:00:00Z",
"auth_mode": "oauth_pkce",
}
def _seed_xai_oauth_state(
hermes_home: Path, state: dict, *, active_provider: str = "xai-oauth"
) -> None:
hermes_home.mkdir(parents=True, exist_ok=True)
auth_store = {
"version": 1,
"active_provider": active_provider,
"providers": {"xai-oauth": state},
}
(hermes_home / "auth.json").write_text(json.dumps(auth_store, indent=2))
def test_resolve_credentials_quarantines_dead_tokens_on_terminal_refresh_failure(
tmp_path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Terminal refresh failure (relogin_required=True, code=xai_refresh_failed)
must clear access_token/refresh_token from auth.json and write a
last_auth_error marker so subsequent calls fail fast without a network retry.
Mirrors the credential_pool.py quarantine for the singleton/direct resolve path.
"""
hermes_home = tmp_path / "hermes"
_seed_xai_oauth_state(hermes_home, dict(_STALE_XAI_OAUTH_STATE), active_provider="nous")
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
def _terminal_refresh(tokens, **kwargs):
raise AuthError(
"xAI token refresh failed. Response: invalid_grant",
provider="xai-oauth",
code="xai_refresh_failed",
relogin_required=True,
)
monkeypatch.setattr("hermes_cli.auth._refresh_xai_oauth_tokens", _terminal_refresh)
with pytest.raises(AuthError) as exc_info:
resolve_xai_oauth_runtime_credentials(force_refresh=True)
assert exc_info.value.code == "xai_refresh_failed"
assert exc_info.value.relogin_required is True
raw = json.loads((hermes_home / "auth.json").read_text())
tokens = raw["providers"]["xai-oauth"]["tokens"]
# Dead OAuth fields must be cleared.
assert "access_token" not in tokens
assert "refresh_token" not in tokens
# Non-credential metadata must be preserved.
assert tokens.get("token_type") == "Bearer"
# Structured diagnostic blob must be written.
err = raw["providers"]["xai-oauth"].get("last_auth_error")
assert isinstance(err, dict)
assert err["provider"] == "xai-oauth"
assert err["code"] == "xai_refresh_failed"
assert err["reason"] == "runtime_refresh_failure"
assert err["relogin_required"] is True
assert "at" in err
# Active provider must be unchanged.
assert raw["active_provider"] == "nous"
def test_resolve_credentials_does_not_quarantine_on_transient_refresh_failure(
tmp_path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Transient refresh failure (relogin_required=False, e.g. 429 / 5xx) must
NOT trigger the quarantine path — tokens stay on disk for the next attempt.
"""
hermes_home = tmp_path / "hermes"
_seed_xai_oauth_state(hermes_home, dict(_STALE_XAI_OAUTH_STATE))
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
def _transient_refresh(tokens, **kwargs):
raise AuthError(
"xAI token refresh failed: connection error",
provider="xai-oauth",
code="xai_refresh_failed",
relogin_required=False,
)
monkeypatch.setattr("hermes_cli.auth._refresh_xai_oauth_tokens", _transient_refresh)
with pytest.raises(AuthError) as exc_info:
resolve_xai_oauth_runtime_credentials(force_refresh=True)
assert exc_info.value.relogin_required is False
# Tokens must be untouched — no quarantine on transient errors.
raw = json.loads((hermes_home / "auth.json").read_text())
tokens = raw["providers"]["xai-oauth"]["tokens"]
assert tokens["refresh_token"] == "dead-refresh-token"
assert tokens["access_token"] == "dead-access-token"
assert "last_auth_error" not in raw["providers"]["xai-oauth"]
# ---------------------------------------------------------------------------
# Auth status surface
# ---------------------------------------------------------------------------