fix(gateway): honor stt.enabled false for voice transcription
- bridge stt.enabled from config.yaml into gateway runtime config - preserve the flag in GatewayConfig serialization - skip gateway voice transcription when STT is disabled - add regression tests for config loading and disabled transcription flow
This commit is contained in:
parent
2119b68799
commit
c36136084a
3 changed files with 86 additions and 1 deletions
|
|
@ -21,6 +21,17 @@ from hermes_cli.config import get_hermes_home
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _coerce_bool(value: Any, default: bool = True) -> bool:
|
||||
"""Coerce bool-ish config values, preserving a caller-provided default."""
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return value.strip().lower() in ("true", "1", "yes", "on")
|
||||
return bool(value)
|
||||
|
||||
|
||||
class Platform(Enum):
|
||||
"""Supported messaging platforms."""
|
||||
LOCAL = "local"
|
||||
|
|
@ -160,6 +171,9 @@ class GatewayConfig:
|
|||
|
||||
# Delivery settings
|
||||
always_log_local: bool = True # Always save cron outputs to local files
|
||||
|
||||
# STT settings
|
||||
stt_enabled: bool = True # Whether to auto-transcribe inbound voice messages
|
||||
|
||||
def get_connected_platforms(self) -> List[Platform]:
|
||||
"""Return list of platforms that are enabled and configured."""
|
||||
|
|
@ -224,6 +238,7 @@ class GatewayConfig:
|
|||
"quick_commands": self.quick_commands,
|
||||
"sessions_dir": str(self.sessions_dir),
|
||||
"always_log_local": self.always_log_local,
|
||||
"stt_enabled": self.stt_enabled,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
|
@ -260,6 +275,10 @@ class GatewayConfig:
|
|||
if not isinstance(quick_commands, dict):
|
||||
quick_commands = {}
|
||||
|
||||
stt_enabled = data.get("stt_enabled")
|
||||
if stt_enabled is None:
|
||||
stt_enabled = data.get("stt", {}).get("enabled") if isinstance(data.get("stt"), dict) else None
|
||||
|
||||
return cls(
|
||||
platforms=platforms,
|
||||
default_reset_policy=default_policy,
|
||||
|
|
@ -269,6 +288,7 @@ class GatewayConfig:
|
|||
quick_commands=quick_commands,
|
||||
sessions_dir=sessions_dir,
|
||||
always_log_local=data.get("always_log_local", True),
|
||||
stt_enabled=_coerce_bool(stt_enabled, True),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -318,6 +338,12 @@ def load_gateway_config() -> GatewayConfig:
|
|||
else:
|
||||
logger.warning("Ignoring invalid quick_commands in config.yaml (expected mapping, got %s)", type(qc).__name__)
|
||||
|
||||
# Bridge STT enable/disable from config.yaml into gateway runtime.
|
||||
# This keeps the gateway aligned with the user-facing config source.
|
||||
stt_cfg = yaml_cfg.get("stt")
|
||||
if isinstance(stt_cfg, dict) and "enabled" in stt_cfg:
|
||||
config.stt_enabled = _coerce_bool(stt_cfg.get("enabled"), True)
|
||||
|
||||
# Bridge discord settings from config.yaml to env vars
|
||||
# (env vars take precedence — only set if not already defined)
|
||||
discord_cfg = yaml_cfg.get("discord", {})
|
||||
|
|
|
|||
|
|
@ -3512,7 +3512,7 @@ class GatewayRunner:
|
|||
audio_paths: List[str],
|
||||
) -> str:
|
||||
"""
|
||||
Auto-transcribe user voice/audio messages using OpenAI Whisper API
|
||||
Auto-transcribe user voice/audio messages using the configured STT provider
|
||||
and prepend the transcript to the message text.
|
||||
|
||||
Args:
|
||||
|
|
@ -3522,6 +3522,12 @@ class GatewayRunner:
|
|||
Returns:
|
||||
The enriched message string with transcriptions prepended.
|
||||
"""
|
||||
if not getattr(self.config, "stt_enabled", True):
|
||||
disabled_note = "[The user sent voice message(s), but transcription is disabled in config.]"
|
||||
if user_text:
|
||||
return f"{disabled_note}\n\n{user_text}"
|
||||
return disabled_note
|
||||
|
||||
from tools.transcription_tools import transcribe_audio, get_stt_model_from_config
|
||||
import asyncio
|
||||
|
||||
|
|
|
|||
53
tests/gateway/test_stt_config.py
Normal file
53
tests/gateway/test_stt_config.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
"""Gateway STT config tests — honor stt.enabled: false from config.yaml."""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from gateway.config import GatewayConfig, load_gateway_config
|
||||
|
||||
|
||||
def test_gateway_config_stt_disabled_from_dict_nested():
|
||||
config = GatewayConfig.from_dict({"stt": {"enabled": False}})
|
||||
assert config.stt_enabled is False
|
||||
|
||||
|
||||
def test_load_gateway_config_bridges_stt_enabled_from_config_yaml(tmp_path, monkeypatch):
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
(hermes_home / "config.yaml").write_text(
|
||||
yaml.dump({"stt": {"enabled": False}}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
config = load_gateway_config()
|
||||
|
||||
assert config.stt_enabled is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enrich_message_with_transcription_skips_when_stt_disabled():
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = GatewayRunner.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(stt_enabled=False)
|
||||
|
||||
with patch(
|
||||
"tools.transcription_tools.transcribe_audio",
|
||||
side_effect=AssertionError("transcribe_audio should not be called when STT is disabled"),
|
||||
), patch(
|
||||
"tools.transcription_tools.get_stt_model_from_config",
|
||||
return_value=None,
|
||||
):
|
||||
result = await runner._enrich_message_with_transcription(
|
||||
"caption",
|
||||
["/tmp/voice.ogg"],
|
||||
)
|
||||
|
||||
assert "transcription is disabled" in result.lower()
|
||||
assert "caption" in result
|
||||
Loading…
Add table
Add a link
Reference in a new issue