fix(gateway): persist watcher metadata in checkpoint for crash recovery (#1706)
Salvaged from PR #1573 by @eren-karakus0. Cherry-picked with authorship preserved. Fixes #1143 — background process notifications resume after gateway restart. Co-authored-by: Muhammet Eren Karakuş <erenkar950@gmail.com>
This commit is contained in:
parent
ce7418e274
commit
d87655afff
5 changed files with 151 additions and 5 deletions
|
|
@ -984,6 +984,16 @@ class GatewayRunner:
|
||||||
):
|
):
|
||||||
self._schedule_update_notification_watch()
|
self._schedule_update_notification_watch()
|
||||||
|
|
||||||
|
# Drain any recovered process watchers (from crash recovery checkpoint)
|
||||||
|
try:
|
||||||
|
from tools.process_registry import process_registry
|
||||||
|
while process_registry.pending_watchers:
|
||||||
|
watcher = process_registry.pending_watchers.pop(0)
|
||||||
|
asyncio.create_task(self._run_process_watcher(watcher))
|
||||||
|
logger.info("Resumed watcher for recovered process %s", watcher.get("session_id"))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Recovered watcher setup error: %s", e)
|
||||||
|
|
||||||
# Start background session expiry watcher for proactive memory flushing
|
# Start background session expiry watcher for proactive memory flushing
|
||||||
asyncio.create_task(self._session_expiry_watcher())
|
asyncio.create_task(self._session_expiry_watcher())
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -50,13 +50,16 @@ def _build_runner(monkeypatch, tmp_path, mode: str) -> GatewayRunner:
|
||||||
return runner
|
return runner
|
||||||
|
|
||||||
|
|
||||||
def _watcher_dict(session_id="proc_test"):
|
def _watcher_dict(session_id="proc_test", thread_id=""):
|
||||||
return {
|
d = {
|
||||||
"session_id": session_id,
|
"session_id": session_id,
|
||||||
"check_interval": 0,
|
"check_interval": 0,
|
||||||
"platform": "telegram",
|
"platform": "telegram",
|
||||||
"chat_id": "123",
|
"chat_id": "123",
|
||||||
}
|
}
|
||||||
|
if thread_id:
|
||||||
|
d["thread_id"] = thread_id
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -196,3 +199,47 @@ async def test_run_process_watcher_respects_notification_mode(
|
||||||
if expected_fragment is not None:
|
if expected_fragment is not None:
|
||||||
sent_message = adapter.send.await_args.args[1]
|
sent_message = adapter.send.await_args.args[1]
|
||||||
assert expected_fragment in sent_message
|
assert expected_fragment in sent_message
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_thread_id_passed_to_send(monkeypatch, tmp_path):
|
||||||
|
"""thread_id from watcher dict is forwarded as metadata to adapter.send()."""
|
||||||
|
import tools.process_registry as pr_module
|
||||||
|
|
||||||
|
sessions = [SimpleNamespace(output_buffer="done\n", exited=True, exit_code=0)]
|
||||||
|
monkeypatch.setattr(pr_module, "process_registry", _FakeRegistry(sessions))
|
||||||
|
|
||||||
|
async def _instant_sleep(*_a, **_kw):
|
||||||
|
pass
|
||||||
|
monkeypatch.setattr(asyncio, "sleep", _instant_sleep)
|
||||||
|
|
||||||
|
runner = _build_runner(monkeypatch, tmp_path, "all")
|
||||||
|
adapter = runner.adapters[Platform.TELEGRAM]
|
||||||
|
|
||||||
|
await runner._run_process_watcher(_watcher_dict(thread_id="42"))
|
||||||
|
|
||||||
|
assert adapter.send.await_count == 1
|
||||||
|
_, kwargs = adapter.send.call_args
|
||||||
|
assert kwargs["metadata"] == {"thread_id": "42"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_thread_id_sends_no_metadata(monkeypatch, tmp_path):
|
||||||
|
"""When thread_id is empty, metadata should be None (general topic)."""
|
||||||
|
import tools.process_registry as pr_module
|
||||||
|
|
||||||
|
sessions = [SimpleNamespace(output_buffer="done\n", exited=True, exit_code=0)]
|
||||||
|
monkeypatch.setattr(pr_module, "process_registry", _FakeRegistry(sessions))
|
||||||
|
|
||||||
|
async def _instant_sleep(*_a, **_kw):
|
||||||
|
pass
|
||||||
|
monkeypatch.setattr(asyncio, "sleep", _instant_sleep)
|
||||||
|
|
||||||
|
runner = _build_runner(monkeypatch, tmp_path, "all")
|
||||||
|
adapter = runner.adapters[Platform.TELEGRAM]
|
||||||
|
|
||||||
|
await runner._run_process_watcher(_watcher_dict())
|
||||||
|
|
||||||
|
assert adapter.send.await_count == 1
|
||||||
|
_, kwargs = adapter.send.call_args
|
||||||
|
assert kwargs["metadata"] is None
|
||||||
|
|
|
||||||
|
|
@ -294,6 +294,61 @@ class TestCheckpoint:
|
||||||
recovered = registry.recover_from_checkpoint()
|
recovered = registry.recover_from_checkpoint()
|
||||||
assert recovered == 0
|
assert recovered == 0
|
||||||
|
|
||||||
|
def test_write_checkpoint_includes_watcher_metadata(self, registry, tmp_path):
|
||||||
|
with patch("tools.process_registry.CHECKPOINT_PATH", tmp_path / "procs.json"):
|
||||||
|
s = _make_session()
|
||||||
|
s.watcher_platform = "telegram"
|
||||||
|
s.watcher_chat_id = "999"
|
||||||
|
s.watcher_thread_id = "42"
|
||||||
|
s.watcher_interval = 60
|
||||||
|
registry._running[s.id] = s
|
||||||
|
registry._write_checkpoint()
|
||||||
|
|
||||||
|
data = json.loads((tmp_path / "procs.json").read_text())
|
||||||
|
assert len(data) == 1
|
||||||
|
assert data[0]["watcher_platform"] == "telegram"
|
||||||
|
assert data[0]["watcher_chat_id"] == "999"
|
||||||
|
assert data[0]["watcher_thread_id"] == "42"
|
||||||
|
assert data[0]["watcher_interval"] == 60
|
||||||
|
|
||||||
|
def test_recover_enqueues_watchers(self, registry, tmp_path):
|
||||||
|
checkpoint = tmp_path / "procs.json"
|
||||||
|
checkpoint.write_text(json.dumps([{
|
||||||
|
"session_id": "proc_live",
|
||||||
|
"command": "sleep 999",
|
||||||
|
"pid": os.getpid(), # current process — guaranteed alive
|
||||||
|
"task_id": "t1",
|
||||||
|
"session_key": "sk1",
|
||||||
|
"watcher_platform": "telegram",
|
||||||
|
"watcher_chat_id": "123",
|
||||||
|
"watcher_thread_id": "42",
|
||||||
|
"watcher_interval": 60,
|
||||||
|
}]))
|
||||||
|
with patch("tools.process_registry.CHECKPOINT_PATH", checkpoint):
|
||||||
|
recovered = registry.recover_from_checkpoint()
|
||||||
|
assert recovered == 1
|
||||||
|
assert len(registry.pending_watchers) == 1
|
||||||
|
w = registry.pending_watchers[0]
|
||||||
|
assert w["session_id"] == "proc_live"
|
||||||
|
assert w["platform"] == "telegram"
|
||||||
|
assert w["chat_id"] == "123"
|
||||||
|
assert w["thread_id"] == "42"
|
||||||
|
assert w["check_interval"] == 60
|
||||||
|
|
||||||
|
def test_recover_skips_watcher_when_no_interval(self, registry, tmp_path):
|
||||||
|
checkpoint = tmp_path / "procs.json"
|
||||||
|
checkpoint.write_text(json.dumps([{
|
||||||
|
"session_id": "proc_live",
|
||||||
|
"command": "sleep 999",
|
||||||
|
"pid": os.getpid(),
|
||||||
|
"task_id": "t1",
|
||||||
|
"watcher_interval": 0,
|
||||||
|
}]))
|
||||||
|
with patch("tools.process_registry.CHECKPOINT_PATH", checkpoint):
|
||||||
|
recovered = registry.recover_from_checkpoint()
|
||||||
|
assert recovered == 1
|
||||||
|
assert len(registry.pending_watchers) == 0
|
||||||
|
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# Kill process
|
# Kill process
|
||||||
|
|
|
||||||
|
|
@ -78,6 +78,11 @@ class ProcessSession:
|
||||||
output_buffer: str = "" # Rolling output (last MAX_OUTPUT_CHARS)
|
output_buffer: str = "" # Rolling output (last MAX_OUTPUT_CHARS)
|
||||||
max_output_chars: int = MAX_OUTPUT_CHARS
|
max_output_chars: int = MAX_OUTPUT_CHARS
|
||||||
detached: bool = False # True if recovered from crash (no pipe)
|
detached: bool = False # True if recovered from crash (no pipe)
|
||||||
|
# Watcher/notification metadata (persisted for crash recovery)
|
||||||
|
watcher_platform: str = ""
|
||||||
|
watcher_chat_id: str = ""
|
||||||
|
watcher_thread_id: str = ""
|
||||||
|
watcher_interval: int = 0 # 0 = no watcher configured
|
||||||
_lock: threading.Lock = field(default_factory=threading.Lock)
|
_lock: threading.Lock = field(default_factory=threading.Lock)
|
||||||
_reader_thread: Optional[threading.Thread] = field(default=None, repr=False)
|
_reader_thread: Optional[threading.Thread] = field(default=None, repr=False)
|
||||||
_pty: Any = field(default=None, repr=False) # ptyprocess handle (when use_pty=True)
|
_pty: Any = field(default=None, repr=False) # ptyprocess handle (when use_pty=True)
|
||||||
|
|
@ -709,6 +714,10 @@ class ProcessRegistry:
|
||||||
"started_at": s.started_at,
|
"started_at": s.started_at,
|
||||||
"task_id": s.task_id,
|
"task_id": s.task_id,
|
||||||
"session_key": s.session_key,
|
"session_key": s.session_key,
|
||||||
|
"watcher_platform": s.watcher_platform,
|
||||||
|
"watcher_chat_id": s.watcher_chat_id,
|
||||||
|
"watcher_thread_id": s.watcher_thread_id,
|
||||||
|
"watcher_interval": s.watcher_interval,
|
||||||
})
|
})
|
||||||
|
|
||||||
# Atomic write to avoid corruption on crash
|
# Atomic write to avoid corruption on crash
|
||||||
|
|
@ -755,12 +764,27 @@ class ProcessRegistry:
|
||||||
cwd=entry.get("cwd"),
|
cwd=entry.get("cwd"),
|
||||||
started_at=entry.get("started_at", time.time()),
|
started_at=entry.get("started_at", time.time()),
|
||||||
detached=True, # Can't read output, but can report status + kill
|
detached=True, # Can't read output, but can report status + kill
|
||||||
|
watcher_platform=entry.get("watcher_platform", ""),
|
||||||
|
watcher_chat_id=entry.get("watcher_chat_id", ""),
|
||||||
|
watcher_thread_id=entry.get("watcher_thread_id", ""),
|
||||||
|
watcher_interval=entry.get("watcher_interval", 0),
|
||||||
)
|
)
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._running[session.id] = session
|
self._running[session.id] = session
|
||||||
recovered += 1
|
recovered += 1
|
||||||
logger.info("Recovered detached process: %s (pid=%d)", session.command[:60], pid)
|
logger.info("Recovered detached process: %s (pid=%d)", session.command[:60], pid)
|
||||||
|
|
||||||
|
# Re-enqueue watcher so gateway can resume notifications
|
||||||
|
if session.watcher_interval > 0:
|
||||||
|
self.pending_watchers.append({
|
||||||
|
"session_id": session.id,
|
||||||
|
"check_interval": session.watcher_interval,
|
||||||
|
"session_key": session.session_key,
|
||||||
|
"platform": session.watcher_platform,
|
||||||
|
"chat_id": session.watcher_chat_id,
|
||||||
|
"thread_id": session.watcher_thread_id,
|
||||||
|
})
|
||||||
|
|
||||||
# Clear the checkpoint (will be rewritten as processes finish)
|
# Clear the checkpoint (will be rewritten as processes finish)
|
||||||
try:
|
try:
|
||||||
from utils import atomic_json_write
|
from utils import atomic_json_write
|
||||||
|
|
|
||||||
|
|
@ -1082,13 +1082,23 @@ def terminal_tool(
|
||||||
result_data["check_interval_note"] = (
|
result_data["check_interval_note"] = (
|
||||||
f"Requested {check_interval}s raised to minimum 30s"
|
f"Requested {check_interval}s raised to minimum 30s"
|
||||||
)
|
)
|
||||||
|
watcher_platform = os.getenv("HERMES_SESSION_PLATFORM", "")
|
||||||
|
watcher_chat_id = os.getenv("HERMES_SESSION_CHAT_ID", "")
|
||||||
|
watcher_thread_id = os.getenv("HERMES_SESSION_THREAD_ID", "")
|
||||||
|
|
||||||
|
# Store on session for checkpoint persistence
|
||||||
|
proc_session.watcher_platform = watcher_platform
|
||||||
|
proc_session.watcher_chat_id = watcher_chat_id
|
||||||
|
proc_session.watcher_thread_id = watcher_thread_id
|
||||||
|
proc_session.watcher_interval = effective_interval
|
||||||
|
|
||||||
process_registry.pending_watchers.append({
|
process_registry.pending_watchers.append({
|
||||||
"session_id": proc_session.id,
|
"session_id": proc_session.id,
|
||||||
"check_interval": effective_interval,
|
"check_interval": effective_interval,
|
||||||
"session_key": session_key,
|
"session_key": session_key,
|
||||||
"platform": os.getenv("HERMES_SESSION_PLATFORM", ""),
|
"platform": watcher_platform,
|
||||||
"chat_id": os.getenv("HERMES_SESSION_CHAT_ID", ""),
|
"chat_id": watcher_chat_id,
|
||||||
"thread_id": os.getenv("HERMES_SESSION_THREAD_ID", ""),
|
"thread_id": watcher_thread_id,
|
||||||
})
|
})
|
||||||
|
|
||||||
return json.dumps(result_data, ensure_ascii=False)
|
return json.dumps(result_data, ensure_ascii=False)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue