mirror of
https://github.com/EKKOLearnAI/hermes-web-ui.git
synced 2026-06-03 01:40:17 +00:00
fix(context): prefer provider context lengths (#1184)
This commit is contained in:
@@ -39,6 +39,15 @@ interface CustomProviderEntry {
|
||||
models?: Record<string, { context_length?: number }>
|
||||
}
|
||||
|
||||
type ConfigProviderModels = Record<string, { context_length?: number } | string> | string[]
|
||||
|
||||
interface ConfigProviderEntry {
|
||||
context_length?: number
|
||||
default_model?: string
|
||||
model?: string
|
||||
models?: ConfigProviderModels
|
||||
}
|
||||
|
||||
const MODEL_CACHE_PROVIDER_ALIASES: Record<string, string[]> = {
|
||||
gemini: ['google'],
|
||||
moonshot: ['moonshotai'],
|
||||
@@ -122,6 +131,46 @@ function getConfigContextLength(config: any): number | null {
|
||||
return val
|
||||
}
|
||||
|
||||
function getConfigProvider(config: any, provider: string | null): ConfigProviderEntry | null {
|
||||
if (!provider) return null
|
||||
const providers = config?.providers
|
||||
if (!providers || typeof providers !== 'object') return null
|
||||
const exact = providers[provider]
|
||||
if (exact && typeof exact === 'object') return exact as ConfigProviderEntry
|
||||
const lower = provider.toLowerCase()
|
||||
const match = Object.entries(providers).find(([name]) => name.toLowerCase() === lower)
|
||||
const value = match?.[1]
|
||||
return value && typeof value === 'object' ? value as ConfigProviderEntry : null
|
||||
}
|
||||
|
||||
function getPositiveNumber(value: unknown): number | null {
|
||||
return typeof value === 'number' && Number.isFinite(value) && value > 0 ? value : null
|
||||
}
|
||||
|
||||
function providerHasModel(provider: ConfigProviderEntry, modelName: string): boolean {
|
||||
if (provider.default_model === modelName || provider.model === modelName) return true
|
||||
const models = provider.models
|
||||
if (Array.isArray(models)) return models.includes(modelName)
|
||||
return !!models && typeof models === 'object' && Object.prototype.hasOwnProperty.call(models, modelName)
|
||||
}
|
||||
|
||||
function lookupProviderConfigContextLength(config: any, modelName: string, provider: string | null): number | null {
|
||||
const providerEntry = getConfigProvider(config, provider)
|
||||
if (!providerEntry) return null
|
||||
|
||||
const models = providerEntry.models
|
||||
if (models && !Array.isArray(models) && typeof models === 'object') {
|
||||
const modelEntry = models[modelName]
|
||||
if (modelEntry && typeof modelEntry === 'object') {
|
||||
const modelCtx = getPositiveNumber(modelEntry.context_length)
|
||||
if (modelCtx) return modelCtx
|
||||
}
|
||||
}
|
||||
|
||||
if (!providerHasModel(providerEntry, modelName)) return null
|
||||
return getPositiveNumber(providerEntry.context_length)
|
||||
}
|
||||
|
||||
function normalizeCustomProviderName(name: string): string {
|
||||
return name.trim().toLowerCase().replace(/ /g, '-')
|
||||
}
|
||||
@@ -333,10 +382,13 @@ function lookupContextFromCache(config: any, modelName: string, provider: string
|
||||
/**
|
||||
* Get the context length for the current profile's default model.
|
||||
* Resolution order:
|
||||
* 1. config.yaml model.context_length (highest priority, user override)
|
||||
* 2. custom_providers models.<model>.context_length
|
||||
* 3. models_dev_cache.json, scoped to model.provider when configured
|
||||
* 4. DEFAULT_CONTEXT_LENGTH (200K hardcoded fallback)
|
||||
* 1. model_context database override
|
||||
* 2. provider/model-specific providers.<provider>.models.<model>.context_length
|
||||
* 3. provider-level providers.<provider>.context_length when the model belongs to that provider
|
||||
* 4. custom_providers models.<model>.context_length
|
||||
* 5. top-level model.context_length fallback
|
||||
* 6. models_dev_cache.json, scoped to model.provider when configured
|
||||
* 7. DEFAULT_CONTEXT_LENGTH
|
||||
*/
|
||||
/**
|
||||
* 从数据库 model_context 表查找上下文长度(最高优先级)
|
||||
@@ -375,18 +427,22 @@ export function getModelContextLength(input?: string | ModelContextLengthOptions
|
||||
const dbCtx = lookupContextFromDatabase(model, provider)
|
||||
if (dbCtx && dbCtx > 0) return dbCtx
|
||||
|
||||
// 1. Global context_length override in config.yaml
|
||||
const configCtx = getConfigContextLength(config)
|
||||
if (configCtx && configCtx > 0) return configCtx
|
||||
// 1. Provider-specific context_length in config.yaml
|
||||
const providerConfigCtx = lookupProviderConfigContextLength(config, model, provider)
|
||||
if (providerConfigCtx && providerConfigCtx > 0) return providerConfigCtx
|
||||
|
||||
// 2. Custom provider context_length
|
||||
const customCtx = lookupCustomProviderContextLength(config, model, provider)
|
||||
if (customCtx && customCtx > 0) return customCtx
|
||||
|
||||
// 3. models_dev_cache.json
|
||||
// 3. Global context_length fallback in config.yaml
|
||||
const configCtx = getConfigContextLength(config)
|
||||
if (configCtx && configCtx > 0) return configCtx
|
||||
|
||||
// 4. models_dev_cache.json
|
||||
const cached = lookupContextFromCache(config, model, provider)
|
||||
if (cached) return cached
|
||||
|
||||
// 4. Fallback
|
||||
// 5. Fallback
|
||||
return DEFAULT_CONTEXT_LENGTH
|
||||
}
|
||||
|
||||
@@ -109,6 +109,22 @@ describe('getModelContextLength', () => {
|
||||
expect(getModelContextLength()).toBe(1_050_000)
|
||||
})
|
||||
|
||||
it('prefers requested provider model context_length over top-level default context_length', async () => {
|
||||
writeConfig(`model:\n default: gpt-5.5\n provider: openai-codex\n context_length: 272000\n\nproviders:\n qwen:\n name: Qwen\n default_model: qwen3.6-plus\n models:\n qwen3.6-plus:\n context_length: 1048576\n`)
|
||||
|
||||
const { getModelContextLength } = await loadModelContext()
|
||||
|
||||
expect(getModelContextLength({ provider: 'qwen', model: 'qwen3.6-plus' })).toBe(1_048_576)
|
||||
})
|
||||
|
||||
it('uses provider-level context_length when the requested model belongs to that provider', async () => {
|
||||
writeConfig(`model:\n default: gpt-5.5\n provider: openai-codex\n context_length: 272000\n\nproviders:\n qwen:\n name: Qwen\n default_model: qwen3.6-plus\n models:\n - qwen3.6-plus\n context_length: 1048576\n`)
|
||||
|
||||
const { getModelContextLength } = await loadModelContext()
|
||||
|
||||
expect(getModelContextLength({ provider: 'qwen', model: 'qwen3.6-plus' })).toBe(1_048_576)
|
||||
})
|
||||
|
||||
it('keeps legacy model-name cache lookup when no provider is configured', async () => {
|
||||
writeConfig(`model:\n default: gpt-5.5\n`)
|
||||
writeModelsCache({
|
||||
|
||||
Reference in New Issue
Block a user