fix: address PR review round 5 — streaming guard, VC auth, history prefix, auto-TTS control
1. Gate _streaming_api_call to chat_completions mode only — Anthropic and
Codex fall back to _interruptible_api_call. Preserve Anthropic base_url
across all client rebuild paths (interrupt, fallback, 401 refresh).
2. Discord VC synthetic events now use chat_type="channel" instead of
defaulting to "dm" — prevents session bleed into DM context.
Authorization runs before echoing transcript. Sanitize @everyone/@here
in voice transcripts.
3. CLI voice prefix ("[Voice input...]") is now API-call-local only —
stripped from returned history so it never persists to session DB or
resumed sessions.
4. /voice off now disables base adapter auto-TTS via _auto_tts_disabled_chats
set — voice input no longer triggers TTS when voice mode is off.
This commit is contained in:
parent
35748a2fb0
commit
cc0a453476
5 changed files with 59 additions and 22 deletions
21
cli.py
21
cli.py
|
|
@ -4213,20 +4213,20 @@ class HermesCLI:
|
||||||
if text_queue is not None:
|
if text_queue is not None:
|
||||||
text_queue.put(delta)
|
text_queue.put(delta)
|
||||||
|
|
||||||
# When voice mode is active, prepend a brief instruction to the
|
# When voice mode is active, prepend a brief instruction so the
|
||||||
# user message so the model responds concisely. This avoids
|
# model responds concisely. The prefix is API-call-local only —
|
||||||
# modifying the system prompt (which would invalidate the prompt
|
# we strip it from the returned history so it never persists to
|
||||||
# cache). The original message in conversation_history stays clean.
|
# session DB or resumed sessions.
|
||||||
agent_message = message
|
_voice_prefix = ""
|
||||||
if self._voice_mode and isinstance(message, str):
|
if self._voice_mode and isinstance(message, str):
|
||||||
agent_message = (
|
_voice_prefix = (
|
||||||
"[Voice input — respond concisely and conversationally, "
|
"[Voice input — respond concisely and conversationally, "
|
||||||
"2-3 sentences max. No code blocks or markdown.] "
|
"2-3 sentences max. No code blocks or markdown.] "
|
||||||
+ message
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def run_agent():
|
def run_agent():
|
||||||
nonlocal result
|
nonlocal result
|
||||||
|
agent_message = _voice_prefix + message if _voice_prefix else message
|
||||||
result = self.agent.run_conversation(
|
result = self.agent.run_conversation(
|
||||||
user_message=agent_message,
|
user_message=agent_message,
|
||||||
conversation_history=self.conversation_history[:-1], # Exclude the message we just added
|
conversation_history=self.conversation_history[:-1], # Exclude the message we just added
|
||||||
|
|
@ -4298,6 +4298,13 @@ class HermesCLI:
|
||||||
# Update history with full conversation
|
# Update history with full conversation
|
||||||
self.conversation_history = result.get("messages", self.conversation_history) if result else self.conversation_history
|
self.conversation_history = result.get("messages", self.conversation_history) if result else self.conversation_history
|
||||||
|
|
||||||
|
# Strip voice prefix from history so it never persists
|
||||||
|
if _voice_prefix and self.conversation_history:
|
||||||
|
for msg in self.conversation_history:
|
||||||
|
if msg.get("role") == "user" and isinstance(msg.get("content"), str):
|
||||||
|
if msg["content"].startswith(_voice_prefix):
|
||||||
|
msg["content"] = msg["content"][len(_voice_prefix):]
|
||||||
|
|
||||||
# Get the final response
|
# Get the final response
|
||||||
response = result.get("final_response", "") if result else ""
|
response = result.get("final_response", "") if result else ""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -351,6 +351,8 @@ class BasePlatformAdapter(ABC):
|
||||||
# Key: session_key (e.g., chat_id), Value: (event, asyncio.Event for interrupt)
|
# Key: session_key (e.g., chat_id), Value: (event, asyncio.Event for interrupt)
|
||||||
self._active_sessions: Dict[str, asyncio.Event] = {}
|
self._active_sessions: Dict[str, asyncio.Event] = {}
|
||||||
self._pending_messages: Dict[str, MessageEvent] = {}
|
self._pending_messages: Dict[str, MessageEvent] = {}
|
||||||
|
# Chats where auto-TTS on voice input is disabled (set by /voice off)
|
||||||
|
self._auto_tts_disabled_chats: set = set()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
|
|
@ -733,8 +735,12 @@ class BasePlatformAdapter(ABC):
|
||||||
logger.info("[%s] extract_images found %d image(s) in response (%d chars)", self.name, len(images), len(response))
|
logger.info("[%s] extract_images found %d image(s) in response (%d chars)", self.name, len(images), len(response))
|
||||||
|
|
||||||
# Auto-TTS: if voice message, generate audio FIRST (before sending text)
|
# Auto-TTS: if voice message, generate audio FIRST (before sending text)
|
||||||
|
# Skipped when the chat has voice mode disabled (/voice off)
|
||||||
_tts_path = None
|
_tts_path = None
|
||||||
if event.message_type == MessageType.VOICE and text_content and not media_files:
|
if (event.message_type == MessageType.VOICE
|
||||||
|
and text_content
|
||||||
|
and not media_files
|
||||||
|
and event.source.chat_id not in self._auto_tts_disabled_chats):
|
||||||
try:
|
try:
|
||||||
from tools.tts_tool import text_to_speech_tool, check_tts_requirements
|
from tools.tts_tool import text_to_speech_tool, check_tts_requirements
|
||||||
if check_tts_requirements():
|
if check_tts_requirements():
|
||||||
|
|
|
||||||
|
|
@ -2119,9 +2119,13 @@ class GatewayRunner:
|
||||||
args = event.get_command_args().strip().lower()
|
args = event.get_command_args().strip().lower()
|
||||||
chat_id = event.source.chat_id
|
chat_id = event.source.chat_id
|
||||||
|
|
||||||
|
adapter = self.adapters.get(event.source.platform)
|
||||||
|
|
||||||
if args in ("on", "enable"):
|
if args in ("on", "enable"):
|
||||||
self._voice_mode[chat_id] = "voice_only"
|
self._voice_mode[chat_id] = "voice_only"
|
||||||
self._save_voice_modes()
|
self._save_voice_modes()
|
||||||
|
if adapter:
|
||||||
|
adapter._auto_tts_disabled_chats.discard(chat_id)
|
||||||
return (
|
return (
|
||||||
"Voice mode enabled.\n"
|
"Voice mode enabled.\n"
|
||||||
"I'll reply with voice when you send voice messages.\n"
|
"I'll reply with voice when you send voice messages.\n"
|
||||||
|
|
@ -2130,10 +2134,14 @@ class GatewayRunner:
|
||||||
elif args in ("off", "disable"):
|
elif args in ("off", "disable"):
|
||||||
self._voice_mode.pop(chat_id, None)
|
self._voice_mode.pop(chat_id, None)
|
||||||
self._save_voice_modes()
|
self._save_voice_modes()
|
||||||
|
if adapter:
|
||||||
|
adapter._auto_tts_disabled_chats.add(chat_id)
|
||||||
return "Voice mode disabled. Text-only replies."
|
return "Voice mode disabled. Text-only replies."
|
||||||
elif args == "tts":
|
elif args == "tts":
|
||||||
self._voice_mode[chat_id] = "all"
|
self._voice_mode[chat_id] = "all"
|
||||||
self._save_voice_modes()
|
self._save_voice_modes()
|
||||||
|
if adapter:
|
||||||
|
adapter._auto_tts_disabled_chats.discard(chat_id)
|
||||||
return (
|
return (
|
||||||
"Auto-TTS enabled.\n"
|
"Auto-TTS enabled.\n"
|
||||||
"All replies will include a voice message."
|
"All replies will include a voice message."
|
||||||
|
|
@ -2171,10 +2179,14 @@ class GatewayRunner:
|
||||||
if current == "off":
|
if current == "off":
|
||||||
self._voice_mode[chat_id] = "voice_only"
|
self._voice_mode[chat_id] = "voice_only"
|
||||||
self._save_voice_modes()
|
self._save_voice_modes()
|
||||||
|
if adapter:
|
||||||
|
adapter._auto_tts_disabled_chats.discard(chat_id)
|
||||||
return "Voice mode enabled."
|
return "Voice mode enabled."
|
||||||
else:
|
else:
|
||||||
self._voice_mode.pop(chat_id, None)
|
self._voice_mode.pop(chat_id, None)
|
||||||
self._save_voice_modes()
|
self._save_voice_modes()
|
||||||
|
if adapter:
|
||||||
|
adapter._auto_tts_disabled_chats.add(chat_id)
|
||||||
return "Voice mode disabled."
|
return "Voice mode disabled."
|
||||||
|
|
||||||
async def _handle_voice_channel_join(self, event: MessageEvent) -> str:
|
async def _handle_voice_channel_join(self, event: MessageEvent) -> str:
|
||||||
|
|
@ -2211,6 +2223,7 @@ class GatewayRunner:
|
||||||
adapter._voice_text_channels[guild_id] = int(event.source.chat_id)
|
adapter._voice_text_channels[guild_id] = int(event.source.chat_id)
|
||||||
self._voice_mode[event.source.chat_id] = "all"
|
self._voice_mode[event.source.chat_id] = "all"
|
||||||
self._save_voice_modes()
|
self._save_voice_modes()
|
||||||
|
adapter._auto_tts_disabled_chats.discard(event.source.chat_id)
|
||||||
return (
|
return (
|
||||||
f"Joined voice channel **{voice_channel.name}**.\n"
|
f"Joined voice channel **{voice_channel.name}**.\n"
|
||||||
f"I'll speak my replies and listen to you. Use /voice leave to disconnect."
|
f"I'll speak my replies and listen to you. Use /voice leave to disconnect."
|
||||||
|
|
@ -2265,21 +2278,28 @@ class GatewayRunner:
|
||||||
if not text_ch_id:
|
if not text_ch_id:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Show transcript in text channel
|
# Check authorization before processing voice input
|
||||||
try:
|
|
||||||
channel = adapter._client.get_channel(text_ch_id)
|
|
||||||
if channel:
|
|
||||||
await channel.send(f"**[Voice]** <@{user_id}>: {transcript}")
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Build a synthetic MessageEvent and feed through the normal pipeline
|
|
||||||
source = SessionSource(
|
source = SessionSource(
|
||||||
platform=Platform.DISCORD,
|
platform=Platform.DISCORD,
|
||||||
chat_id=str(text_ch_id),
|
chat_id=str(text_ch_id),
|
||||||
user_id=str(user_id),
|
user_id=str(user_id),
|
||||||
user_name=str(user_id),
|
user_name=str(user_id),
|
||||||
|
chat_type="channel",
|
||||||
)
|
)
|
||||||
|
if not self._is_user_authorized(source):
|
||||||
|
logger.debug("Unauthorized voice input from user %d, ignoring", user_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Show transcript in text channel (after auth, with mention sanitization)
|
||||||
|
try:
|
||||||
|
channel = adapter._client.get_channel(text_ch_id)
|
||||||
|
if channel:
|
||||||
|
safe_text = transcript[:2000].replace("@everyone", "@\u200beveryone").replace("@here", "@\u200bhere")
|
||||||
|
await channel.send(f"**[Voice]** <@{user_id}>: {safe_text}")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Build a synthetic MessageEvent and feed through the normal pipeline
|
||||||
# Use SimpleNamespace as raw_message so _get_guild_id() can extract
|
# Use SimpleNamespace as raw_message so _get_guild_id() can extract
|
||||||
# guild_id and _send_voice_reply() plays audio in the voice channel.
|
# guild_id and _send_voice_reply() plays audio in the voice channel.
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
|
|
||||||
12
run_agent.py
12
run_agent.py
|
|
@ -508,6 +508,7 @@ class AIAgent:
|
||||||
from agent.anthropic_adapter import build_anthropic_client, resolve_anthropic_token
|
from agent.anthropic_adapter import build_anthropic_client, resolve_anthropic_token
|
||||||
effective_key = api_key or resolve_anthropic_token() or ""
|
effective_key = api_key or resolve_anthropic_token() or ""
|
||||||
self._anthropic_api_key = effective_key
|
self._anthropic_api_key = effective_key
|
||||||
|
self._anthropic_base_url = base_url
|
||||||
self._anthropic_client = build_anthropic_client(effective_key, base_url)
|
self._anthropic_client = build_anthropic_client(effective_key, base_url)
|
||||||
# No OpenAI client needed for Anthropic mode
|
# No OpenAI client needed for Anthropic mode
|
||||||
self.client = None
|
self.client = None
|
||||||
|
|
@ -2625,7 +2626,7 @@ class AIAgent:
|
||||||
try:
|
try:
|
||||||
if self.api_mode == "anthropic_messages":
|
if self.api_mode == "anthropic_messages":
|
||||||
from agent.anthropic_adapter import build_anthropic_client
|
from agent.anthropic_adapter import build_anthropic_client
|
||||||
self._anthropic_client = build_anthropic_client(self._anthropic_api_key)
|
self._anthropic_client = build_anthropic_client(self._anthropic_api_key, getattr(self, "_anthropic_base_url", None))
|
||||||
else:
|
else:
|
||||||
self.client = OpenAI(**self._client_kwargs)
|
self.client = OpenAI(**self._client_kwargs)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
@ -2757,7 +2758,7 @@ class AIAgent:
|
||||||
try:
|
try:
|
||||||
if self.api_mode == "anthropic_messages":
|
if self.api_mode == "anthropic_messages":
|
||||||
from agent.anthropic_adapter import build_anthropic_client
|
from agent.anthropic_adapter import build_anthropic_client
|
||||||
self._anthropic_client = build_anthropic_client(self._anthropic_api_key)
|
self._anthropic_client = build_anthropic_client(self._anthropic_api_key, getattr(self, "_anthropic_base_url", None))
|
||||||
else:
|
else:
|
||||||
self.client = OpenAI(**self._client_kwargs)
|
self.client = OpenAI(**self._client_kwargs)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
@ -2823,7 +2824,8 @@ class AIAgent:
|
||||||
from agent.anthropic_adapter import build_anthropic_client, resolve_anthropic_token
|
from agent.anthropic_adapter import build_anthropic_client, resolve_anthropic_token
|
||||||
effective_key = fb_client.api_key or resolve_anthropic_token() or ""
|
effective_key = fb_client.api_key or resolve_anthropic_token() or ""
|
||||||
self._anthropic_api_key = effective_key
|
self._anthropic_api_key = effective_key
|
||||||
self._anthropic_client = build_anthropic_client(effective_key)
|
self._anthropic_base_url = getattr(fb_client, "base_url", None)
|
||||||
|
self._anthropic_client = build_anthropic_client(effective_key, self._anthropic_base_url)
|
||||||
self.client = None
|
self.client = None
|
||||||
self._client_kwargs = {}
|
self._client_kwargs = {}
|
||||||
else:
|
else:
|
||||||
|
|
@ -4436,7 +4438,7 @@ class AIAgent:
|
||||||
self._dump_api_request_debug(api_kwargs, reason="preflight")
|
self._dump_api_request_debug(api_kwargs, reason="preflight")
|
||||||
|
|
||||||
cb = getattr(self, "_stream_callback", None)
|
cb = getattr(self, "_stream_callback", None)
|
||||||
if cb is not None:
|
if cb is not None and self.api_mode == "chat_completions":
|
||||||
response = self._streaming_api_call(api_kwargs, cb)
|
response = self._streaming_api_call(api_kwargs, cb)
|
||||||
else:
|
else:
|
||||||
response = self._interruptible_api_call(api_kwargs)
|
response = self._interruptible_api_call(api_kwargs)
|
||||||
|
|
@ -4770,7 +4772,7 @@ class AIAgent:
|
||||||
new_token = resolve_anthropic_token()
|
new_token = resolve_anthropic_token()
|
||||||
if new_token and new_token != self._anthropic_api_key:
|
if new_token and new_token != self._anthropic_api_key:
|
||||||
self._anthropic_api_key = new_token
|
self._anthropic_api_key = new_token
|
||||||
self._anthropic_client = build_anthropic_client(new_token)
|
self._anthropic_client = build_anthropic_client(new_token, getattr(self, "_anthropic_base_url", None))
|
||||||
print(f"{self.log_prefix}🔐 Anthropic credentials refreshed after 401. Retrying request...")
|
print(f"{self.log_prefix}🔐 Anthropic credentials refreshed after 401. Retrying request...")
|
||||||
continue
|
continue
|
||||||
# Credential refresh didn't help — show diagnostic info
|
# Credential refresh didn't help — show diagnostic info
|
||||||
|
|
|
||||||
|
|
@ -38,6 +38,7 @@ def _make_runner(tmp_path):
|
||||||
runner._VOICE_MODE_PATH = tmp_path / "gateway_voice_mode.json"
|
runner._VOICE_MODE_PATH = tmp_path / "gateway_voice_mode.json"
|
||||||
runner._session_db = None
|
runner._session_db = None
|
||||||
runner.session_store = MagicMock()
|
runner.session_store = MagicMock()
|
||||||
|
runner._is_user_authorized = lambda source: True
|
||||||
return runner
|
return runner
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -731,6 +732,7 @@ class TestVoiceChannelCommands:
|
||||||
assert event.text == "Hello from VC"
|
assert event.text == "Hello from VC"
|
||||||
assert event.message_type == MessageType.VOICE
|
assert event.message_type == MessageType.VOICE
|
||||||
assert event.source.chat_id == "123"
|
assert event.source.chat_id == "123"
|
||||||
|
assert event.source.chat_type == "channel"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_input_posts_transcript_in_text_channel(self, runner):
|
async def test_input_posts_transcript_in_text_channel(self, runner):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue