Merge pull request #1427 from NousResearch/fix/1414-gateway-shutdown-restart
fix(gateway): cancel active runs during shutdown
This commit is contained in:
commit
5254d0bba1
3 changed files with 149 additions and 2 deletions
|
|
@ -356,6 +356,10 @@ class BasePlatformAdapter(ABC):
|
||||||
# Key: session_key (e.g., chat_id), Value: (event, asyncio.Event for interrupt)
|
# Key: session_key (e.g., chat_id), Value: (event, asyncio.Event for interrupt)
|
||||||
self._active_sessions: Dict[str, asyncio.Event] = {}
|
self._active_sessions: Dict[str, asyncio.Event] = {}
|
||||||
self._pending_messages: Dict[str, MessageEvent] = {}
|
self._pending_messages: Dict[str, MessageEvent] = {}
|
||||||
|
# Background message-processing tasks spawned by handle_message().
|
||||||
|
# Gateway shutdown cancels these so an old gateway instance doesn't keep
|
||||||
|
# working on a task after --replace or manual restarts.
|
||||||
|
self._background_tasks: set[asyncio.Task] = set()
|
||||||
# Chats where auto-TTS on voice input is disabled (set by /voice off)
|
# Chats where auto-TTS on voice input is disabled (set by /voice off)
|
||||||
self._auto_tts_disabled_chats: set = set()
|
self._auto_tts_disabled_chats: set = set()
|
||||||
|
|
||||||
|
|
@ -778,7 +782,15 @@ class BasePlatformAdapter(ABC):
|
||||||
return # Don't process now - will be handled after current task finishes
|
return # Don't process now - will be handled after current task finishes
|
||||||
|
|
||||||
# Spawn background task to process this message
|
# Spawn background task to process this message
|
||||||
asyncio.create_task(self._process_message_background(event, session_key))
|
task = asyncio.create_task(self._process_message_background(event, session_key))
|
||||||
|
try:
|
||||||
|
self._background_tasks.add(task)
|
||||||
|
except TypeError:
|
||||||
|
# Some tests stub create_task() with lightweight sentinels that are not
|
||||||
|
# hashable and do not support lifecycle callbacks.
|
||||||
|
return
|
||||||
|
if hasattr(task, "add_done_callback"):
|
||||||
|
task.add_done_callback(self._background_tasks.discard)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_human_delay() -> float:
|
def _get_human_delay() -> float:
|
||||||
|
|
@ -988,6 +1000,21 @@ class BasePlatformAdapter(ABC):
|
||||||
if session_key in self._active_sessions:
|
if session_key in self._active_sessions:
|
||||||
del self._active_sessions[session_key]
|
del self._active_sessions[session_key]
|
||||||
|
|
||||||
|
async def cancel_background_tasks(self) -> None:
|
||||||
|
"""Cancel any in-flight background message-processing tasks.
|
||||||
|
|
||||||
|
Used during gateway shutdown/replacement so active sessions from the old
|
||||||
|
process do not keep running after adapters are being torn down.
|
||||||
|
"""
|
||||||
|
tasks = [task for task in self._background_tasks if not task.done()]
|
||||||
|
for task in tasks:
|
||||||
|
task.cancel()
|
||||||
|
if tasks:
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
self._background_tasks.clear()
|
||||||
|
self._pending_messages.clear()
|
||||||
|
self._active_sessions.clear()
|
||||||
|
|
||||||
def has_pending_interrupt(self, session_key: str) -> bool:
|
def has_pending_interrupt(self, session_key: str) -> bool:
|
||||||
"""Check if there's a pending interrupt for a session."""
|
"""Check if there's a pending interrupt for a session."""
|
||||||
return session_key in self._active_sessions and self._active_sessions[session_key].is_set()
|
return session_key in self._active_sessions and self._active_sessions[session_key].is_set()
|
||||||
|
|
|
||||||
|
|
@ -900,8 +900,19 @@ class GatewayRunner:
|
||||||
"""Stop the gateway and disconnect all adapters."""
|
"""Stop the gateway and disconnect all adapters."""
|
||||||
logger.info("Stopping gateway...")
|
logger.info("Stopping gateway...")
|
||||||
self._running = False
|
self._running = False
|
||||||
|
|
||||||
|
for session_key, agent in list(self._running_agents.items()):
|
||||||
|
try:
|
||||||
|
agent.interrupt("Gateway shutting down")
|
||||||
|
logger.debug("Interrupted running agent for session %s during shutdown", session_key[:20])
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("Failed interrupting agent during shutdown: %s", e)
|
||||||
|
|
||||||
for platform, adapter in list(self.adapters.items()):
|
for platform, adapter in list(self.adapters.items()):
|
||||||
|
try:
|
||||||
|
await adapter.cancel_background_tasks()
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("✗ %s background-task cancel error: %s", platform.value, e)
|
||||||
try:
|
try:
|
||||||
await adapter.disconnect()
|
await adapter.disconnect()
|
||||||
logger.info("✓ %s disconnected", platform.value)
|
logger.info("✓ %s disconnected", platform.value)
|
||||||
|
|
@ -909,6 +920,9 @@ class GatewayRunner:
|
||||||
logger.error("✗ %s disconnect error: %s", platform.value, e)
|
logger.error("✗ %s disconnect error: %s", platform.value, e)
|
||||||
|
|
||||||
self.adapters.clear()
|
self.adapters.clear()
|
||||||
|
self._running_agents.clear()
|
||||||
|
self._pending_messages.clear()
|
||||||
|
self._pending_approvals.clear()
|
||||||
self._shutdown_all_gateway_honcho()
|
self._shutdown_all_gateway_honcho()
|
||||||
self._shutdown_event.set()
|
self._shutdown_event.set()
|
||||||
|
|
||||||
|
|
|
||||||
106
tests/gateway/test_gateway_shutdown.py
Normal file
106
tests/gateway/test_gateway_shutdown.py
Normal file
|
|
@ -0,0 +1,106 @@
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||||
|
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, SendResult
|
||||||
|
from gateway.run import GatewayRunner
|
||||||
|
from gateway.session import SessionSource, build_session_key
|
||||||
|
|
||||||
|
|
||||||
|
class StubAdapter(BasePlatformAdapter):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(PlatformConfig(enabled=True, token="***"), Platform.TELEGRAM)
|
||||||
|
|
||||||
|
async def connect(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def disconnect(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def send(self, chat_id, content, reply_to=None, metadata=None):
|
||||||
|
return SendResult(success=True, message_id="1")
|
||||||
|
|
||||||
|
async def send_typing(self, chat_id, metadata=None):
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_chat_info(self, chat_id):
|
||||||
|
return {"id": chat_id}
|
||||||
|
|
||||||
|
|
||||||
|
def _source(chat_id="123456", chat_type="dm"):
|
||||||
|
return SessionSource(
|
||||||
|
platform=Platform.TELEGRAM,
|
||||||
|
chat_id=chat_id,
|
||||||
|
chat_type=chat_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_background_tasks_cancels_inflight_message_processing():
|
||||||
|
adapter = StubAdapter()
|
||||||
|
release = asyncio.Event()
|
||||||
|
|
||||||
|
async def block_forever(_event):
|
||||||
|
await release.wait()
|
||||||
|
return None
|
||||||
|
|
||||||
|
adapter.set_message_handler(block_forever)
|
||||||
|
event = MessageEvent(text="work", source=_source(), message_id="1")
|
||||||
|
|
||||||
|
await adapter.handle_message(event)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
session_key = build_session_key(event.source)
|
||||||
|
assert session_key in adapter._active_sessions
|
||||||
|
assert adapter._background_tasks
|
||||||
|
|
||||||
|
await adapter.cancel_background_tasks()
|
||||||
|
|
||||||
|
assert adapter._background_tasks == set()
|
||||||
|
assert adapter._active_sessions == {}
|
||||||
|
assert adapter._pending_messages == {}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks():
|
||||||
|
runner = object.__new__(GatewayRunner)
|
||||||
|
runner.config = GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")})
|
||||||
|
runner._running = True
|
||||||
|
runner._shutdown_event = asyncio.Event()
|
||||||
|
runner._exit_reason = None
|
||||||
|
runner._pending_messages = {"session": "pending text"}
|
||||||
|
runner._pending_approvals = {"session": {"command": "rm -rf /tmp/x"}}
|
||||||
|
runner._shutdown_all_gateway_honcho = lambda: None
|
||||||
|
|
||||||
|
adapter = StubAdapter()
|
||||||
|
release = asyncio.Event()
|
||||||
|
|
||||||
|
async def block_forever(_event):
|
||||||
|
await release.wait()
|
||||||
|
return None
|
||||||
|
|
||||||
|
adapter.set_message_handler(block_forever)
|
||||||
|
event = MessageEvent(text="work", source=_source(), message_id="1")
|
||||||
|
await adapter.handle_message(event)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
disconnect_mock = AsyncMock()
|
||||||
|
adapter.disconnect = disconnect_mock
|
||||||
|
|
||||||
|
session_key = build_session_key(event.source)
|
||||||
|
running_agent = MagicMock()
|
||||||
|
runner._running_agents = {session_key: running_agent}
|
||||||
|
runner.adapters = {Platform.TELEGRAM: adapter}
|
||||||
|
|
||||||
|
with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"):
|
||||||
|
await runner.stop()
|
||||||
|
|
||||||
|
running_agent.interrupt.assert_called_once_with("Gateway shutting down")
|
||||||
|
disconnect_mock.assert_awaited_once()
|
||||||
|
assert runner.adapters == {}
|
||||||
|
assert runner._running_agents == {}
|
||||||
|
assert runner._pending_messages == {}
|
||||||
|
assert runner._pending_approvals == {}
|
||||||
|
assert runner._shutdown_event.is_set() is True
|
||||||
Loading…
Add table
Add a link
Reference in a new issue