feat(proxy): add xai upstream adapter for Grok via OAuth

This commit is contained in:
yannsunn
2026-05-19 00:23:26 +09:00
committed by Teknium
parent bde6313e34
commit 1d6f3753de
5 changed files with 265 additions and 3 deletions
+1 -1
View File
@@ -10264,7 +10264,7 @@ def main():
proxy_start.add_argument(
"--provider",
default="nous",
help="Upstream provider (default: nous). See `hermes proxy providers`.",
help="Upstream provider: nous or xai (default: nous). See `hermes proxy providers`.",
)
proxy_start.add_argument(
"--host",
+2
View File
@@ -9,11 +9,13 @@ from typing import Dict, Type
from hermes_cli.proxy.adapters.base import UpstreamAdapter
from hermes_cli.proxy.adapters.nous_portal import NousPortalAdapter
from hermes_cli.proxy.adapters.xai import XAIGrokAdapter
# Registry of available adapter classes keyed by provider name as used on
# the ``hermes proxy start --provider <name>`` CLI flag.
ADAPTERS: Dict[str, Type[UpstreamAdapter]] = {
"nous": NousPortalAdapter,
"xai": XAIGrokAdapter,
}
+136
View File
@@ -0,0 +1,136 @@
"""xAI Grok OAuth upstream adapter."""
from __future__ import annotations
import logging
import threading
from typing import FrozenSet, Optional
from agent.credential_pool import CredentialPool, PooledCredential, load_pool
from hermes_cli.auth import DEFAULT_XAI_OAUTH_BASE_URL
from hermes_cli.proxy.adapters.base import UpstreamAdapter, UpstreamCredential
logger = logging.getLogger(__name__)
_POOL_PROVIDER = "xai-oauth"
# xAI's public API is OpenAI-compatible for the endpoints Hermes commonly
# uses. The Responses endpoint is included because Hermes' native xAI runtime
# uses codex_responses mode.
_ALLOWED_PATHS: FrozenSet[str] = frozenset(
{
"/responses",
"/chat/completions",
"/completions",
"/embeddings",
"/models",
}
)
class XAIGrokAdapter(UpstreamAdapter):
"""Proxy upstream for xAI Grok via Hermes-managed OAuth credentials."""
auth_hint = "hermes auth add xai-oauth --type oauth"
def __init__(self) -> None:
self._lock = threading.Lock()
self._pool: Optional[CredentialPool] = None
@property
def name(self) -> str:
return "xai"
@property
def display_name(self) -> str:
return "xAI Grok OAuth"
@property
def allowed_paths(self) -> FrozenSet[str]:
return _ALLOWED_PATHS
def is_authenticated(self) -> bool:
pool = self._load_pool()
return bool(pool and pool.has_available())
def get_credential(self) -> UpstreamCredential:
with self._lock:
pool = self._load_pool()
if pool is None or not pool.has_credentials():
raise RuntimeError(
"No xAI OAuth credentials found. Run "
"`hermes auth add xai-oauth --type oauth` first."
)
entry = pool.select()
if entry is None:
raise RuntimeError(
"No available xAI OAuth credentials found. Run "
"`hermes auth reset xai-oauth` or re-authenticate with "
"`hermes auth add xai-oauth --type oauth`."
)
self._pool = pool
return self._credential_from_entry(entry)
def get_retry_credential(
self,
*,
failed_credential: UpstreamCredential,
status_code: int,
) -> Optional[UpstreamCredential]:
if status_code != 401:
return None
with self._lock:
pool = self._pool or self._load_pool()
if pool is None:
return None
refreshed = pool.try_refresh_current()
if refreshed is None:
refreshed = pool.mark_exhausted_and_rotate(status_code=status_code)
if refreshed is None:
return None
retry_cred = self._credential_from_entry(refreshed)
if retry_cred.bearer == failed_credential.bearer:
return None
logger.info("proxy: xAI upstream rejected bearer; retrying with refreshed pool credential")
return retry_cred
def _load_pool(self) -> Optional[CredentialPool]:
try:
return load_pool(_POOL_PROVIDER)
except Exception as exc:
logger.warning("proxy: failed to load xAI OAuth credential pool: %s", exc)
return None
def _credential_from_entry(self, entry: PooledCredential) -> UpstreamCredential:
bearer = (
getattr(entry, "runtime_api_key", None)
or getattr(entry, "access_token", "")
or ""
)
bearer = str(bearer).strip()
if not bearer:
raise RuntimeError(
"xAI OAuth credential pool entry did not contain an access token. "
"Re-authenticate with `hermes auth add xai-oauth --type oauth`."
)
base_url = (
getattr(entry, "runtime_base_url", None)
or getattr(entry, "base_url", None)
or DEFAULT_XAI_OAUTH_BASE_URL
)
base_url = str(base_url or DEFAULT_XAI_OAUTH_BASE_URL).strip().rstrip("/")
return UpstreamCredential(
bearer=bearer,
base_url=base_url or DEFAULT_XAI_OAUTH_BASE_URL,
expires_at=getattr(entry, "expires_at", None),
)
__all__ = ["XAIGrokAdapter"]
+3 -2
View File
@@ -44,9 +44,10 @@ def cmd_proxy_start(args: Any) -> int:
return 2
if not adapter.is_authenticated():
auth_hint = getattr(adapter, "auth_hint", f"hermes login {adapter.name}")
print(
f"Not logged into {adapter.display_name}. "
f"Run `hermes login {adapter.name}` first.",
f"Run `{auth_hint}` first.",
file=sys.stderr,
)
return 2
@@ -122,7 +123,7 @@ def cmd_proxy(args: Any) -> int:
"OAuth-authenticated provider credentials to outbound requests.\n"
"\n"
"Subcommands:\n"
" hermes proxy start [--provider nous] [--host 127.0.0.1] [--port 8645]\n"
" hermes proxy start [--provider nous|xai] [--host 127.0.0.1] [--port 8645]\n"
" Run the proxy in the foreground.\n"
" hermes proxy status\n"
" Show which upstream adapters are ready.\n"
+123
View File
@@ -15,6 +15,7 @@ import pytest
from hermes_cli.proxy.adapters import ADAPTERS, get_adapter
from hermes_cli.proxy.adapters.base import UpstreamAdapter, UpstreamCredential
from hermes_cli.proxy.adapters.nous_portal import NousPortalAdapter
from hermes_cli.proxy.adapters.xai import XAIGrokAdapter
# ---------------------------------------------------------------------------
@@ -26,15 +27,26 @@ def test_registry_lists_nous():
assert "nous" in ADAPTERS
def test_registry_lists_xai():
assert "xai" in ADAPTERS
def test_get_adapter_returns_instance():
adapter = get_adapter("nous")
assert isinstance(adapter, NousPortalAdapter)
assert isinstance(adapter, UpstreamAdapter)
def test_get_adapter_returns_xai_instance():
adapter = get_adapter("xai")
assert isinstance(adapter, XAIGrokAdapter)
assert isinstance(adapter, UpstreamAdapter)
def test_get_adapter_case_insensitive():
assert isinstance(get_adapter("NOUS"), NousPortalAdapter)
assert isinstance(get_adapter(" Nous "), NousPortalAdapter)
assert isinstance(get_adapter("XAI"), XAIGrokAdapter)
def test_get_adapter_unknown_provider_raises():
@@ -327,6 +339,117 @@ def test_nous_adapter_concurrent_refresh_serialized(tmp_path, monkeypatch):
assert all(r.startswith("key-") for r in results)
# ---------------------------------------------------------------------------
# XAIGrokAdapter
# ---------------------------------------------------------------------------
def _write_xai_pool_entry(
hermes_home: Path,
*,
access_token: str = "xai-access-token",
refresh_token: str = "xai-refresh-token",
base_url: str = "https://api.x.ai/v1",
source: str = "manual:xai_pkce",
) -> Path:
"""Write an xai-oauth pool entry into a hermetic HERMES_HOME."""
auth_path = hermes_home / "auth.json"
auth_path.write_text(json.dumps({
"version": 1,
"providers": {},
"credential_pool": {
"xai-oauth": [
{
"id": "xai123",
"label": "xai-test",
"auth_type": "oauth",
"priority": 0,
"source": source,
"access_token": access_token,
"refresh_token": refresh_token,
"base_url": base_url,
}
]
},
}))
return auth_path
def test_xai_adapter_metadata():
adapter = XAIGrokAdapter()
assert adapter.name == "xai"
assert adapter.display_name == "xAI Grok OAuth"
assert "/responses" in adapter.allowed_paths
assert "/chat/completions" in adapter.allowed_paths
assert "/models" in adapter.allowed_paths
def test_xai_adapter_not_authenticated_when_no_pool_entry(tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
(tmp_path / "auth.json").write_text(json.dumps({
"version": 1,
"providers": {},
"credential_pool": {},
}))
assert not XAIGrokAdapter().is_authenticated()
def test_xai_adapter_authenticated_with_pool_entry(tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
_write_xai_pool_entry(tmp_path)
assert XAIGrokAdapter().is_authenticated()
def test_xai_adapter_get_credential_uses_oauth_pool(tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
_write_xai_pool_entry(
tmp_path,
access_token="pool-access-token",
base_url="https://api.x.ai/v1/",
)
cred = XAIGrokAdapter().get_credential()
assert cred.bearer == "pool-access-token"
assert cred.base_url == "https://api.x.ai/v1"
assert cred.token_type == "Bearer"
def test_xai_adapter_get_credential_defaults_base_url(tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
_write_xai_pool_entry(tmp_path, base_url="")
cred = XAIGrokAdapter().get_credential()
assert cred.base_url == "https://api.x.ai/v1"
def test_xai_adapter_retry_refreshes_current_pool_entry(tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
_write_xai_pool_entry(tmp_path, access_token="old-access-token")
def fake_refresh(access_token, refresh_token, **kwargs):
assert access_token == "old-access-token"
assert refresh_token == "xai-refresh-token"
return {
"access_token": "new-access-token",
"refresh_token": "new-refresh-token",
"last_refresh": "2026-05-19T00:00:00Z",
}
monkeypatch.setattr("hermes_cli.auth.refresh_xai_oauth_pure", fake_refresh)
adapter = XAIGrokAdapter()
failed = adapter.get_credential()
retry = adapter.get_retry_credential(
failed_credential=failed,
status_code=401,
)
assert retry is not None
assert retry.bearer == "new-access-token"
# ---------------------------------------------------------------------------
# Server: path filtering + forwarding
#