Merge pull request #1306 from NousResearch/hermes/hermes-2ba57c8a
fix: backfill model on gateway sessions after agent runs
This commit is contained in:
commit
917adcbaf4
5 changed files with 54 additions and 7 deletions
|
|
@ -1652,10 +1652,11 @@ class GatewayRunner:
|
||||||
skip_db=agent_persisted,
|
skip_db=agent_persisted,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update session with actual prompt token count from the agent
|
# Update session with actual prompt token count and model from the agent
|
||||||
self.session_store.update_session(
|
self.session_store.update_session(
|
||||||
session_entry.session_key,
|
session_entry.session_key,
|
||||||
last_prompt_tokens=agent_result.get("last_prompt_tokens", 0),
|
last_prompt_tokens=agent_result.get("last_prompt_tokens", 0),
|
||||||
|
model=agent_result.get("model"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Auto voice reply: send TTS audio before the text response
|
# Auto voice reply: send TTS audio before the text response
|
||||||
|
|
@ -3988,6 +3989,7 @@ class GatewayRunner:
|
||||||
_agent = agent_holder[0]
|
_agent = agent_holder[0]
|
||||||
if _agent and hasattr(_agent, "context_compressor"):
|
if _agent and hasattr(_agent, "context_compressor"):
|
||||||
_last_prompt_toks = getattr(_agent.context_compressor, "last_prompt_tokens", 0)
|
_last_prompt_toks = getattr(_agent.context_compressor, "last_prompt_tokens", 0)
|
||||||
|
_resolved_model = getattr(_agent, "model", None) if _agent else None
|
||||||
|
|
||||||
if not final_response:
|
if not final_response:
|
||||||
error_msg = f"⚠️ {result['error']}" if result.get("error") else "(No response generated)"
|
error_msg = f"⚠️ {result['error']}" if result.get("error") else "(No response generated)"
|
||||||
|
|
@ -3998,6 +4000,7 @@ class GatewayRunner:
|
||||||
"tools": tools_holder[0] or [],
|
"tools": tools_holder[0] or [],
|
||||||
"history_offset": len(agent_history),
|
"history_offset": len(agent_history),
|
||||||
"last_prompt_tokens": _last_prompt_toks,
|
"last_prompt_tokens": _last_prompt_toks,
|
||||||
|
"model": _resolved_model,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Scan tool results for MEDIA:<path> tags that need to be delivered
|
# Scan tool results for MEDIA:<path> tags that need to be delivered
|
||||||
|
|
@ -4060,6 +4063,7 @@ class GatewayRunner:
|
||||||
"tools": tools_holder[0] or [],
|
"tools": tools_holder[0] or [],
|
||||||
"history_offset": len(agent_history),
|
"history_offset": len(agent_history),
|
||||||
"last_prompt_tokens": _last_prompt_toks,
|
"last_prompt_tokens": _last_prompt_toks,
|
||||||
|
"model": _resolved_model,
|
||||||
"session_id": effective_session_id,
|
"session_id": effective_session_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -594,6 +594,7 @@ class SessionStore:
|
||||||
input_tokens: int = 0,
|
input_tokens: int = 0,
|
||||||
output_tokens: int = 0,
|
output_tokens: int = 0,
|
||||||
last_prompt_tokens: int = None,
|
last_prompt_tokens: int = None,
|
||||||
|
model: str = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update a session's metadata after an interaction."""
|
"""Update a session's metadata after an interaction."""
|
||||||
self._ensure_loaded()
|
self._ensure_loaded()
|
||||||
|
|
@ -611,7 +612,8 @@ class SessionStore:
|
||||||
if self._db:
|
if self._db:
|
||||||
try:
|
try:
|
||||||
self._db.update_token_counts(
|
self._db.update_token_counts(
|
||||||
entry.session_id, input_tokens, output_tokens
|
entry.session_id, input_tokens, output_tokens,
|
||||||
|
model=model,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("Session DB operation failed: %s", e)
|
logger.debug("Session DB operation failed: %s", e)
|
||||||
|
|
|
||||||
|
|
@ -227,15 +227,17 @@ class SessionDB:
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
|
|
||||||
def update_token_counts(
|
def update_token_counts(
|
||||||
self, session_id: str, input_tokens: int = 0, output_tokens: int = 0
|
self, session_id: str, input_tokens: int = 0, output_tokens: int = 0,
|
||||||
|
model: str = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Increment token counters on a session."""
|
"""Increment token counters and backfill model if not already set."""
|
||||||
self._conn.execute(
|
self._conn.execute(
|
||||||
"""UPDATE sessions SET
|
"""UPDATE sessions SET
|
||||||
input_tokens = input_tokens + ?,
|
input_tokens = input_tokens + ?,
|
||||||
output_tokens = output_tokens + ?
|
output_tokens = output_tokens + ?,
|
||||||
|
model = COALESCE(model, ?)
|
||||||
WHERE id = ?""",
|
WHERE id = ?""",
|
||||||
(input_tokens, output_tokens, session_id),
|
(input_tokens, output_tokens, model, session_id),
|
||||||
)
|
)
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -577,3 +577,28 @@ class TestLastPromptTokens:
|
||||||
|
|
||||||
store.update_session("k1", last_prompt_tokens=0)
|
store.update_session("k1", last_prompt_tokens=0)
|
||||||
assert entry.last_prompt_tokens == 0
|
assert entry.last_prompt_tokens == 0
|
||||||
|
|
||||||
|
def test_update_session_passes_model_to_db(self, tmp_path):
|
||||||
|
"""Gateway session updates should forward the resolved model to SQLite."""
|
||||||
|
config = GatewayConfig()
|
||||||
|
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||||
|
store = SessionStore(sessions_dir=tmp_path, config=config)
|
||||||
|
store._loaded = True
|
||||||
|
store._save = MagicMock()
|
||||||
|
store._db = MagicMock()
|
||||||
|
|
||||||
|
from gateway.session import SessionEntry
|
||||||
|
from datetime import datetime
|
||||||
|
entry = SessionEntry(
|
||||||
|
session_key="k1",
|
||||||
|
session_id="s1",
|
||||||
|
created_at=datetime.now(),
|
||||||
|
updated_at=datetime.now(),
|
||||||
|
)
|
||||||
|
store._entries = {"k1": entry}
|
||||||
|
|
||||||
|
store.update_session("k1", model="openai/gpt-5.4")
|
||||||
|
|
||||||
|
store._db.update_token_counts.assert_called_once_with(
|
||||||
|
"s1", 0, 0, model="openai/gpt-5.4"
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -55,13 +55,27 @@ class TestSessionLifecycle:
|
||||||
|
|
||||||
def test_update_token_counts(self, db):
|
def test_update_token_counts(self, db):
|
||||||
db.create_session(session_id="s1", source="cli")
|
db.create_session(session_id="s1", source="cli")
|
||||||
db.update_token_counts("s1", input_tokens=100, output_tokens=50)
|
|
||||||
db.update_token_counts("s1", input_tokens=200, output_tokens=100)
|
db.update_token_counts("s1", input_tokens=200, output_tokens=100)
|
||||||
|
db.update_token_counts("s1", input_tokens=100, output_tokens=50)
|
||||||
|
|
||||||
session = db.get_session("s1")
|
session = db.get_session("s1")
|
||||||
assert session["input_tokens"] == 300
|
assert session["input_tokens"] == 300
|
||||||
assert session["output_tokens"] == 150
|
assert session["output_tokens"] == 150
|
||||||
|
|
||||||
|
def test_update_token_counts_backfills_model_when_null(self, db):
|
||||||
|
db.create_session(session_id="s1", source="telegram")
|
||||||
|
db.update_token_counts("s1", input_tokens=10, output_tokens=5, model="openai/gpt-5.4")
|
||||||
|
|
||||||
|
session = db.get_session("s1")
|
||||||
|
assert session["model"] == "openai/gpt-5.4"
|
||||||
|
|
||||||
|
def test_update_token_counts_preserves_existing_model(self, db):
|
||||||
|
db.create_session(session_id="s1", source="cli", model="anthropic/claude-opus-4.6")
|
||||||
|
db.update_token_counts("s1", input_tokens=10, output_tokens=5, model="openai/gpt-5.4")
|
||||||
|
|
||||||
|
session = db.get_session("s1")
|
||||||
|
assert session["model"] == "anthropic/claude-opus-4.6"
|
||||||
|
|
||||||
def test_parent_session(self, db):
|
def test_parent_session(self, db):
|
||||||
db.create_session(session_id="parent", source="cli")
|
db.create_session(session_id="parent", source="cli")
|
||||||
db.create_session(session_id="child", source="cli", parent_session_id="parent")
|
db.create_session(session_id="child", source="cli", parent_session_id="parent")
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue