diff --git a/gateway/config.py b/gateway/config.py index 47c739e9..2b187c52 100644 --- a/gateway/config.py +++ b/gateway/config.py @@ -21,6 +21,17 @@ from hermes_cli.config import get_hermes_home logger = logging.getLogger(__name__) +def _coerce_bool(value: Any, default: bool = True) -> bool: + """Coerce bool-ish config values, preserving a caller-provided default.""" + if value is None: + return default + if isinstance(value, bool): + return value + if isinstance(value, str): + return value.strip().lower() in ("true", "1", "yes", "on") + return bool(value) + + class Platform(Enum): """Supported messaging platforms.""" LOCAL = "local" @@ -160,6 +171,9 @@ class GatewayConfig: # Delivery settings always_log_local: bool = True # Always save cron outputs to local files + + # STT settings + stt_enabled: bool = True # Whether to auto-transcribe inbound voice messages def get_connected_platforms(self) -> List[Platform]: """Return list of platforms that are enabled and configured.""" @@ -224,6 +238,7 @@ class GatewayConfig: "quick_commands": self.quick_commands, "sessions_dir": str(self.sessions_dir), "always_log_local": self.always_log_local, + "stt_enabled": self.stt_enabled, } @classmethod @@ -260,6 +275,10 @@ class GatewayConfig: if not isinstance(quick_commands, dict): quick_commands = {} + stt_enabled = data.get("stt_enabled") + if stt_enabled is None: + stt_enabled = data.get("stt", {}).get("enabled") if isinstance(data.get("stt"), dict) else None + return cls( platforms=platforms, default_reset_policy=default_policy, @@ -269,6 +288,7 @@ class GatewayConfig: quick_commands=quick_commands, sessions_dir=sessions_dir, always_log_local=data.get("always_log_local", True), + stt_enabled=_coerce_bool(stt_enabled, True), ) @@ -318,6 +338,12 @@ def load_gateway_config() -> GatewayConfig: else: logger.warning("Ignoring invalid quick_commands in config.yaml (expected mapping, got %s)", type(qc).__name__) + # Bridge STT enable/disable from config.yaml into gateway runtime. + # This keeps the gateway aligned with the user-facing config source. + stt_cfg = yaml_cfg.get("stt") + if isinstance(stt_cfg, dict) and "enabled" in stt_cfg: + config.stt_enabled = _coerce_bool(stt_cfg.get("enabled"), True) + # Bridge discord settings from config.yaml to env vars # (env vars take precedence — only set if not already defined) discord_cfg = yaml_cfg.get("discord", {}) diff --git a/gateway/run.py b/gateway/run.py index e973852b..f955573c 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -3512,7 +3512,7 @@ class GatewayRunner: audio_paths: List[str], ) -> str: """ - Auto-transcribe user voice/audio messages using OpenAI Whisper API + Auto-transcribe user voice/audio messages using the configured STT provider and prepend the transcript to the message text. Args: @@ -3522,6 +3522,12 @@ class GatewayRunner: Returns: The enriched message string with transcriptions prepended. """ + if not getattr(self.config, "stt_enabled", True): + disabled_note = "[The user sent voice message(s), but transcription is disabled in config.]" + if user_text: + return f"{disabled_note}\n\n{user_text}" + return disabled_note + from tools.transcription_tools import transcribe_audio, get_stt_model_from_config import asyncio diff --git a/tests/gateway/test_stt_config.py b/tests/gateway/test_stt_config.py new file mode 100644 index 00000000..d5a9fc55 --- /dev/null +++ b/tests/gateway/test_stt_config.py @@ -0,0 +1,53 @@ +"""Gateway STT config tests — honor stt.enabled: false from config.yaml.""" + +from pathlib import Path +from unittest.mock import AsyncMock, patch + +import pytest +import yaml + +from gateway.config import GatewayConfig, load_gateway_config + + +def test_gateway_config_stt_disabled_from_dict_nested(): + config = GatewayConfig.from_dict({"stt": {"enabled": False}}) + assert config.stt_enabled is False + + +def test_load_gateway_config_bridges_stt_enabled_from_config_yaml(tmp_path, monkeypatch): + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir() + (hermes_home / "config.yaml").write_text( + yaml.dump({"stt": {"enabled": False}}), + encoding="utf-8", + ) + + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + monkeypatch.setattr(Path, "home", lambda: tmp_path) + + config = load_gateway_config() + + assert config.stt_enabled is False + + +@pytest.mark.asyncio +async def test_enrich_message_with_transcription_skips_when_stt_disabled(): + from gateway.run import GatewayRunner + + runner = GatewayRunner.__new__(GatewayRunner) + runner.config = GatewayConfig(stt_enabled=False) + + with patch( + "tools.transcription_tools.transcribe_audio", + side_effect=AssertionError("transcribe_audio should not be called when STT is disabled"), + ), patch( + "tools.transcription_tools.get_stt_model_from_config", + return_value=None, + ): + result = await runner._enrich_message_with_transcription( + "caption", + ["/tmp/voice.ogg"], + ) + + assert "transcription is disabled" in result.lower() + assert "caption" in result