fix(gateway): honor stt.enabled false for voice transcription
- bridge stt.enabled from config.yaml into gateway runtime config - preserve the flag in GatewayConfig serialization - skip gateway voice transcription when STT is disabled - add regression tests for config loading and disabled transcription flow
This commit is contained in:
parent
2119b68799
commit
c36136084a
3 changed files with 86 additions and 1 deletions
|
|
@ -21,6 +21,17 @@ from hermes_cli.config import get_hermes_home
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_bool(value: Any, default: bool = True) -> bool:
|
||||||
|
"""Coerce bool-ish config values, preserving a caller-provided default."""
|
||||||
|
if value is None:
|
||||||
|
return default
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return value
|
||||||
|
if isinstance(value, str):
|
||||||
|
return value.strip().lower() in ("true", "1", "yes", "on")
|
||||||
|
return bool(value)
|
||||||
|
|
||||||
|
|
||||||
class Platform(Enum):
|
class Platform(Enum):
|
||||||
"""Supported messaging platforms."""
|
"""Supported messaging platforms."""
|
||||||
LOCAL = "local"
|
LOCAL = "local"
|
||||||
|
|
@ -160,6 +171,9 @@ class GatewayConfig:
|
||||||
|
|
||||||
# Delivery settings
|
# Delivery settings
|
||||||
always_log_local: bool = True # Always save cron outputs to local files
|
always_log_local: bool = True # Always save cron outputs to local files
|
||||||
|
|
||||||
|
# STT settings
|
||||||
|
stt_enabled: bool = True # Whether to auto-transcribe inbound voice messages
|
||||||
|
|
||||||
def get_connected_platforms(self) -> List[Platform]:
|
def get_connected_platforms(self) -> List[Platform]:
|
||||||
"""Return list of platforms that are enabled and configured."""
|
"""Return list of platforms that are enabled and configured."""
|
||||||
|
|
@ -224,6 +238,7 @@ class GatewayConfig:
|
||||||
"quick_commands": self.quick_commands,
|
"quick_commands": self.quick_commands,
|
||||||
"sessions_dir": str(self.sessions_dir),
|
"sessions_dir": str(self.sessions_dir),
|
||||||
"always_log_local": self.always_log_local,
|
"always_log_local": self.always_log_local,
|
||||||
|
"stt_enabled": self.stt_enabled,
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -260,6 +275,10 @@ class GatewayConfig:
|
||||||
if not isinstance(quick_commands, dict):
|
if not isinstance(quick_commands, dict):
|
||||||
quick_commands = {}
|
quick_commands = {}
|
||||||
|
|
||||||
|
stt_enabled = data.get("stt_enabled")
|
||||||
|
if stt_enabled is None:
|
||||||
|
stt_enabled = data.get("stt", {}).get("enabled") if isinstance(data.get("stt"), dict) else None
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
platforms=platforms,
|
platforms=platforms,
|
||||||
default_reset_policy=default_policy,
|
default_reset_policy=default_policy,
|
||||||
|
|
@ -269,6 +288,7 @@ class GatewayConfig:
|
||||||
quick_commands=quick_commands,
|
quick_commands=quick_commands,
|
||||||
sessions_dir=sessions_dir,
|
sessions_dir=sessions_dir,
|
||||||
always_log_local=data.get("always_log_local", True),
|
always_log_local=data.get("always_log_local", True),
|
||||||
|
stt_enabled=_coerce_bool(stt_enabled, True),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -318,6 +338,12 @@ def load_gateway_config() -> GatewayConfig:
|
||||||
else:
|
else:
|
||||||
logger.warning("Ignoring invalid quick_commands in config.yaml (expected mapping, got %s)", type(qc).__name__)
|
logger.warning("Ignoring invalid quick_commands in config.yaml (expected mapping, got %s)", type(qc).__name__)
|
||||||
|
|
||||||
|
# Bridge STT enable/disable from config.yaml into gateway runtime.
|
||||||
|
# This keeps the gateway aligned with the user-facing config source.
|
||||||
|
stt_cfg = yaml_cfg.get("stt")
|
||||||
|
if isinstance(stt_cfg, dict) and "enabled" in stt_cfg:
|
||||||
|
config.stt_enabled = _coerce_bool(stt_cfg.get("enabled"), True)
|
||||||
|
|
||||||
# Bridge discord settings from config.yaml to env vars
|
# Bridge discord settings from config.yaml to env vars
|
||||||
# (env vars take precedence — only set if not already defined)
|
# (env vars take precedence — only set if not already defined)
|
||||||
discord_cfg = yaml_cfg.get("discord", {})
|
discord_cfg = yaml_cfg.get("discord", {})
|
||||||
|
|
|
||||||
|
|
@ -3512,7 +3512,7 @@ class GatewayRunner:
|
||||||
audio_paths: List[str],
|
audio_paths: List[str],
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Auto-transcribe user voice/audio messages using OpenAI Whisper API
|
Auto-transcribe user voice/audio messages using the configured STT provider
|
||||||
and prepend the transcript to the message text.
|
and prepend the transcript to the message text.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -3522,6 +3522,12 @@ class GatewayRunner:
|
||||||
Returns:
|
Returns:
|
||||||
The enriched message string with transcriptions prepended.
|
The enriched message string with transcriptions prepended.
|
||||||
"""
|
"""
|
||||||
|
if not getattr(self.config, "stt_enabled", True):
|
||||||
|
disabled_note = "[The user sent voice message(s), but transcription is disabled in config.]"
|
||||||
|
if user_text:
|
||||||
|
return f"{disabled_note}\n\n{user_text}"
|
||||||
|
return disabled_note
|
||||||
|
|
||||||
from tools.transcription_tools import transcribe_audio, get_stt_model_from_config
|
from tools.transcription_tools import transcribe_audio, get_stt_model_from_config
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
|
|
||||||
53
tests/gateway/test_stt_config.py
Normal file
53
tests/gateway/test_stt_config.py
Normal file
|
|
@ -0,0 +1,53 @@
|
||||||
|
"""Gateway STT config tests — honor stt.enabled: false from config.yaml."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from gateway.config import GatewayConfig, load_gateway_config
|
||||||
|
|
||||||
|
|
||||||
|
def test_gateway_config_stt_disabled_from_dict_nested():
|
||||||
|
config = GatewayConfig.from_dict({"stt": {"enabled": False}})
|
||||||
|
assert config.stt_enabled is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_gateway_config_bridges_stt_enabled_from_config_yaml(tmp_path, monkeypatch):
|
||||||
|
hermes_home = tmp_path / ".hermes"
|
||||||
|
hermes_home.mkdir()
|
||||||
|
(hermes_home / "config.yaml").write_text(
|
||||||
|
yaml.dump({"stt": {"enabled": False}}),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||||
|
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||||
|
|
||||||
|
config = load_gateway_config()
|
||||||
|
|
||||||
|
assert config.stt_enabled is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enrich_message_with_transcription_skips_when_stt_disabled():
|
||||||
|
from gateway.run import GatewayRunner
|
||||||
|
|
||||||
|
runner = GatewayRunner.__new__(GatewayRunner)
|
||||||
|
runner.config = GatewayConfig(stt_enabled=False)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"tools.transcription_tools.transcribe_audio",
|
||||||
|
side_effect=AssertionError("transcribe_audio should not be called when STT is disabled"),
|
||||||
|
), patch(
|
||||||
|
"tools.transcription_tools.get_stt_model_from_config",
|
||||||
|
return_value=None,
|
||||||
|
):
|
||||||
|
result = await runner._enrich_message_with_transcription(
|
||||||
|
"caption",
|
||||||
|
["/tmp/voice.ogg"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "transcription is disabled" in result.lower()
|
||||||
|
assert "caption" in result
|
||||||
Loading…
Add table
Add a link
Reference in a new issue