fix(context): prefer provider context lengths (#1184)

This commit is contained in:
Qiang Han
2026-06-01 09:31:58 +08:00
committed by GitHub
parent 8dbf4c7439
commit 1fbbfdfad6
2 changed files with 81 additions and 9 deletions
@@ -39,6 +39,15 @@ interface CustomProviderEntry {
models?: Record<string, { context_length?: number }>
}
type ConfigProviderModels = Record<string, { context_length?: number } | string> | string[]
interface ConfigProviderEntry {
context_length?: number
default_model?: string
model?: string
models?: ConfigProviderModels
}
const MODEL_CACHE_PROVIDER_ALIASES: Record<string, string[]> = {
gemini: ['google'],
moonshot: ['moonshotai'],
@@ -122,6 +131,46 @@ function getConfigContextLength(config: any): number | null {
return val
}
function getConfigProvider(config: any, provider: string | null): ConfigProviderEntry | null {
if (!provider) return null
const providers = config?.providers
if (!providers || typeof providers !== 'object') return null
const exact = providers[provider]
if (exact && typeof exact === 'object') return exact as ConfigProviderEntry
const lower = provider.toLowerCase()
const match = Object.entries(providers).find(([name]) => name.toLowerCase() === lower)
const value = match?.[1]
return value && typeof value === 'object' ? value as ConfigProviderEntry : null
}
function getPositiveNumber(value: unknown): number | null {
return typeof value === 'number' && Number.isFinite(value) && value > 0 ? value : null
}
function providerHasModel(provider: ConfigProviderEntry, modelName: string): boolean {
if (provider.default_model === modelName || provider.model === modelName) return true
const models = provider.models
if (Array.isArray(models)) return models.includes(modelName)
return !!models && typeof models === 'object' && Object.prototype.hasOwnProperty.call(models, modelName)
}
function lookupProviderConfigContextLength(config: any, modelName: string, provider: string | null): number | null {
const providerEntry = getConfigProvider(config, provider)
if (!providerEntry) return null
const models = providerEntry.models
if (models && !Array.isArray(models) && typeof models === 'object') {
const modelEntry = models[modelName]
if (modelEntry && typeof modelEntry === 'object') {
const modelCtx = getPositiveNumber(modelEntry.context_length)
if (modelCtx) return modelCtx
}
}
if (!providerHasModel(providerEntry, modelName)) return null
return getPositiveNumber(providerEntry.context_length)
}
function normalizeCustomProviderName(name: string): string {
return name.trim().toLowerCase().replace(/ /g, '-')
}
@@ -333,10 +382,13 @@ function lookupContextFromCache(config: any, modelName: string, provider: string
/**
* Get the context length for the current profile's default model.
* Resolution order:
* 1. config.yaml model.context_length (highest priority, user override)
* 2. custom_providers models.<model>.context_length
* 3. models_dev_cache.json, scoped to model.provider when configured
* 4. DEFAULT_CONTEXT_LENGTH (200K hardcoded fallback)
* 1. model_context database override
* 2. provider/model-specific providers.<provider>.models.<model>.context_length
* 3. provider-level providers.<provider>.context_length when the model belongs to that provider
* 4. custom_providers models.<model>.context_length
* 5. top-level model.context_length fallback
* 6. models_dev_cache.json, scoped to model.provider when configured
* 7. DEFAULT_CONTEXT_LENGTH
*/
/**
* 从数据库 model_context 表查找上下文长度(最高优先级)
@@ -375,18 +427,22 @@ export function getModelContextLength(input?: string | ModelContextLengthOptions
const dbCtx = lookupContextFromDatabase(model, provider)
if (dbCtx && dbCtx > 0) return dbCtx
// 1. Global context_length override in config.yaml
const configCtx = getConfigContextLength(config)
if (configCtx && configCtx > 0) return configCtx
// 1. Provider-specific context_length in config.yaml
const providerConfigCtx = lookupProviderConfigContextLength(config, model, provider)
if (providerConfigCtx && providerConfigCtx > 0) return providerConfigCtx
// 2. Custom provider context_length
const customCtx = lookupCustomProviderContextLength(config, model, provider)
if (customCtx && customCtx > 0) return customCtx
// 3. models_dev_cache.json
// 3. Global context_length fallback in config.yaml
const configCtx = getConfigContextLength(config)
if (configCtx && configCtx > 0) return configCtx
// 4. models_dev_cache.json
const cached = lookupContextFromCache(config, model, provider)
if (cached) return cached
// 4. Fallback
// 5. Fallback
return DEFAULT_CONTEXT_LENGTH
}
+16
View File
@@ -109,6 +109,22 @@ describe('getModelContextLength', () => {
expect(getModelContextLength()).toBe(1_050_000)
})
it('prefers requested provider model context_length over top-level default context_length', async () => {
writeConfig(`model:\n default: gpt-5.5\n provider: openai-codex\n context_length: 272000\n\nproviders:\n qwen:\n name: Qwen\n default_model: qwen3.6-plus\n models:\n qwen3.6-plus:\n context_length: 1048576\n`)
const { getModelContextLength } = await loadModelContext()
expect(getModelContextLength({ provider: 'qwen', model: 'qwen3.6-plus' })).toBe(1_048_576)
})
it('uses provider-level context_length when the requested model belongs to that provider', async () => {
writeConfig(`model:\n default: gpt-5.5\n provider: openai-codex\n context_length: 272000\n\nproviders:\n qwen:\n name: Qwen\n default_model: qwen3.6-plus\n models:\n - qwen3.6-plus\n context_length: 1048576\n`)
const { getModelContextLength } = await loadModelContext()
expect(getModelContextLength({ provider: 'qwen', model: 'qwen3.6-plus' })).toBe(1_048_576)
})
it('keeps legacy model-name cache lookup when no provider is configured', async () => {
writeConfig(`model:\n default: gpt-5.5\n`)
writeModelsCache({