feat: live context window status tracking during streaming

This commit is contained in:
dobby-d-elf
2026-05-10 06:51:46 -06:00
parent 8a653bac20
commit 1cf0ff01b5
3 changed files with 176 additions and 0 deletions
+153
View File
@@ -1937,6 +1937,97 @@ def _run_agent_streaming(
STREAM_REASONING_TEXT[stream_id] = '' # start accumulating reasoning trace (#1361 §A)
STREAM_LIVE_TOOL_CALLS[stream_id] = [] # start accumulating tool calls (#1361 §B)
agent = None
_live_prompt_estimate_tokens = [0]
_live_prompt_estimate_seen_ids = set()
def _seed_live_prompt_estimate() -> int:
"""Capture the latest exact prompt size before adding live tool deltas."""
if _live_prompt_estimate_tokens[0] > 0:
return _live_prompt_estimate_tokens[0]
_base = 0
_agent = agent
if _agent is not None:
try:
_cc = getattr(_agent, 'context_compressor', None)
if _cc:
_base = getattr(_cc, 'last_prompt_tokens', 0) or 0
except Exception:
_base = 0
if not _base:
try:
_session_obj = get_session(session_id)
_base = getattr(_session_obj, 'last_prompt_tokens', 0) or 0
except Exception:
_base = 0
_live_prompt_estimate_tokens[0] = int(_base or 0)
return _live_prompt_estimate_tokens[0]
def _bump_live_prompt_estimate(messages) -> int:
"""Increment a rough next-prompt estimate from live tool activity."""
if not messages:
return _live_prompt_estimate_tokens[0]
try:
from agent.model_metadata import estimate_messages_tokens_rough
_delta = int(estimate_messages_tokens_rough(messages) or 0)
except Exception:
_delta = 0
if _delta > 0:
_seed_live_prompt_estimate()
_live_prompt_estimate_tokens[0] += _delta
return _live_prompt_estimate_tokens[0]
def _live_usage_snapshot():
"""Best-effort live usage payload for mid-stream UI updates.
During tool execution the final `done` event has not fired yet, but the
frontend still benefits from seeing the latest known token / context
values. These are exact for the most recent model call and a truthful
lower bound for the pending next call after a tool result is appended.
"""
_usage = {
'input_tokens': 0,
'output_tokens': 0,
'estimated_cost': 0,
'context_length': 0,
'threshold_tokens': 0,
'last_prompt_tokens': 0,
}
try:
_session_obj = get_session(session_id)
except Exception:
_session_obj = None
_agent = agent
if _agent is not None:
try:
_usage['input_tokens'] = getattr(_agent, 'session_prompt_tokens', 0) or 0
_usage['output_tokens'] = getattr(_agent, 'session_completion_tokens', 0) or 0
_usage['estimated_cost'] = getattr(_agent, 'session_estimated_cost_usd', 0) or 0
except Exception:
pass
try:
_cc = getattr(_agent, 'context_compressor', None)
if _cc:
_usage['context_length'] = getattr(_cc, 'context_length', 0) or 0
_usage['threshold_tokens'] = getattr(_cc, 'threshold_tokens', 0) or 0
_usage['last_prompt_tokens'] = getattr(_cc, 'last_prompt_tokens', 0) or 0
except Exception:
pass
if _session_obj is not None:
for _field in ('input_tokens', 'output_tokens', 'estimated_cost', 'context_length', 'threshold_tokens', 'last_prompt_tokens'):
if not _usage.get(_field):
try:
_usage[_field] = getattr(_session_obj, _field, 0) or 0
except Exception:
pass
if _live_prompt_estimate_tokens[0] > (_usage.get('last_prompt_tokens') or 0):
_usage['last_prompt_tokens'] = _live_prompt_estimate_tokens[0]
return _usage
# Register this stream with the global streaming meter
meter().begin_session(stream_id)
@@ -1954,6 +2045,7 @@ def _run_agent_streaming(
break # stream was cancelled or ended — exit
stats = meter().get_stats()
stats['session_id'] = stream_id
stats['usage'] = _live_usage_snapshot()
put('metering', stats)
_metering_thread = threading.Thread(target=_metering_ticker, daemon=True)
@@ -2200,6 +2292,35 @@ def _run_agent_streaming(
# block is reordered later (Issue #765).
_checkpoint_activity = [0]
def _record_live_tool_start(tool_call_id, name, args):
if not tool_call_id or tool_call_id in _live_prompt_estimate_seen_ids:
return
_live_prompt_estimate_seen_ids.add(tool_call_id)
_tool_call = {
'id': tool_call_id,
'type': 'function',
'function': {
'name': str(name or ''),
'arguments': json.dumps(args if isinstance(args, dict) else {}, ensure_ascii=False, sort_keys=True),
},
}
_bump_live_prompt_estimate([{
'role': 'assistant',
'content': '',
'tool_calls': [_tool_call],
}])
def _record_live_tool_complete(tool_call_id, name, function_result):
if not tool_call_id:
return
_result_text = _tool_result_snippet(function_result)
_bump_live_prompt_estimate([{
'role': 'tool',
'name': str(name or ''),
'tool_call_id': tool_call_id,
'content': _result_text,
}])
def on_tool(*cb_args, **cb_kwargs):
nonlocal _reasoning_text
event_type = None
@@ -2255,6 +2376,10 @@ def _run_agent_streaming(
'preview': preview,
'args': args_snap,
})
_tool_stats = meter().get_stats()
_tool_stats['session_id'] = stream_id
_tool_stats['usage'] = _live_usage_snapshot()
put('metering', _tool_stats)
# Fallback: poll for pending approval in case notify_cb wasn't
# registered (e.g. older approval module without gateway support).
try:
@@ -2298,8 +2423,32 @@ def _run_agent_streaming(
'duration': cb_kwargs.get('duration'),
'is_error': bool(cb_kwargs.get('is_error', False)),
})
_tool_stats = meter().get_stats()
_tool_stats['session_id'] = stream_id
_tool_stats['usage'] = _live_usage_snapshot()
put('metering', _tool_stats)
return
def on_tool_start(tool_call_id, name, args):
try:
_record_live_tool_start(tool_call_id, name, args)
_tool_stats = meter().get_stats()
_tool_stats['session_id'] = stream_id
_tool_stats['usage'] = _live_usage_snapshot()
put('metering', _tool_stats)
except Exception:
logger.debug('Failed to update live prompt estimate on tool start', exc_info=True)
def on_tool_complete(tool_call_id, name, args, function_result):
try:
_record_live_tool_complete(tool_call_id, name, function_result)
_tool_stats = meter().get_stats()
_tool_stats['session_id'] = stream_id
_tool_stats['usage'] = _live_usage_snapshot()
put('metering', _tool_stats)
except Exception:
logger.debug('Failed to update live prompt estimate on tool completion', exc_info=True)
_AIAgent = _get_ai_agent()
if _AIAgent is None:
raise ImportError(_aiagent_import_error_detail())
@@ -2481,6 +2630,10 @@ def _run_agent_streaming(
_agent_kwargs['reasoning_config'] = _reasoning_config
if 'interim_assistant_callback' in _agent_params:
_agent_kwargs['interim_assistant_callback'] = on_interim_assistant
if 'tool_start_callback' in _agent_params:
_agent_kwargs['tool_start_callback'] = on_tool_start
if 'tool_complete_callback' in _agent_params:
_agent_kwargs['tool_complete_callback'] = on_tool_complete
if 'status_callback' in _agent_params:
_agent_kwargs['status_callback'] = _agent_status_callback
if 'max_iterations' in _agent_params and _max_iterations_cfg is not None: