fix(gateway): make group session isolation configurable

default group and channel sessions to per-user isolation, allow opting back into shared room sessions via config.yaml, and document Discord gateway routing and session behavior.
This commit is contained in:
teknium1 2026-03-16 00:22:23 -07:00
parent 06a7d19f98
commit 38b4fd3737
11 changed files with 246 additions and 27 deletions

View file

@ -174,7 +174,10 @@ class GatewayConfig:
# STT settings
stt_enabled: bool = True # Whether to auto-transcribe inbound voice messages
# Session isolation in shared chats
group_sessions_per_user: bool = True # Isolate group/channel sessions per participant when user IDs are available
def get_connected_platforms(self) -> List[Platform]:
"""Return list of platforms that are enabled and configured."""
connected = []
@ -239,6 +242,7 @@ class GatewayConfig:
"sessions_dir": str(self.sessions_dir),
"always_log_local": self.always_log_local,
"stt_enabled": self.stt_enabled,
"group_sessions_per_user": self.group_sessions_per_user,
}
@classmethod
@ -279,6 +283,8 @@ class GatewayConfig:
if stt_enabled is None:
stt_enabled = data.get("stt", {}).get("enabled") if isinstance(data.get("stt"), dict) else None
group_sessions_per_user = data.get("group_sessions_per_user")
return cls(
platforms=platforms,
default_reset_policy=default_policy,
@ -289,6 +295,7 @@ class GatewayConfig:
sessions_dir=sessions_dir,
always_log_local=data.get("always_log_local", True),
stt_enabled=_coerce_bool(stt_enabled, True),
group_sessions_per_user=_coerce_bool(group_sessions_per_user, True),
)
@ -344,6 +351,14 @@ def load_gateway_config() -> GatewayConfig:
if isinstance(stt_cfg, dict) and "enabled" in stt_cfg:
config.stt_enabled = _coerce_bool(stt_cfg.get("enabled"), True)
# Bridge group session isolation from config.yaml into gateway runtime.
# Secure default is per-user isolation in shared chats.
if "group_sessions_per_user" in yaml_cfg:
config.group_sessions_per_user = _coerce_bool(
yaml_cfg.get("group_sessions_per_user"),
True,
)
# Bridge discord settings from config.yaml to env vars
# (env vars take precedence — only set if not already defined)
discord_cfg = yaml_cfg.get("discord", {})

View file

@ -752,7 +752,10 @@ class BasePlatformAdapter(ABC):
if not self._message_handler:
return
session_key = build_session_key(event.source)
session_key = build_session_key(
event.source,
group_sessions_per_user=self.config.extra.get("group_sessions_per_user", True),
)
# Check if there's already an active handler for this session
if session_key in self._active_sessions:

View file

@ -829,7 +829,10 @@ class TelegramAdapter(BasePlatformAdapter):
def _photo_batch_key(self, event: MessageEvent, msg: Message) -> str:
"""Return a batching key for Telegram photos/albums."""
from gateway.session import build_session_key
session_key = build_session_key(event.source)
session_key = build_session_key(
event.source,
group_sessions_per_user=self.config.extra.get("group_sessions_per_user", True),
)
media_group_id = getattr(msg, "media_group_id", None)
if media_group_id:
return f"{session_key}:album:{media_group_id}"

View file

