From eefa1bbad8a987dc549fdc70d3c0a304558fcd74 Mon Sep 17 00:00:00 2001 From: Frank Song Date: Wed, 29 Apr 2026 07:54:48 +0800 Subject: [PATCH] fix(models): preserve model cache metadata --- api/config.py | 31 +++++++-- tests/test_model_cache_metadata.py | 101 +++++++++++++++++++++++++++++ 2 files changed, 128 insertions(+), 4 deletions(-) create mode 100644 tests/test_model_cache_metadata.py diff --git a/api/config.py b/api/config.py index 29551a3f..d498e4ff 100644 --- a/api/config.py +++ b/api/config.py @@ -1211,15 +1211,29 @@ def _delete_models_cache_on_disk() -> None: pass # already absent +def _is_valid_models_cache(cache: object) -> bool: + """Return True when a disk cache payload has the full /api/models shape.""" + if not isinstance(cache, dict): + return False + if not {"active_provider", "default_model", "groups"}.issubset(cache): + return False + active_provider = cache.get("active_provider") + return ( + (active_provider is None or isinstance(active_provider, str)) + and isinstance(cache.get("default_model"), str) + and isinstance(cache.get("groups"), list) + ) + + def _load_models_cache_from_disk() -> dict | None: - """Load groups dict from disk cache if it exists and is valid.""" + """Load /api/models cache from disk if it exists and has current metadata.""" try: import json as _j if not _models_cache_path.exists(): return None with open(_models_cache_path, encoding="utf-8") as f: cache = _j.load(f) - return cache if isinstance(cache, dict) and "groups" in cache else None + return cache if _is_valid_models_cache(cache) else None except Exception: return None @@ -1227,10 +1241,19 @@ def _load_models_cache_from_disk() -> dict | None: def _save_models_cache_to_disk(cache: dict) -> None: """Save cache to disk so it survives server restarts.""" try: - import time as _cache_time + if not _is_valid_models_cache(cache): + return tmp = str(_models_cache_path) + f".{os.getpid()}.tmp" with open(tmp, "w", encoding="utf-8") as f: - json.dump({"groups": cache.get("groups", [])}, f, indent=2) + json.dump( + { + "active_provider": cache["active_provider"], + "default_model": cache["default_model"], + "groups": cache["groups"], + }, + f, + indent=2, + ) os.rename(tmp, str(_models_cache_path)) except Exception: pass # Non-fatal -- cache will rebuild on next call diff --git a/tests/test_model_cache_metadata.py b/tests/test_model_cache_metadata.py new file mode 100644 index 00000000..2a5d02de --- /dev/null +++ b/tests/test_model_cache_metadata.py @@ -0,0 +1,101 @@ +"""Regression tests for /api/models disk cache metadata.""" + +import json + +import api.config as config + + +def _reset_memory_cache() -> None: + with config._available_models_cache_lock: + config._available_models_cache = None + config._available_models_cache_ts = 0.0 + config._cache_build_in_progress = False + config._cache_build_cv.notify_all() + + +def test_save_models_cache_to_disk_preserves_response_metadata(tmp_path, monkeypatch): + cache_path = tmp_path / "models_cache.json" + monkeypatch.setattr(config, "_models_cache_path", cache_path) + + payload = { + "active_provider": "openai", + "default_model": "gpt-5.4-mini", + "groups": [ + { + "provider": "OpenAI", + "provider_id": "openai", + "models": [{"id": "gpt-5.4-mini", "label": "GPT 5.4 Mini"}], + } + ], + } + + config._save_models_cache_to_disk(payload) + + assert json.loads(cache_path.read_text(encoding="utf-8")) == payload + assert config._load_models_cache_from_disk() == payload + + +def test_load_models_cache_from_disk_rejects_legacy_groups_only_cache(tmp_path, monkeypatch): + cache_path = tmp_path / "models_cache.json" + monkeypatch.setattr(config, "_models_cache_path", cache_path) + cache_path.write_text( + json.dumps( + { + "groups": [ + { + "provider": "Legacy", + "provider_id": "legacy", + "models": [{"id": "legacy-model", "label": "Legacy Model"}], + } + ] + } + ), + encoding="utf-8", + ) + + assert config._load_models_cache_from_disk() is None + + +def test_get_available_models_ignores_legacy_disk_cache_and_rebuilds( + tmp_path, + monkeypatch, +): + cache_path = tmp_path / "models_cache.json" + monkeypatch.setattr(config, "_models_cache_path", cache_path) + cache_path.write_text( + json.dumps( + { + "groups": [ + { + "provider": "Legacy", + "provider_id": "legacy", + "models": [{"id": "legacy-model", "label": "Legacy Model"}], + } + ] + } + ), + encoding="utf-8", + ) + _reset_memory_cache() + + saved_mtime = config._cfg_mtime + try: + try: + config._cfg_mtime = config.Path(config._get_config_path()).stat().st_mtime + except OSError: + config._cfg_mtime = 0.0 + + result = config.get_available_models() + finally: + config._cfg_mtime = saved_mtime + _reset_memory_cache() + + assert "active_provider" in result + assert "default_model" in result + assert "groups" in result + assert not any(group.get("provider") == "Legacy" for group in result["groups"]) + + written = json.loads(cache_path.read_text(encoding="utf-8")) + assert "active_provider" in written + assert "default_model" in written + assert "groups" in written