From 1fbbfdfad64d2144f5ac3de8a38b0cd3824c3871 Mon Sep 17 00:00:00 2001 From: Qiang Han <70218387+h1679242037@users.noreply.github.com> Date: Mon, 1 Jun 2026 09:31:58 +0800 Subject: [PATCH] fix(context): prefer provider context lengths (#1184) --- .../src/services/hermes/model-context.ts | 74 ++++++++++++++++--- tests/server/model-context.test.ts | 16 ++++ 2 files changed, 81 insertions(+), 9 deletions(-) diff --git a/packages/server/src/services/hermes/model-context.ts b/packages/server/src/services/hermes/model-context.ts index eafbfbfa..b76a7867 100644 --- a/packages/server/src/services/hermes/model-context.ts +++ b/packages/server/src/services/hermes/model-context.ts @@ -39,6 +39,15 @@ interface CustomProviderEntry { models?: Record } +type ConfigProviderModels = Record | string[] + +interface ConfigProviderEntry { + context_length?: number + default_model?: string + model?: string + models?: ConfigProviderModels +} + const MODEL_CACHE_PROVIDER_ALIASES: Record = { gemini: ['google'], moonshot: ['moonshotai'], @@ -122,6 +131,46 @@ function getConfigContextLength(config: any): number | null { return val } +function getConfigProvider(config: any, provider: string | null): ConfigProviderEntry | null { + if (!provider) return null + const providers = config?.providers + if (!providers || typeof providers !== 'object') return null + const exact = providers[provider] + if (exact && typeof exact === 'object') return exact as ConfigProviderEntry + const lower = provider.toLowerCase() + const match = Object.entries(providers).find(([name]) => name.toLowerCase() === lower) + const value = match?.[1] + return value && typeof value === 'object' ? value as ConfigProviderEntry : null +} + +function getPositiveNumber(value: unknown): number | null { + return typeof value === 'number' && Number.isFinite(value) && value > 0 ? value : null +} + +function providerHasModel(provider: ConfigProviderEntry, modelName: string): boolean { + if (provider.default_model === modelName || provider.model === modelName) return true + const models = provider.models + if (Array.isArray(models)) return models.includes(modelName) + return !!models && typeof models === 'object' && Object.prototype.hasOwnProperty.call(models, modelName) +} + +function lookupProviderConfigContextLength(config: any, modelName: string, provider: string | null): number | null { + const providerEntry = getConfigProvider(config, provider) + if (!providerEntry) return null + + const models = providerEntry.models + if (models && !Array.isArray(models) && typeof models === 'object') { + const modelEntry = models[modelName] + if (modelEntry && typeof modelEntry === 'object') { + const modelCtx = getPositiveNumber(modelEntry.context_length) + if (modelCtx) return modelCtx + } + } + + if (!providerHasModel(providerEntry, modelName)) return null + return getPositiveNumber(providerEntry.context_length) +} + function normalizeCustomProviderName(name: string): string { return name.trim().toLowerCase().replace(/ /g, '-') } @@ -333,10 +382,13 @@ function lookupContextFromCache(config: any, modelName: string, provider: string /** * Get the context length for the current profile's default model. * Resolution order: - * 1. config.yaml model.context_length (highest priority, user override) - * 2. custom_providers models..context_length - * 3. models_dev_cache.json, scoped to model.provider when configured - * 4. DEFAULT_CONTEXT_LENGTH (200K hardcoded fallback) + * 1. model_context database override + * 2. provider/model-specific providers..models..context_length + * 3. provider-level providers..context_length when the model belongs to that provider + * 4. custom_providers models..context_length + * 5. top-level model.context_length fallback + * 6. models_dev_cache.json, scoped to model.provider when configured + * 7. DEFAULT_CONTEXT_LENGTH */ /** * 从数据库 model_context 表查找上下文长度(最高优先级) @@ -375,18 +427,22 @@ export function getModelContextLength(input?: string | ModelContextLengthOptions const dbCtx = lookupContextFromDatabase(model, provider) if (dbCtx && dbCtx > 0) return dbCtx - // 1. Global context_length override in config.yaml - const configCtx = getConfigContextLength(config) - if (configCtx && configCtx > 0) return configCtx + // 1. Provider-specific context_length in config.yaml + const providerConfigCtx = lookupProviderConfigContextLength(config, model, provider) + if (providerConfigCtx && providerConfigCtx > 0) return providerConfigCtx // 2. Custom provider context_length const customCtx = lookupCustomProviderContextLength(config, model, provider) if (customCtx && customCtx > 0) return customCtx - // 3. models_dev_cache.json + // 3. Global context_length fallback in config.yaml + const configCtx = getConfigContextLength(config) + if (configCtx && configCtx > 0) return configCtx + + // 4. models_dev_cache.json const cached = lookupContextFromCache(config, model, provider) if (cached) return cached - // 4. Fallback + // 5. Fallback return DEFAULT_CONTEXT_LENGTH } diff --git a/tests/server/model-context.test.ts b/tests/server/model-context.test.ts index fd89fc1a..155d06e9 100644 --- a/tests/server/model-context.test.ts +++ b/tests/server/model-context.test.ts @@ -109,6 +109,22 @@ describe('getModelContextLength', () => { expect(getModelContextLength()).toBe(1_050_000) }) + it('prefers requested provider model context_length over top-level default context_length', async () => { + writeConfig(`model:\n default: gpt-5.5\n provider: openai-codex\n context_length: 272000\n\nproviders:\n qwen:\n name: Qwen\n default_model: qwen3.6-plus\n models:\n qwen3.6-plus:\n context_length: 1048576\n`) + + const { getModelContextLength } = await loadModelContext() + + expect(getModelContextLength({ provider: 'qwen', model: 'qwen3.6-plus' })).toBe(1_048_576) + }) + + it('uses provider-level context_length when the requested model belongs to that provider', async () => { + writeConfig(`model:\n default: gpt-5.5\n provider: openai-codex\n context_length: 272000\n\nproviders:\n qwen:\n name: Qwen\n default_model: qwen3.6-plus\n models:\n - qwen3.6-plus\n context_length: 1048576\n`) + + const { getModelContextLength } = await loadModelContext() + + expect(getModelContextLength({ provider: 'qwen', model: 'qwen3.6-plus' })).toBe(1_048_576) + }) + it('keeps legacy model-name cache lookup when no provider is configured', async () => { writeConfig(`model:\n default: gpt-5.5\n`) writeModelsCache({