@ -514,6 +514,21 @@ class GatewayRunner:
def exit_reason(self) -> Optional[str]:
return self._exit_reason
def _session_key_for_source(self, source: SessionSource) -> str:
"""Resolve the current session key for a source, honoring gateway config when available."""
if hasattr(self, "session_store") and self.session_store is not None:
try:
session_key = self.session_store._generate_session_key(source)
if isinstance(session_key, str) and session_key:
return session_key
except Exception:
pass
config = getattr(self, "config", None)
return build_session_key(
source,
group_sessions_per_user=getattr(config, "group_sessions_per_user", True),
)
async def _handle_adapter_fatal_error(self, adapter: BasePlatformAdapter) -> None:
"""React to a non-retryable adapter failure after startup."""
logger.error(
@ -942,6 +957,12 @@ class GatewayRunner:
config: Any
) -> Optional[BasePlatformAdapter]:
"""Create the appropriate adapter for a platform."""
if hasattr(config, "extra") and isinstance(config.extra, dict):
config.extra.setdefault(
"group_sessions_per_user",
self.config.group_sessions_per_user,
)
if platform == Platform.TELEGRAM:
from gateway.platforms.telegram import TelegramAdapter, check_telegram_requirements
if not check_telegram_requirements():
@ -1113,7 +1134,7 @@ class GatewayRunner:
# Special case: Telegram/photo bursts often arrive as multiple near-
# simultaneous updates. Do NOT interrupt for photo-only follow-ups here;
# let the adapter-level batching/queueing logic absorb them.
_quick_key = build_session_key(source)
_quick_key = self._session_key_for_source(source)
if _quick_key in self._running_agents:
if event.get_command() == "status":
return await self._handle_status_command(event)
@ -1302,7 +1323,7 @@ class GatewayRunner:
logger.debug("Skill command check failed (non-fatal): %s", e)
# Check for pending exec approval responses
session_key_preview = build_session_key(source)
session_key_preview = self._session_key_for_source(source)
if session_key_preview in self._pending_approvals:
user_text = event.text.strip().lower()
if user_text in ("yes", "y", "approve", "ok", "go", "do it"):
@ -1854,7 +1875,7 @@ class GatewayRunner:
source = event.source
# Get existing session key
session_key = self.session_store._generate_session_key(source)
session_key = self._session_key_for_source(source)
# Flush memories in the background (fire-and-forget) so the user
# gets the "Session reset!" response immediately.
@ -3084,7 +3105,7 @@ class GatewayRunner:
return "Session database not available."
source = event.source
session_key = build_session_key(source)
session_key = self._session_key_for_source(source)
name = event.get_command_args().strip()
if not name:
@ -3156,7 +3177,7 @@ class GatewayRunner:
async def _handle_usage_command(self, event: MessageEvent) -> str:
"""Handle /usage command -- show token usage for the session's last agent run."""
source = event.source
session_key = build_session_key(source)
session_key = self._session_key_for_source(source)
agent = self._running_agents.get(session_key)
if agent and hasattr(agent, "session_total_tokens") and agent.session_api_calls > 0:

View file

@ -315,7 +315,7 @@ class SessionEntry:
)
def build_session_key(source: SessionSource) -> str:
def build_session_key(source: SessionSource, group_sessions_per_user: bool = True) -> str:
"""Build a deterministic session key from a message source.
This is the single source of truth for session key construction.
@ -328,9 +328,11 @@ def build_session_key(source: SessionSource) -> str:
Group/channel rules:
- chat_id identifies the parent group/channel.
- user_id/user_id_alt isolates participants within that parent chat when available.
- user_id/user_id_alt isolates participants within that parent chat when available when
``group_sessions_per_user`` is enabled.
- thread_id differentiates threads within that parent chat.
- Without participant identifiers, messages fall back to one shared session per chat.
- Without participant identifiers, or when isolation is disabled, messages fall back to one
shared session per chat.
- Without identifiers, messages fall back to one session per platform/chat_type.
"""
platform = source.platform.value
@ -350,7 +352,7 @@ def build_session_key(source: SessionSource) -> str:
key_parts.append(source.chat_id)
if source.thread_id:
key_parts.append(source.thread_id)
if participant_id:
if group_sessions_per_user and participant_id:
key_parts.append(str(participant_id))
return ":".join(key_parts)
@ -432,7 +434,10 @@ class SessionStore:
def _generate_session_key(self, source: SessionSource) -> str:
"""Generate a session key from a source."""
return build_session_key(source)
return build_session_key(
source,
group_sessions_per_user=getattr(self.config, "group_sessions_per_user", True),
)
def _is_session_expired(self, entry: SessionEntry) -> bool:
"""Check if a session has expired based on its reset policy.