fix(config): custom named provider API key resolution in WebUI

- add robust custom provider credential/base_url resolver
- apply fallback in streaming and routes agent init/self-heal paths
- support slug normalization and config fallbacks for custom:* providers
This commit is contained in:
王浩生
2026-05-07 20:25:04 +08:00
committed by nesquena-hermes
parent 6253032b53
commit cdbdc28f5c
3 changed files with 143 additions and 1 deletions
+96
View File
@@ -1596,6 +1596,102 @@ def resolve_model_provider(model_id: str) -> tuple:
return model_id, config_provider, config_base_url
def resolve_custom_provider_connection(provider_id: str) -> tuple[str | None, str | None]:
"""Return (api_key, base_url) for a named ``custom:*`` provider.
Supports ``custom_providers[].api_key`` as either a literal key or
``${ENV_VAR}``, and ``custom_providers[].key_env`` as an env-var hint.
Returns ``(None, None)`` when no named custom provider matches.
"""
pid = str(provider_id or "").strip().lower()
if not pid.startswith("custom:"):
return None, None
def _slugify(value: str) -> str:
s = str(value or "").strip().lower().replace("_", "-").replace(" ", "-")
while "--" in s:
s = s.replace("--", "-")
return s.strip("-")
slug = _slugify(pid.split(":", 1)[1].strip())
if not slug:
return None, None
# Read the live config snapshot to avoid stale module-level cache edge
# cases after profile switches or runtime config edits.
cfg_data = get_config()
def _resolve_key(raw_api_key, raw_key_env) -> str | None:
api_key = None
if raw_api_key is not None:
key_text = str(raw_api_key).strip()
if key_text.startswith("${") and key_text.endswith("}") and len(key_text) > 3:
api_key = os.getenv(key_text[2:-1], "").strip() or None
elif key_text:
api_key = key_text
if not api_key:
key_env = str(raw_key_env or "").strip()
if key_env:
api_key = os.getenv(key_env, "").strip() or None
return api_key
custom_providers = cfg_data.get("custom_providers", [])
if not isinstance(custom_providers, list):
custom_providers = []
for entry in custom_providers:
if not isinstance(entry, dict):
continue
name = str(entry.get("name") or "").strip()
if not name:
continue
entry_slug = _slugify(name)
if entry_slug != slug:
continue
base_url = str(entry.get("base_url") or "").strip() or None
api_key = _resolve_key(entry.get("api_key"), entry.get("key_env"))
return api_key, base_url
# If exactly one custom provider is configured, use it as a pragmatic
# fallback for mismatched slugs (e.g. punctuation differences).
if len(custom_providers) == 1 and isinstance(custom_providers[0], dict):
entry = custom_providers[0]
return (
_resolve_key(entry.get("api_key"), entry.get("key_env")),
str(entry.get("base_url") or "").strip() or None,
)
# Fallbacks for setups that don't use custom_providers names directly.
providers_cfg = cfg_data.get("providers", {})
provider_specific = providers_cfg.get(pid, {}) if isinstance(providers_cfg, dict) else {}
provider_custom = providers_cfg.get("custom", {}) if isinstance(providers_cfg, dict) else {}
model_cfg = cfg_data.get("model", {})
model_provider = str(model_cfg.get("provider") or "").strip().lower() if isinstance(model_cfg, dict) else ""
fallback_base = None
for candidate in (provider_specific, provider_custom, model_cfg):
if isinstance(candidate, dict):
_base = str(candidate.get("base_url") or "").strip()
if _base:
fallback_base = _base
break
fallback_key = None
if isinstance(provider_specific, dict):
fallback_key = _resolve_key(provider_specific.get("api_key"), provider_specific.get("key_env"))
if not fallback_key and isinstance(provider_custom, dict):
fallback_key = _resolve_key(provider_custom.get("api_key"), provider_custom.get("key_env"))
if not fallback_key and isinstance(model_cfg, dict) and model_provider in {"custom", pid, slug}:
fallback_key = _resolve_key(model_cfg.get("api_key"), model_cfg.get("key_env"))
if fallback_key or fallback_base:
return fallback_key, fallback_base or None
return None, None
def model_with_provider_context(model_id: str, model_provider: str | None = None) -> str:
"""Return the model string to pass to ``resolve_model_provider()``.
+24 -1
View File
@@ -6639,7 +6639,10 @@ def _handle_chat_sync(handler, body):
from run_agent import AIAgent
with CHAT_LOCK:
from api.config import resolve_model_provider
from api.config import (
resolve_model_provider,
resolve_custom_provider_connection,
)
_model, _provider, _base_url = resolve_model_provider(
model_with_provider_context(s.model, getattr(s, "model_provider", None))
@@ -6665,6 +6668,12 @@ def _handle_chat_sync(handler, body):
f"[webui] WARNING: resolve_runtime_provider failed: {_e}",
flush=True,
)
if isinstance(_provider, str) and _provider.startswith("custom:"):
_cp_key, _cp_base = resolve_custom_provider_connection(_provider)
if not _api_key and _cp_key:
_api_key = _cp_key
if not _base_url and _cp_base:
_base_url = _cp_base
agent = AIAgent(
model=_model,
provider=_provider,
@@ -7427,6 +7436,13 @@ def _handle_session_compress(handler, body):
except Exception as _e:
logger.warning("resolve_runtime_provider failed for compression: %s", _e)
if isinstance(resolved_provider, str) and resolved_provider.startswith("custom:"):
_cp_key, _cp_base = _cfg.resolve_custom_provider_connection(resolved_provider)
if not resolved_api_key and _cp_key:
resolved_api_key = _cp_key
if not resolved_base_url and _cp_base:
resolved_base_url = _cp_base
if not resolved_api_key:
return bad(handler, "No provider configured -- cannot compress.")
@@ -8041,6 +8057,13 @@ def _handle_handoff_summary(handler, body):
except Exception as _e:
logger.warning("resolve_runtime_provider failed for handoff summary: %s", _e)
if isinstance(resolved_provider, str) and resolved_provider.startswith("custom:"):
_cp_key, _cp_base = _cfg.resolve_custom_provider_connection(resolved_provider)
if not resolved_api_key and _cp_key:
resolved_api_key = _cp_key
if not resolved_base_url and _cp_base:
resolved_base_url = _cp_base
if not resolved_api_key:
summary_text = _fallback_handoff_summary(msgs)
try:
+23
View File
@@ -26,6 +26,7 @@ from api.config import (
_get_session_agent_lock, _set_thread_env, _clear_thread_env,
SESSION_AGENT_LOCKS, SESSION_AGENT_LOCKS_LOCK,
resolve_model_provider,
resolve_custom_provider_connection,
model_with_provider_context,
)
from api.helpers import redact_session_data, _redact_text
@@ -2266,6 +2267,16 @@ def _run_agent_streaming(
except Exception as _e:
print(f"[webui] WARNING: resolve_runtime_provider failed: {_e}", flush=True)
# Named custom providers (custom:slug) may not be resolvable by
# hermes_cli.runtime_provider directly. Fall back to config.yaml
# custom_providers[] so WebUI can pass explicit creds/base_url.
if isinstance(resolved_provider, str) and resolved_provider.startswith("custom:"):
_cp_key, _cp_base = resolve_custom_provider_connection(resolved_provider)
if not resolved_api_key and _cp_key:
resolved_api_key = _cp_key
if not resolved_base_url and _cp_base:
resolved_base_url = _cp_base
# Read per-profile config at call time (not module-level snapshot)
from api.config import get_config as _get_config
_cfg = _get_config()
@@ -2725,6 +2736,12 @@ def _run_agent_streaming(
resolved_provider = _heal_rt.get('provider')
if not resolved_base_url:
resolved_base_url = _heal_rt.get('base_url')
if isinstance(resolved_provider, str) and resolved_provider.startswith('custom:'):
_cp_key, _cp_base = resolve_custom_provider_connection(resolved_provider)
if not resolved_api_key and _cp_key:
resolved_api_key = _cp_key
if not resolved_base_url and _cp_base:
resolved_base_url = _cp_base
# Rebuild agent kwargs and create a fresh agent
_agent_kwargs['api_key'] = resolved_api_key
_agent_kwargs['base_url'] = resolved_base_url
@@ -3284,6 +3301,12 @@ def _run_agent_streaming(
resolved_provider = _heal_rt.get('provider')
if not resolved_base_url:
resolved_base_url = _heal_rt.get('base_url')
if isinstance(resolved_provider, str) and resolved_provider.startswith('custom:'):
_cp_key, _cp_base = resolve_custom_provider_connection(resolved_provider)
if not resolved_api_key and _cp_key:
resolved_api_key = _cp_key
if not resolved_base_url and _cp_base:
resolved_base_url = _cp_base
# Build a fresh agent with the new credentials
_heal_kwargs = dict(_agent_kwargs) if '_agent_kwargs' in dir() else {}
_heal_kwargs['api_key'] = resolved_api_key