diff --git a/gateway/run.py b/gateway/run.py index 6795610a..5ab74972 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -1652,10 +1652,11 @@ class GatewayRunner: skip_db=agent_persisted, ) - # Update session with actual prompt token count from the agent + # Update session with actual prompt token count and model from the agent self.session_store.update_session( session_entry.session_key, last_prompt_tokens=agent_result.get("last_prompt_tokens", 0), + model=agent_result.get("model"), ) # Auto voice reply: send TTS audio before the text response @@ -3988,6 +3989,7 @@ class GatewayRunner: _agent = agent_holder[0] if _agent and hasattr(_agent, "context_compressor"): _last_prompt_toks = getattr(_agent.context_compressor, "last_prompt_tokens", 0) + _resolved_model = getattr(_agent, "model", None) if _agent else None if not final_response: error_msg = f"⚠️ {result['error']}" if result.get("error") else "(No response generated)" @@ -3998,6 +4000,7 @@ class GatewayRunner: "tools": tools_holder[0] or [], "history_offset": len(agent_history), "last_prompt_tokens": _last_prompt_toks, + "model": _resolved_model, } # Scan tool results for MEDIA: tags that need to be delivered @@ -4060,6 +4063,7 @@ class GatewayRunner: "tools": tools_holder[0] or [], "history_offset": len(agent_history), "last_prompt_tokens": _last_prompt_toks, + "model": _resolved_model, "session_id": effective_session_id, } diff --git a/gateway/session.py b/gateway/session.py index 86e42b59..2f74d454 100644 --- a/gateway/session.py +++ b/gateway/session.py @@ -594,6 +594,7 @@ class SessionStore: input_tokens: int = 0, output_tokens: int = 0, last_prompt_tokens: int = None, + model: str = None, ) -> None: """Update a session's metadata after an interaction.""" self._ensure_loaded() @@ -611,7 +612,8 @@ class SessionStore: if self._db: try: self._db.update_token_counts( - entry.session_id, input_tokens, output_tokens + entry.session_id, input_tokens, output_tokens, + model=model, ) except Exception as e: logger.debug("Session DB operation failed: %s", e) diff --git a/hermes_state.py b/hermes_state.py index 5e29321e..8945e195 100644 --- a/hermes_state.py +++ b/hermes_state.py @@ -227,15 +227,17 @@ class SessionDB: self._conn.commit() def update_token_counts( - self, session_id: str, input_tokens: int = 0, output_tokens: int = 0 + self, session_id: str, input_tokens: int = 0, output_tokens: int = 0, + model: str = None, ) -> None: - """Increment token counters on a session.""" + """Increment token counters and backfill model if not already set.""" self._conn.execute( """UPDATE sessions SET input_tokens = input_tokens + ?, - output_tokens = output_tokens + ? + output_tokens = output_tokens + ?, + model = COALESCE(model, ?) WHERE id = ?""", - (input_tokens, output_tokens, session_id), + (input_tokens, output_tokens, model, session_id), ) self._conn.commit() diff --git a/tests/gateway/test_session.py b/tests/gateway/test_session.py index b5808a99..0737f18d 100644 --- a/tests/gateway/test_session.py +++ b/tests/gateway/test_session.py @@ -577,3 +577,28 @@ class TestLastPromptTokens: store.update_session("k1", last_prompt_tokens=0) assert entry.last_prompt_tokens == 0 + + def test_update_session_passes_model_to_db(self, tmp_path): + """Gateway session updates should forward the resolved model to SQLite.""" + config = GatewayConfig() + with patch("gateway.session.SessionStore._ensure_loaded"): + store = SessionStore(sessions_dir=tmp_path, config=config) + store._loaded = True + store._save = MagicMock() + store._db = MagicMock() + + from gateway.session import SessionEntry + from datetime import datetime + entry = SessionEntry( + session_key="k1", + session_id="s1", + created_at=datetime.now(), + updated_at=datetime.now(), + ) + store._entries = {"k1": entry} + + store.update_session("k1", model="openai/gpt-5.4") + + store._db.update_token_counts.assert_called_once_with( + "s1", 0, 0, model="openai/gpt-5.4" + ) diff --git a/tests/test_hermes_state.py b/tests/test_hermes_state.py index 329ae6f4..81e922c7 100644 --- a/tests/test_hermes_state.py +++ b/tests/test_hermes_state.py @@ -55,13 +55,27 @@ class TestSessionLifecycle: def test_update_token_counts(self, db): db.create_session(session_id="s1", source="cli") - db.update_token_counts("s1", input_tokens=100, output_tokens=50) db.update_token_counts("s1", input_tokens=200, output_tokens=100) + db.update_token_counts("s1", input_tokens=100, output_tokens=50) session = db.get_session("s1") assert session["input_tokens"] == 300 assert session["output_tokens"] == 150 + def test_update_token_counts_backfills_model_when_null(self, db): + db.create_session(session_id="s1", source="telegram") + db.update_token_counts("s1", input_tokens=10, output_tokens=5, model="openai/gpt-5.4") + + session = db.get_session("s1") + assert session["model"] == "openai/gpt-5.4" + + def test_update_token_counts_preserves_existing_model(self, db): + db.create_session(session_id="s1", source="cli", model="anthropic/claude-opus-4.6") + db.update_token_counts("s1", input_tokens=10, output_tokens=5, model="openai/gpt-5.4") + + session = db.get_session("s1") + assert session["model"] == "anthropic/claude-opus-4.6" + def test_parent_session(self, db): db.create_session(session_id="parent", source="cli") db.create_session(session_id="child", source="cli", parent_session_id="parent")