From cdbdc28f5cd0074ca6056cd4f74398a215b97a76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=B5=A9=E7=94=9F?= Date: Thu, 7 May 2026 20:25:04 +0800 Subject: [PATCH] fix(config): custom named provider API key resolution in WebUI - add robust custom provider credential/base_url resolver - apply fallback in streaming and routes agent init/self-heal paths - support slug normalization and config fallbacks for custom:* providers --- api/config.py | 96 ++++++++++++++++++++++++++++++++++++++++++++++++ api/routes.py | 25 ++++++++++++- api/streaming.py | 23 ++++++++++++ 3 files changed, 143 insertions(+), 1 deletion(-) diff --git a/api/config.py b/api/config.py index ef07c0ff..b93d9ba2 100644 --- a/api/config.py +++ b/api/config.py @@ -1596,6 +1596,102 @@ def resolve_model_provider(model_id: str) -> tuple: return model_id, config_provider, config_base_url +def resolve_custom_provider_connection(provider_id: str) -> tuple[str | None, str | None]: + """Return (api_key, base_url) for a named ``custom:*`` provider. + + Supports ``custom_providers[].api_key`` as either a literal key or + ``${ENV_VAR}``, and ``custom_providers[].key_env`` as an env-var hint. + Returns ``(None, None)`` when no named custom provider matches. + """ + pid = str(provider_id or "").strip().lower() + if not pid.startswith("custom:"): + return None, None + + def _slugify(value: str) -> str: + s = str(value or "").strip().lower().replace("_", "-").replace(" ", "-") + while "--" in s: + s = s.replace("--", "-") + return s.strip("-") + + slug = _slugify(pid.split(":", 1)[1].strip()) + if not slug: + return None, None + + # Read the live config snapshot to avoid stale module-level cache edge + # cases after profile switches or runtime config edits. + cfg_data = get_config() + + def _resolve_key(raw_api_key, raw_key_env) -> str | None: + api_key = None + if raw_api_key is not None: + key_text = str(raw_api_key).strip() + if key_text.startswith("${") and key_text.endswith("}") and len(key_text) > 3: + api_key = os.getenv(key_text[2:-1], "").strip() or None + elif key_text: + api_key = key_text + if not api_key: + key_env = str(raw_key_env or "").strip() + if key_env: + api_key = os.getenv(key_env, "").strip() or None + return api_key + + custom_providers = cfg_data.get("custom_providers", []) + if not isinstance(custom_providers, list): + custom_providers = [] + + for entry in custom_providers: + if not isinstance(entry, dict): + continue + name = str(entry.get("name") or "").strip() + if not name: + continue + entry_slug = _slugify(name) + if entry_slug != slug: + continue + + base_url = str(entry.get("base_url") or "").strip() or None + api_key = _resolve_key(entry.get("api_key"), entry.get("key_env")) + return api_key, base_url + + # If exactly one custom provider is configured, use it as a pragmatic + # fallback for mismatched slugs (e.g. punctuation differences). + if len(custom_providers) == 1 and isinstance(custom_providers[0], dict): + entry = custom_providers[0] + return ( + _resolve_key(entry.get("api_key"), entry.get("key_env")), + str(entry.get("base_url") or "").strip() or None, + ) + + # Fallbacks for setups that don't use custom_providers names directly. + providers_cfg = cfg_data.get("providers", {}) + provider_specific = providers_cfg.get(pid, {}) if isinstance(providers_cfg, dict) else {} + provider_custom = providers_cfg.get("custom", {}) if isinstance(providers_cfg, dict) else {} + + model_cfg = cfg_data.get("model", {}) + model_provider = str(model_cfg.get("provider") or "").strip().lower() if isinstance(model_cfg, dict) else "" + + fallback_base = None + for candidate in (provider_specific, provider_custom, model_cfg): + if isinstance(candidate, dict): + _base = str(candidate.get("base_url") or "").strip() + if _base: + fallback_base = _base + break + + fallback_key = None + if isinstance(provider_specific, dict): + fallback_key = _resolve_key(provider_specific.get("api_key"), provider_specific.get("key_env")) + if not fallback_key and isinstance(provider_custom, dict): + fallback_key = _resolve_key(provider_custom.get("api_key"), provider_custom.get("key_env")) + if not fallback_key and isinstance(model_cfg, dict) and model_provider in {"custom", pid, slug}: + fallback_key = _resolve_key(model_cfg.get("api_key"), model_cfg.get("key_env")) + + if fallback_key or fallback_base: + return fallback_key, fallback_base or None + + return None, None + + def model_with_provider_context(model_id: str, model_provider: str | None = None) -> str: """Return the model string to pass to ``resolve_model_provider()``. diff --git a/api/routes.py b/api/routes.py index 96a0048f..e524c9ed 100644 --- a/api/routes.py +++ b/api/routes.py @@ -6639,7 +6639,10 @@ def _handle_chat_sync(handler, body): from run_agent import AIAgent with CHAT_LOCK: - from api.config import resolve_model_provider + from api.config import ( + resolve_model_provider, + resolve_custom_provider_connection, + ) _model, _provider, _base_url = resolve_model_provider( model_with_provider_context(s.model, getattr(s, "model_provider", None)) @@ -6665,6 +6668,12 @@ def _handle_chat_sync(handler, body): f"[webui] WARNING: resolve_runtime_provider failed: {_e}", flush=True, ) + if isinstance(_provider, str) and _provider.startswith("custom:"): + _cp_key, _cp_base = resolve_custom_provider_connection(_provider) + if not _api_key and _cp_key: + _api_key = _cp_key + if not _base_url and _cp_base: + _base_url = _cp_base agent = AIAgent( model=_model, provider=_provider, @@ -7427,6 +7436,13 @@ def _handle_session_compress(handler, body): except Exception as _e: logger.warning("resolve_runtime_provider failed for compression: %s", _e) + if isinstance(resolved_provider, str) and resolved_provider.startswith("custom:"): + _cp_key, _cp_base = _cfg.resolve_custom_provider_connection(resolved_provider) + if not resolved_api_key and _cp_key: + resolved_api_key = _cp_key + if not resolved_base_url and _cp_base: + resolved_base_url = _cp_base + if not resolved_api_key: return bad(handler, "No provider configured -- cannot compress.") @@ -8041,6 +8057,13 @@ def _handle_handoff_summary(handler, body): except Exception as _e: logger.warning("resolve_runtime_provider failed for handoff summary: %s", _e) + if isinstance(resolved_provider, str) and resolved_provider.startswith("custom:"): + _cp_key, _cp_base = _cfg.resolve_custom_provider_connection(resolved_provider) + if not resolved_api_key and _cp_key: + resolved_api_key = _cp_key + if not resolved_base_url and _cp_base: + resolved_base_url = _cp_base + if not resolved_api_key: summary_text = _fallback_handoff_summary(msgs) try: diff --git a/api/streaming.py b/api/streaming.py index f8f63e1e..0829cff0 100644 --- a/api/streaming.py +++ b/api/streaming.py @@ -26,6 +26,7 @@ from api.config import ( _get_session_agent_lock, _set_thread_env, _clear_thread_env, SESSION_AGENT_LOCKS, SESSION_AGENT_LOCKS_LOCK, resolve_model_provider, + resolve_custom_provider_connection, model_with_provider_context, ) from api.helpers import redact_session_data, _redact_text @@ -2266,6 +2267,16 @@ def _run_agent_streaming( except Exception as _e: print(f"[webui] WARNING: resolve_runtime_provider failed: {_e}", flush=True) + # Named custom providers (custom:slug) may not be resolvable by + # hermes_cli.runtime_provider directly. Fall back to config.yaml + # custom_providers[] so WebUI can pass explicit creds/base_url. + if isinstance(resolved_provider, str) and resolved_provider.startswith("custom:"): + _cp_key, _cp_base = resolve_custom_provider_connection(resolved_provider) + if not resolved_api_key and _cp_key: + resolved_api_key = _cp_key + if not resolved_base_url and _cp_base: + resolved_base_url = _cp_base + # Read per-profile config at call time (not module-level snapshot) from api.config import get_config as _get_config _cfg = _get_config() @@ -2725,6 +2736,12 @@ def _run_agent_streaming( resolved_provider = _heal_rt.get('provider') if not resolved_base_url: resolved_base_url = _heal_rt.get('base_url') + if isinstance(resolved_provider, str) and resolved_provider.startswith('custom:'): + _cp_key, _cp_base = resolve_custom_provider_connection(resolved_provider) + if not resolved_api_key and _cp_key: + resolved_api_key = _cp_key + if not resolved_base_url and _cp_base: + resolved_base_url = _cp_base # Rebuild agent kwargs and create a fresh agent _agent_kwargs['api_key'] = resolved_api_key _agent_kwargs['base_url'] = resolved_base_url @@ -3284,6 +3301,12 @@ def _run_agent_streaming( resolved_provider = _heal_rt.get('provider') if not resolved_base_url: resolved_base_url = _heal_rt.get('base_url') + if isinstance(resolved_provider, str) and resolved_provider.startswith('custom:'): + _cp_key, _cp_base = resolve_custom_provider_connection(resolved_provider) + if not resolved_api_key and _cp_key: + resolved_api_key = _cp_key + if not resolved_base_url and _cp_base: + resolved_base_url = _cp_base # Build a fresh agent with the new credentials _heal_kwargs = dict(_agent_kwargs) if '_agent_kwargs' in dir() else {} _heal_kwargs['api_key'] = resolved_api_key