fix: integration hardening for gateway token tracking
Follow-up to 58dbd81 — ensures smooth transition for existing users:
- Backward compat: old session files without last_prompt_tokens
default to 0 via data.get('last_prompt_tokens', 0)
- /compress, /undo, /retry: reset last_prompt_tokens to 0 after
rewriting transcripts (stale token counts would under-report)
- Auto-compression hygiene: reset last_prompt_tokens after rewriting
- update_session: use None sentinel (not 0) as default so callers
can explicitly reset to 0 while normal calls don't clobber
- 6 new tests covering: default value, serialization roundtrip,
old-format migration, set/reset/no-change semantics
- /reset: new SessionEntry naturally gets last_prompt_tokens=0
2942 tests pass.
This commit is contained in:
parent
5eb62ef423
commit
909e048ad4
3 changed files with 128 additions and 2 deletions
|
|
@ -1083,6 +1083,8 @@ class GatewayRunner:
|
||||||
self.session_store.rewrite_transcript(
|
self.session_store.rewrite_transcript(
|
||||||
session_entry.session_id, _compressed
|
session_entry.session_id, _compressed
|
||||||
)
|
)
|
||||||
|
# Reset stored token count — transcript was rewritten
|
||||||
|
session_entry.last_prompt_tokens = 0
|
||||||
history = _compressed
|
history = _compressed
|
||||||
_new_count = len(_compressed)
|
_new_count = len(_compressed)
|
||||||
_new_tokens = estimate_messages_tokens_rough(
|
_new_tokens = estimate_messages_tokens_rough(
|
||||||
|
|
@ -1747,6 +1749,8 @@ class GatewayRunner:
|
||||||
# Truncate history to before the last user message and persist
|
# Truncate history to before the last user message and persist
|
||||||
truncated = history[:last_user_idx]
|
truncated = history[:last_user_idx]
|
||||||
self.session_store.rewrite_transcript(session_entry.session_id, truncated)
|
self.session_store.rewrite_transcript(session_entry.session_id, truncated)
|
||||||
|
# Reset stored token count — transcript was truncated
|
||||||
|
session_entry.last_prompt_tokens = 0
|
||||||
|
|
||||||
# Re-send by creating a fake text event with the old message
|
# Re-send by creating a fake text event with the old message
|
||||||
retry_event = MessageEvent(
|
retry_event = MessageEvent(
|
||||||
|
|
@ -1778,6 +1782,8 @@ class GatewayRunner:
|
||||||
removed_msg = history[last_user_idx].get("content", "")
|
removed_msg = history[last_user_idx].get("content", "")
|
||||||
removed_count = len(history) - last_user_idx
|
removed_count = len(history) - last_user_idx
|
||||||
self.session_store.rewrite_transcript(session_entry.session_id, history[:last_user_idx])
|
self.session_store.rewrite_transcript(session_entry.session_id, history[:last_user_idx])
|
||||||
|
# Reset stored token count — transcript was truncated
|
||||||
|
session_entry.last_prompt_tokens = 0
|
||||||
|
|
||||||
preview = removed_msg[:40] + "..." if len(removed_msg) > 40 else removed_msg
|
preview = removed_msg[:40] + "..." if len(removed_msg) > 40 else removed_msg
|
||||||
return f"↩️ Undid {removed_count} message(s).\nRemoved: \"{preview}\""
|
return f"↩️ Undid {removed_count} message(s).\nRemoved: \"{preview}\""
|
||||||
|
|
@ -1911,6 +1917,10 @@ class GatewayRunner:
|
||||||
)
|
)
|
||||||
|
|
||||||
self.session_store.rewrite_transcript(session_entry.session_id, compressed)
|
self.session_store.rewrite_transcript(session_entry.session_id, compressed)
|
||||||
|
# Reset stored token count — transcript changed, old value is stale
|
||||||
|
self.session_store.update_session(
|
||||||
|
session_entry.session_key, last_prompt_tokens=0,
|
||||||
|
)
|
||||||
new_count = len(compressed)
|
new_count = len(compressed)
|
||||||
new_tokens = estimate_messages_tokens_rough(compressed)
|
new_tokens = estimate_messages_tokens_rough(compressed)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -556,7 +556,7 @@ class SessionStore:
|
||||||
session_key: str,
|
session_key: str,
|
||||||
input_tokens: int = 0,
|
input_tokens: int = 0,
|
||||||
output_tokens: int = 0,
|
output_tokens: int = 0,
|
||||||
last_prompt_tokens: int = 0,
|
last_prompt_tokens: int = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update a session's metadata after an interaction."""
|
"""Update a session's metadata after an interaction."""
|
||||||
self._ensure_loaded()
|
self._ensure_loaded()
|
||||||
|
|
@ -566,7 +566,7 @@ class SessionStore:
|
||||||
entry.updated_at = datetime.now()
|
entry.updated_at = datetime.now()
|
||||||
entry.input_tokens += input_tokens
|
entry.input_tokens += input_tokens
|
||||||
entry.output_tokens += output_tokens
|
entry.output_tokens += output_tokens
|
||||||
if last_prompt_tokens > 0:
|
if last_prompt_tokens is not None:
|
||||||
entry.last_prompt_tokens = last_prompt_tokens
|
entry.last_prompt_tokens = last_prompt_tokens
|
||||||
entry.total_tokens = entry.input_tokens + entry.output_tokens
|
entry.total_tokens = entry.input_tokens + entry.output_tokens
|
||||||
self._save()
|
self._save()
|
||||||
|
|
|
||||||
|
|
@ -429,3 +429,119 @@ class TestHasAnySessions:
|
||||||
|
|
||||||
store._entries = {"key1": MagicMock()}
|
store._entries = {"key1": MagicMock()}
|
||||||
assert store.has_any_sessions() is False
|
assert store.has_any_sessions() is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestLastPromptTokens:
|
||||||
|
"""Tests for the last_prompt_tokens field — actual API token tracking."""
|
||||||
|
|
||||||
|
def test_session_entry_default(self):
|
||||||
|
"""New sessions should have last_prompt_tokens=0."""
|
||||||
|
from gateway.session import SessionEntry
|
||||||
|
from datetime import datetime
|
||||||
|
entry = SessionEntry(
|
||||||
|
session_key="test",
|
||||||
|
session_id="s1",
|
||||||
|
created_at=datetime.now(),
|
||||||
|
updated_at=datetime.now(),
|
||||||
|
)
|
||||||
|
assert entry.last_prompt_tokens == 0
|
||||||
|
|
||||||
|
def test_session_entry_roundtrip(self):
|
||||||
|
"""last_prompt_tokens should survive serialization/deserialization."""
|
||||||
|
from gateway.session import SessionEntry
|
||||||
|
from datetime import datetime
|
||||||
|
entry = SessionEntry(
|
||||||
|
session_key="test",
|
||||||
|
session_id="s1",
|
||||||
|
created_at=datetime.now(),
|
||||||
|
updated_at=datetime.now(),
|
||||||
|
last_prompt_tokens=42000,
|
||||||
|
)
|
||||||
|
d = entry.to_dict()
|
||||||
|
assert d["last_prompt_tokens"] == 42000
|
||||||
|
restored = SessionEntry.from_dict(d)
|
||||||
|
assert restored.last_prompt_tokens == 42000
|
||||||
|
|
||||||
|
def test_session_entry_from_old_data(self):
|
||||||
|
"""Old session data without last_prompt_tokens should default to 0."""
|
||||||
|
from gateway.session import SessionEntry
|
||||||
|
data = {
|
||||||
|
"session_key": "test",
|
||||||
|
"session_id": "s1",
|
||||||
|
"created_at": "2025-01-01T00:00:00",
|
||||||
|
"updated_at": "2025-01-01T00:00:00",
|
||||||
|
"input_tokens": 100,
|
||||||
|
"output_tokens": 50,
|
||||||
|
"total_tokens": 150,
|
||||||
|
# No last_prompt_tokens — old format
|
||||||
|
}
|
||||||
|
entry = SessionEntry.from_dict(data)
|
||||||
|
assert entry.last_prompt_tokens == 0
|
||||||
|
|
||||||
|
def test_update_session_sets_last_prompt_tokens(self, tmp_path):
|
||||||
|
"""update_session should store the actual prompt token count."""
|
||||||
|
config = GatewayConfig()
|
||||||
|
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||||
|
store = SessionStore(sessions_dir=tmp_path, config=config)
|
||||||
|
store._loaded = True
|
||||||
|
store._db = None
|
||||||
|
store._save = MagicMock()
|
||||||
|
|
||||||
|
from gateway.session import SessionEntry
|
||||||
|
from datetime import datetime
|
||||||
|
entry = SessionEntry(
|
||||||
|
session_key="k1",
|
||||||
|
session_id="s1",
|
||||||
|
created_at=datetime.now(),
|
||||||
|
updated_at=datetime.now(),
|
||||||
|
)
|
||||||
|
store._entries = {"k1": entry}
|
||||||
|
|
||||||
|
store.update_session("k1", last_prompt_tokens=85000)
|
||||||
|
assert entry.last_prompt_tokens == 85000
|
||||||
|
|
||||||
|
def test_update_session_none_does_not_change(self, tmp_path):
|
||||||
|
"""update_session with default (None) should not change last_prompt_tokens."""
|
||||||
|
config = GatewayConfig()
|
||||||
|
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||||
|
store = SessionStore(sessions_dir=tmp_path, config=config)
|
||||||
|
store._loaded = True
|
||||||
|
store._db = None
|
||||||
|
store._save = MagicMock()
|
||||||
|
|
||||||
|
from gateway.session import SessionEntry
|
||||||
|
from datetime import datetime
|
||||||
|
entry = SessionEntry(
|
||||||
|
session_key="k1",
|
||||||
|
session_id="s1",
|
||||||
|
created_at=datetime.now(),
|
||||||
|
updated_at=datetime.now(),
|
||||||
|
last_prompt_tokens=50000,
|
||||||
|
)
|
||||||
|
store._entries = {"k1": entry}
|
||||||
|
|
||||||
|
store.update_session("k1") # No last_prompt_tokens arg
|
||||||
|
assert entry.last_prompt_tokens == 50000 # unchanged
|
||||||
|
|
||||||
|
def test_update_session_zero_resets(self, tmp_path):
|
||||||
|
"""update_session with last_prompt_tokens=0 should reset the field."""
|
||||||
|
config = GatewayConfig()
|
||||||
|
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||||
|
store = SessionStore(sessions_dir=tmp_path, config=config)
|
||||||
|
store._loaded = True
|
||||||
|
store._db = None
|
||||||
|
store._save = MagicMock()
|
||||||
|
|
||||||
|
from gateway.session import SessionEntry
|
||||||
|
from datetime import datetime
|
||||||
|
entry = SessionEntry(
|
||||||
|
session_key="k1",
|
||||||
|
session_id="s1",
|
||||||
|
created_at=datetime.now(),
|
||||||
|
updated_at=datetime.now(),
|
||||||
|
last_prompt_tokens=85000,
|
||||||
|
)
|
||||||
|
store._entries = {"k1": entry}
|
||||||
|
|
||||||
|
store.update_session("k1", last_prompt_tokens=0)
|
||||||
|
assert entry.last_prompt_tokens == 0
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue