Merge pull request #1538 from NousResearch/hermes/hermes-a098c323
feat: unified streaming infrastructure — real-time token delivery for CLI + gateway
This commit is contained in:
commit
6c84e26e70
11 changed files with 1413 additions and 155 deletions
|
|
@ -355,6 +355,19 @@ session_reset:
|
|||
# explicitly want one shared "room brain" per group/channel.
|
||||
group_sessions_per_user: true
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Gateway Streaming
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Stream tokens to messaging platforms in real-time. The bot sends a message
|
||||
# on first token, then progressively edits it as more tokens arrive.
|
||||
# Disabled by default — enable to try the streaming UX on Telegram/Discord/Slack.
|
||||
streaming:
|
||||
enabled: false
|
||||
# transport: edit # "edit" = progressive editMessageText
|
||||
# edit_interval: 0.3 # seconds between message edits
|
||||
# buffer_threshold: 40 # chars before forcing an edit flush
|
||||
# cursor: " ▉" # cursor shown during streaming
|
||||
|
||||
# =============================================================================
|
||||
# Skills Configuration
|
||||
# =============================================================================
|
||||
|
|
@ -716,6 +729,12 @@ display:
|
|||
# Toggle at runtime with /reasoning show or /reasoning hide.
|
||||
show_reasoning: false
|
||||
|
||||
# Stream tokens to the terminal as they arrive instead of waiting for the
|
||||
# full response. The response box opens on first token and text appears
|
||||
# line-by-line. Tool calls are still captured silently.
|
||||
# Disabled by default — enable to try the streaming UX.
|
||||
streaming: false
|
||||
|
||||
# ───────────────────────────────────────────────────────────────────────────
|
||||
# Skin / Theme
|
||||
# ───────────────────────────────────────────────────────────────────────────
|
||||
|
|
|
|||
203
cli.py
203
cli.py
|
|
@ -210,6 +210,7 @@ def load_cli_config() -> Dict[str, Any]:
|
|||
"compact": False,
|
||||
"resume_display": "full",
|
||||
"show_reasoning": False,
|
||||
"streaming": False,
|
||||
"show_cost": False,
|
||||
"skin": "default",
|
||||
},
|
||||
|
|
@ -1034,6 +1035,14 @@ class HermesCLI:
|
|||
self.show_cost = CLI_CONFIG["display"].get("show_cost", False)
|
||||
self.verbose = verbose if verbose is not None else (self.tool_progress_mode == "verbose")
|
||||
|
||||
# streaming: stream tokens to the terminal as they arrive (display.streaming in config.yaml)
|
||||
self.streaming_enabled = CLI_CONFIG["display"].get("streaming", False)
|
||||
|
||||
# Streaming display state
|
||||
self._stream_buf = "" # Partial line buffer for line-buffered rendering
|
||||
self._stream_started = False # True once first delta arrives
|
||||
self._stream_box_opened = False # True once the response box header is printed
|
||||
|
||||
# Configuration - priority: CLI args > env vars > config file
|
||||
# Model comes from: CLI arg or config.yaml (single source of truth).
|
||||
# LLM_MODEL/OPENAI_MODEL env vars are NOT checked — config.yaml is
|
||||
|
|
@ -1454,6 +1463,177 @@ class HermesCLI:
|
|||
self._spinner_text = text or ""
|
||||
self._invalidate()
|
||||
|
||||
# ── Streaming display ────────────────────────────────────────────────
|
||||
|
||||
def _stream_reasoning_delta(self, text: str) -> None:
|
||||
"""Stream reasoning/thinking tokens into a dim box above the response.
|
||||
|
||||
Opens a dim reasoning box on first token, streams line-by-line.
|
||||
The box is closed automatically when content tokens start arriving
|
||||
(via _stream_delta → _emit_stream_text).
|
||||
"""
|
||||
if not text:
|
||||
return
|
||||
|
||||
# Open reasoning box on first reasoning token
|
||||
if not getattr(self, "_reasoning_box_opened", False):
|
||||
self._reasoning_box_opened = True
|
||||
w = shutil.get_terminal_size().columns
|
||||
r_label = " Reasoning "
|
||||
r_fill = w - 2 - len(r_label)
|
||||
_cprint(f"\n{_DIM}┌─{r_label}{'─' * max(r_fill - 1, 0)}┐{_RST}")
|
||||
|
||||
self._reasoning_buf = getattr(self, "_reasoning_buf", "") + text
|
||||
|
||||
# Emit complete lines
|
||||
while "\n" in self._reasoning_buf:
|
||||
line, self._reasoning_buf = self._reasoning_buf.split("\n", 1)
|
||||
_cprint(f"{_DIM}{line}{_RST}")
|
||||
|
||||
def _close_reasoning_box(self) -> None:
|
||||
"""Close the live reasoning box if it's open."""
|
||||
if getattr(self, "_reasoning_box_opened", False):
|
||||
# Flush remaining reasoning buffer
|
||||
buf = getattr(self, "_reasoning_buf", "")
|
||||
if buf:
|
||||
_cprint(f"{_DIM}{buf}{_RST}")
|
||||
self._reasoning_buf = ""
|
||||
w = shutil.get_terminal_size().columns
|
||||
_cprint(f"{_DIM}└{'─' * (w - 2)}┘{_RST}")
|
||||
self._reasoning_box_opened = False
|
||||
|
||||
def _stream_delta(self, text: str) -> None:
|
||||
"""Line-buffered streaming callback for real-time token rendering.
|
||||
|
||||
Receives text deltas from the agent as tokens arrive. Buffers
|
||||
partial lines and emits complete lines via _cprint to work
|
||||
reliably with prompt_toolkit's patch_stdout.
|
||||
|
||||
Reasoning/thinking blocks (<REASONING_SCRATCHPAD>, <think>, etc.)
|
||||
are suppressed during streaming since they'd display raw XML tags.
|
||||
The agent strips them from the final response anyway.
|
||||
"""
|
||||
if not text:
|
||||
return
|
||||
|
||||
self._stream_started = True
|
||||
|
||||
# ── Tag-based reasoning suppression ──
|
||||
# Track whether we're inside a reasoning/thinking block.
|
||||
# These tags are model-generated (system prompt tells the model
|
||||
# to use them) and get stripped from final_response. We must
|
||||
# suppress them during streaming too.
|
||||
_OPEN_TAGS = ("<REASONING_SCRATCHPAD>", "<think>", "<reasoning>", "<THINKING>")
|
||||
_CLOSE_TAGS = ("</REASONING_SCRATCHPAD>", "</think>", "</reasoning>", "</THINKING>")
|
||||
|
||||
# Append to a pre-filter buffer first
|
||||
self._stream_prefilt = getattr(self, "_stream_prefilt", "") + text
|
||||
|
||||
# Check if we're entering a reasoning block
|
||||
if not getattr(self, "_in_reasoning_block", False):
|
||||
for tag in _OPEN_TAGS:
|
||||
idx = self._stream_prefilt.find(tag)
|
||||
if idx != -1:
|
||||
# Emit everything before the tag
|
||||
before = self._stream_prefilt[:idx]
|
||||
if before:
|
||||
self._emit_stream_text(before)
|
||||
self._in_reasoning_block = True
|
||||
self._stream_prefilt = self._stream_prefilt[idx + len(tag):]
|
||||
break
|
||||
|
||||
# Could also be a partial open tag at the end — hold it back
|
||||
if not getattr(self, "_in_reasoning_block", False):
|
||||
# Check for partial tag match at the end
|
||||
safe = self._stream_prefilt
|
||||
for tag in _OPEN_TAGS:
|
||||
for i in range(1, len(tag)):
|
||||
if self._stream_prefilt.endswith(tag[:i]):
|
||||
safe = self._stream_prefilt[:-i]
|
||||
break
|
||||
if safe:
|
||||
self._emit_stream_text(safe)
|
||||
self._stream_prefilt = self._stream_prefilt[len(safe):]
|
||||
return
|
||||
|
||||
# Inside a reasoning block — look for close tag.
|
||||
# Keep accumulating _stream_prefilt because close tags can arrive
|
||||
# split across multiple tokens (e.g. "</REASONING_SCRATCH" + "PAD>...").
|
||||
if getattr(self, "_in_reasoning_block", False):
|
||||
for tag in _CLOSE_TAGS:
|
||||
idx = self._stream_prefilt.find(tag)
|
||||
if idx != -1:
|
||||
self._in_reasoning_block = False
|
||||
after = self._stream_prefilt[idx + len(tag):]
|
||||
self._stream_prefilt = ""
|
||||
# Process remaining text after close tag through full
|
||||
# filtering (it could contain another open tag)
|
||||
if after:
|
||||
self._stream_delta(after)
|
||||
return
|
||||
# Still inside reasoning block — keep only the tail that could
|
||||
# be a partial close tag prefix (save memory on long blocks).
|
||||
max_tag_len = max(len(t) for t in _CLOSE_TAGS)
|
||||
if len(self._stream_prefilt) > max_tag_len:
|
||||
self._stream_prefilt = self._stream_prefilt[-max_tag_len:]
|
||||
return
|
||||
|
||||
def _emit_stream_text(self, text: str) -> None:
|
||||
"""Emit filtered text to the streaming display."""
|
||||
if not text:
|
||||
return
|
||||
|
||||
# Close the live reasoning box before opening the response box
|
||||
self._close_reasoning_box()
|
||||
|
||||
# Open the response box header on the very first visible text
|
||||
if not self._stream_box_opened:
|
||||
# Strip leading whitespace/newlines before first visible content
|
||||
text = text.lstrip("\n")
|
||||
if not text:
|
||||
return
|
||||
self._stream_box_opened = True
|
||||
try:
|
||||
from hermes_cli.skin_engine import get_active_skin
|
||||
_skin = get_active_skin()
|
||||
label = _skin.get_branding("response_label", "⚕ Hermes")
|
||||
except Exception:
|
||||
label = "⚕ Hermes"
|
||||
w = shutil.get_terminal_size().columns
|
||||
fill = w - 2 - len(label)
|
||||
_cprint(f"\n{_GOLD}╭─{label}{'─' * max(fill - 1, 0)}╮{_RST}")
|
||||
|
||||
self._stream_buf += text
|
||||
|
||||
# Emit complete lines, keep partial remainder in buffer
|
||||
while "\n" in self._stream_buf:
|
||||
line, self._stream_buf = self._stream_buf.split("\n", 1)
|
||||
_cprint(line)
|
||||
|
||||
def _flush_stream(self) -> None:
|
||||
"""Emit any remaining partial line from the stream buffer and close the box."""
|
||||
# Close reasoning box if still open (in case no content tokens arrived)
|
||||
self._close_reasoning_box()
|
||||
|
||||
if self._stream_buf:
|
||||
_cprint(self._stream_buf)
|
||||
self._stream_buf = ""
|
||||
|
||||
# Close the response box
|
||||
if self._stream_box_opened:
|
||||
w = shutil.get_terminal_size().columns
|
||||
_cprint(f"{_GOLD}╰{'─' * (w - 2)}╯{_RST}")
|
||||
|
||||
def _reset_stream_state(self) -> None:
|
||||
"""Reset streaming state before each agent invocation."""
|
||||
self._stream_buf = ""
|
||||
self._stream_started = False
|
||||
self._stream_box_opened = False
|
||||
self._stream_prefilt = ""
|
||||
self._in_reasoning_block = False
|
||||
self._reasoning_box_opened = False
|
||||
self._reasoning_buf = ""
|
||||
|
||||
def _slow_command_status(self, command: str) -> str:
|
||||
"""Return a user-facing status message for slower slash commands."""
|
||||
cmd_lower = command.lower().strip()
|
||||
|
|
@ -1657,7 +1837,11 @@ class HermesCLI:
|
|||
platform="cli",
|
||||
session_db=self._session_db,
|
||||
clarify_callback=self._clarify_callback,
|
||||
reasoning_callback=self._on_reasoning if (self.show_reasoning or self.verbose) else None,
|
||||
reasoning_callback=(
|
||||
self._stream_reasoning_delta if (self.streaming_enabled and self.show_reasoning)
|
||||
else self._on_reasoning if (self.show_reasoning or self.verbose)
|
||||
else None
|
||||
),
|
||||
honcho_session_key=None, # resolved by run_agent via config sessions map / title
|
||||
fallback_model=self._fallback_model,
|
||||
thinking_callback=self._on_thinking,
|
||||
|
|
@ -1665,6 +1849,7 @@ class HermesCLI:
|
|||
checkpoint_max_snapshots=self.checkpoint_max_snapshots,
|
||||
pass_session_id=self.pass_session_id,
|
||||
tool_progress_callback=self._on_tool_progress,
|
||||
stream_delta_callback=self._stream_delta if self.streaming_enabled else None,
|
||||
)
|
||||
self._active_agent_route_signature = (
|
||||
effective_model,
|
||||
|
|
@ -4958,6 +5143,9 @@ class HermesCLI:
|
|||
# Run the conversation with interrupt monitoring
|
||||
result = None
|
||||
|
||||
# Reset streaming display state for this turn
|
||||
self._reset_stream_state()
|
||||
|
||||
# --- Streaming TTS setup ---
|
||||
# When ElevenLabs is the TTS provider and sounddevice is available,
|
||||
# we stream audio sentence-by-sentence as the agent generates tokens
|
||||
|
|
@ -5084,6 +5272,9 @@ class HermesCLI:
|
|||
|
||||
agent_thread.join() # Ensure agent thread completes
|
||||
|
||||
# Flush any remaining streamed text and close the box
|
||||
self._flush_stream()
|
||||
|
||||
# Signal end-of-text to TTS consumer and wait for it to finish
|
||||
if use_streaming_tts and text_queue is not None:
|
||||
text_queue.put(None) # sentinel
|
||||
|
|
@ -5126,8 +5317,9 @@ class HermesCLI:
|
|||
|
||||
response_previewed = result.get("response_previewed", False) if result else False
|
||||
|
||||
# Display reasoning (thinking) box if enabled and available
|
||||
if self.show_reasoning and result:
|
||||
# Display reasoning (thinking) box if enabled and available.
|
||||
# Skip when streaming already showed reasoning live.
|
||||
if self.show_reasoning and result and not self._stream_started:
|
||||
reasoning = result.get("last_reasoning")
|
||||
if reasoning:
|
||||
w = shutil.get_terminal_size().columns
|
||||
|
|
@ -5158,10 +5350,15 @@ class HermesCLI:
|
|||
_resp_text = "#FFF8DC"
|
||||
|
||||
is_error_response = result and (result.get("failed") or result.get("partial"))
|
||||
already_streamed = self._stream_started and self._stream_box_opened and not is_error_response
|
||||
if use_streaming_tts and _streaming_box_opened and not is_error_response:
|
||||
# Text was already printed sentence-by-sentence; just close the box
|
||||
w = shutil.get_terminal_size().columns
|
||||
_cprint(f"\n{_GOLD}╰{'─' * (w - 2)}╯{_RST}")
|
||||
elif already_streamed:
|
||||
# Response was already streamed token-by-token with box framing;
|
||||
# _flush_stream() already closed the box. Skip Rich Panel.
|
||||
pass
|
||||
else:
|
||||
_chat_console = ChatConsole()
|
||||
_chat_console.print(Panel(
|
||||
|
|
|
|||
|
|
@ -146,6 +146,37 @@ class PlatformConfig:
|
|||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamingConfig:
|
||||
"""Configuration for real-time token streaming to messaging platforms."""
|
||||
enabled: bool = False
|
||||
transport: str = "edit" # "edit" (progressive editMessageText) or "off"
|
||||
edit_interval: float = 0.3 # Seconds between message edits
|
||||
buffer_threshold: int = 40 # Chars before forcing an edit
|
||||
cursor: str = " ▉" # Cursor shown during streaming
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"enabled": self.enabled,
|
||||
"transport": self.transport,
|
||||
"edit_interval": self.edit_interval,
|
||||
"buffer_threshold": self.buffer_threshold,
|
||||
"cursor": self.cursor,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "StreamingConfig":
|
||||
if not data:
|
||||
return cls()
|
||||
return cls(
|
||||
enabled=data.get("enabled", False),
|
||||
transport=data.get("transport", "edit"),
|
||||
edit_interval=float(data.get("edit_interval", 0.3)),
|
||||
buffer_threshold=int(data.get("buffer_threshold", 40)),
|
||||
cursor=data.get("cursor", " ▉"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GatewayConfig:
|
||||
"""
|
||||
|
|
@ -179,6 +210,9 @@ class GatewayConfig:
|
|||
# Session isolation in shared chats
|
||||
group_sessions_per_user: bool = True # Isolate group/channel sessions per participant when user IDs are available
|
||||
|
||||
# Streaming configuration
|
||||
streaming: StreamingConfig = field(default_factory=StreamingConfig)
|
||||
|
||||
def get_connected_platforms(self) -> List[Platform]:
|
||||
"""Return list of platforms that are enabled and configured."""
|
||||
connected = []
|
||||
|
|
@ -244,6 +278,7 @@ class GatewayConfig:
|
|||
"always_log_local": self.always_log_local,
|
||||
"stt_enabled": self.stt_enabled,
|
||||
"group_sessions_per_user": self.group_sessions_per_user,
|
||||
"streaming": self.streaming.to_dict(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
|
@ -297,6 +332,7 @@ class GatewayConfig:
|
|||
always_log_local=data.get("always_log_local", True),
|
||||
stt_enabled=_coerce_bool(stt_enabled, True),
|
||||
group_sessions_per_user=_coerce_bool(group_sessions_per_user, True),
|
||||
streaming=StreamingConfig.from_dict(data.get("streaming", {})),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1979,6 +1979,11 @@ class GatewayRunner:
|
|||
if self._should_send_voice_reply(event, response, agent_messages):
|
||||
await self._send_voice_reply(event, response)
|
||||
|
||||
# If streaming already delivered the response, return None so
|
||||
# _process_message_background doesn't send it again.
|
||||
if agent_result.get("already_sent"):
|
||||
return None
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -4135,6 +4140,7 @@ class GatewayRunner:
|
|||
agent_holder = [None] # Mutable container for the agent instance
|
||||
result_holder = [None] # Mutable container for the result
|
||||
tools_holder = [None] # Mutable container for the tool definitions
|
||||
stream_consumer_holder = [None] # Mutable container for stream consumer
|
||||
|
||||
# Bridge sync step_callback → async hooks.emit for agent:step events
|
||||
_loop_for_step = asyncio.get_event_loop()
|
||||
|
|
@ -4197,6 +4203,35 @@ class GatewayRunner:
|
|||
honcho_manager, honcho_config = self._get_or_create_gateway_honcho(session_key)
|
||||
reasoning_config = self._load_reasoning_config()
|
||||
self._reasoning_config = reasoning_config
|
||||
# Set up streaming consumer if enabled
|
||||
_stream_consumer = None
|
||||
_stream_delta_cb = None
|
||||
_scfg = getattr(getattr(self, 'config', None), 'streaming', None)
|
||||
if _scfg is None:
|
||||
from gateway.config import StreamingConfig
|
||||
_scfg = StreamingConfig()
|
||||
|
||||
if _scfg.enabled and _scfg.transport != "off":
|
||||
try:
|
||||
from gateway.stream_consumer import GatewayStreamConsumer, StreamConsumerConfig
|
||||
_adapter = self.adapters.get(source.platform)
|
||||
if _adapter:
|
||||
_consumer_cfg = StreamConsumerConfig(
|
||||
edit_interval=_scfg.edit_interval,
|
||||
buffer_threshold=_scfg.buffer_threshold,
|
||||
cursor=_scfg.cursor,
|
||||
)
|
||||
_stream_consumer = GatewayStreamConsumer(
|
||||
adapter=_adapter,
|
||||
chat_id=source.chat_id,
|
||||
config=_consumer_cfg,
|
||||
metadata={"thread_id": source.thread_id} if source.thread_id else None,
|
||||
)
|
||||
_stream_delta_cb = _stream_consumer.on_delta
|
||||
stream_consumer_holder[0] = _stream_consumer
|
||||
except Exception as _sc_err:
|
||||
logger.debug("Could not set up stream consumer: %s", _sc_err)
|
||||
|
||||
turn_route = self._resolve_turn_agent_config(message, model, runtime_kwargs)
|
||||
agent = AIAgent(
|
||||
model=turn_route["model"],
|
||||
|
|
@ -4217,6 +4252,7 @@ class GatewayRunner:
|
|||
session_id=session_id,
|
||||
tool_progress_callback=progress_callback if tool_progress_enabled else None,
|
||||
step_callback=_step_callback_sync if _hooks_ref.loaded_hooks else None,
|
||||
stream_delta_callback=_stream_delta_cb,
|
||||
platform=platform_key,
|
||||
honcho_session_key=session_key,
|
||||
honcho_manager=honcho_manager,
|
||||
|
|
@ -4287,6 +4323,10 @@ class GatewayRunner:
|
|||
|
||||
result = agent.run_conversation(message, conversation_history=agent_history, task_id=session_id)
|
||||
result_holder[0] = result
|
||||
|
||||
# Signal the stream consumer that the agent is done
|
||||
if _stream_consumer is not None:
|
||||
_stream_consumer.finish()
|
||||
|
||||
# Return final response, or a message if something went wrong
|
||||
final_response = result.get("final_response")
|
||||
|
|
@ -4386,6 +4426,20 @@ class GatewayRunner:
|
|||
progress_task = None
|
||||
if tool_progress_enabled:
|
||||
progress_task = asyncio.create_task(send_progress_messages())
|
||||
|
||||
# Start stream consumer task — polls for consumer creation since it
|
||||
# happens inside run_sync (thread pool) after the agent is constructed.
|
||||
stream_task = None
|
||||
|
||||
async def _start_stream_consumer():
|
||||
"""Wait for the stream consumer to be created, then run it."""
|
||||
for _ in range(200): # Up to 10s wait
|
||||
if stream_consumer_holder[0] is not None:
|
||||
await stream_consumer_holder[0].run()
|
||||
return
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
stream_task = asyncio.create_task(_start_stream_consumer())
|
||||
|
||||
# Track this agent as running for this session (for interrupt support)
|
||||
# We do this in a callback after the agent is created
|
||||
|
|
@ -4468,6 +4522,17 @@ class GatewayRunner:
|
|||
if progress_task:
|
||||
progress_task.cancel()
|
||||
interrupt_monitor.cancel()
|
||||
|
||||
# Wait for stream consumer to finish its final edit
|
||||
if stream_task:
|
||||
try:
|
||||
await asyncio.wait_for(stream_task, timeout=5.0)
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
stream_task.cancel()
|
||||
try:
|
||||
await stream_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Clean up tracking
|
||||
tracking_task.cancel()
|
||||
|
|
@ -4481,6 +4546,12 @@ class GatewayRunner:
|
|||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# If streaming already delivered the response, mark it so the
|
||||
# caller's send() is skipped (avoiding duplicate messages).
|
||||
_sc = stream_consumer_holder[0]
|
||||
if _sc and _sc.already_sent and isinstance(response, dict):
|
||||
response["already_sent"] = True
|
||||
|
||||
return response
|
||||
|
||||
|
|
|
|||
177
gateway/stream_consumer.py
Normal file
177
gateway/stream_consumer.py
Normal file
|
|
@ -0,0 +1,177 @@
|
|||
"""Gateway streaming consumer — bridges sync agent callbacks to async platform delivery.
|
||||
|
||||
The agent fires stream_delta_callback(text) synchronously from its worker thread.
|
||||
GatewayStreamConsumer:
|
||||
1. Receives deltas via on_delta() (thread-safe, sync)
|
||||
2. Queues them to an asyncio task via queue.Queue
|
||||
3. The async run() task buffers, rate-limits, and progressively edits
|
||||
a single message on the target platform
|
||||
|
||||
Design: Uses the edit transport (send initial message, then editMessageText).
|
||||
This is universally supported across Telegram, Discord, and Slack.
|
||||
|
||||
Credit: jobless0x (#774, #1312), OutThisLife (#798), clicksingh (#697).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import queue
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
logger = logging.getLogger("gateway.stream_consumer")
|
||||
|
||||
# Sentinel to signal the stream is complete
|
||||
_DONE = object()
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamConsumerConfig:
|
||||
"""Runtime config for a single stream consumer instance."""
|
||||
edit_interval: float = 0.3
|
||||
buffer_threshold: int = 40
|
||||
cursor: str = " ▉"
|
||||
|
||||
|
||||
class GatewayStreamConsumer:
|
||||
"""Async consumer that progressively edits a platform message with streamed tokens.
|
||||
|
||||
Usage::
|
||||
|
||||
consumer = GatewayStreamConsumer(adapter, chat_id, config, metadata=metadata)
|
||||
# Pass consumer.on_delta as stream_delta_callback to AIAgent
|
||||
agent = AIAgent(..., stream_delta_callback=consumer.on_delta)
|
||||
# Start the consumer as an asyncio task
|
||||
task = asyncio.create_task(consumer.run())
|
||||
# ... run agent in thread pool ...
|
||||
consumer.finish() # signal completion
|
||||
await task # wait for final edit
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
adapter: Any,
|
||||
chat_id: str,
|
||||
config: Optional[StreamConsumerConfig] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
):
|
||||
self.adapter = adapter
|
||||
self.chat_id = chat_id
|
||||
self.cfg = config or StreamConsumerConfig()
|
||||
self.metadata = metadata
|
||||
self._queue: queue.Queue = queue.Queue()
|
||||
self._accumulated = ""
|
||||
self._message_id: Optional[str] = None
|
||||
self._already_sent = False
|
||||
self._edit_supported = True # Disabled on first edit failure (Signal/Email/HA)
|
||||
self._last_edit_time = 0.0
|
||||
|
||||
@property
|
||||
def already_sent(self) -> bool:
|
||||
"""True if at least one message was sent/edited — signals the base
|
||||
adapter to skip re-sending the final response."""
|
||||
return self._already_sent
|
||||
|
||||
def on_delta(self, text: str) -> None:
|
||||
"""Thread-safe callback — called from the agent's worker thread."""
|
||||
if text:
|
||||
self._queue.put(text)
|
||||
|
||||
def finish(self) -> None:
|
||||
"""Signal that the stream is complete."""
|
||||
self._queue.put(_DONE)
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Async task that drains the queue and edits the platform message."""
|
||||
try:
|
||||
while True:
|
||||
# Drain all available items from the queue
|
||||
got_done = False
|
||||
while True:
|
||||
try:
|
||||
item = self._queue.get_nowait()
|
||||
if item is _DONE:
|
||||
got_done = True
|
||||
break
|
||||
self._accumulated += item
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
# Decide whether to flush an edit
|
||||
now = time.monotonic()
|
||||
elapsed = now - self._last_edit_time
|
||||
should_edit = (
|
||||
got_done
|
||||
or (elapsed >= self.cfg.edit_interval
|
||||
and len(self._accumulated) > 0)
|
||||
or len(self._accumulated) >= self.cfg.buffer_threshold
|
||||
)
|
||||
|
||||
if should_edit and self._accumulated:
|
||||
display_text = self._accumulated
|
||||
if not got_done:
|
||||
display_text += self.cfg.cursor
|
||||
|
||||
await self._send_or_edit(display_text)
|
||||
self._last_edit_time = time.monotonic()
|
||||
|
||||
if got_done:
|
||||
# Final edit without cursor
|
||||
if self._accumulated and self._message_id:
|
||||
await self._send_or_edit(self._accumulated)
|
||||
return
|
||||
|
||||
await asyncio.sleep(0.05) # Small yield to not busy-loop
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# Best-effort final edit on cancellation
|
||||
if self._accumulated and self._message_id:
|
||||
try:
|
||||
await self._send_or_edit(self._accumulated)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error("Stream consumer error: %s", e)
|
||||
|
||||
async def _send_or_edit(self, text: str) -> None:
|
||||
"""Send or edit the streaming message."""
|
||||
try:
|
||||
if self._message_id is not None:
|
||||
if self._edit_supported:
|
||||
# Edit existing message
|
||||
result = await self.adapter.edit_message(
|
||||
chat_id=self.chat_id,
|
||||
message_id=self._message_id,
|
||||
content=text,
|
||||
)
|
||||
if result.success:
|
||||
self._already_sent = True
|
||||
else:
|
||||
# Edit not supported by this adapter — stop streaming,
|
||||
# let the normal send path handle the final response.
|
||||
# Without this guard, adapters like Signal/Email would
|
||||
# flood the chat with a new message every edit_interval.
|
||||
logger.debug("Edit failed, disabling streaming for this adapter")
|
||||
self._edit_supported = False
|
||||
else:
|
||||
# Editing not supported — skip intermediate updates.
|
||||
# The final response will be sent by the normal path.
|
||||
pass
|
||||
else:
|
||||
# First message — send new
|
||||
result = await self.adapter.send(
|
||||
chat_id=self.chat_id,
|
||||
content=text,
|
||||
metadata=self.metadata,
|
||||
)
|
||||
if result.success and result.message_id:
|
||||
self._message_id = result.message_id
|
||||
self._already_sent = True
|
||||
else:
|
||||
# Initial send failed — disable streaming for this session
|
||||
self._edit_supported = False
|
||||
except Exception as e:
|
||||
logger.error("Stream send/edit error: %s", e)
|
||||
|
|
@ -217,6 +217,7 @@ DEFAULT_CONFIG = {
|
|||
"resume_display": "full",
|
||||
"bell_on_complete": False,
|
||||
"show_reasoning": False,
|
||||
"streaming": False,
|
||||
"show_cost": False, # Show $ cost in the status bar (off by default)
|
||||
"skin": "default",
|
||||
},
|
||||
|
|
|
|||
416
run_agent.py
416
run_agent.py
|
|
@ -296,6 +296,7 @@ class AIAgent:
|
|||
reasoning_callback: callable = None,
|
||||
clarify_callback: callable = None,
|
||||
step_callback: callable = None,
|
||||
stream_delta_callback: callable = None,
|
||||
max_tokens: int = None,
|
||||
reasoning_config: Dict[str, Any] = None,
|
||||
prefill_messages: List[Dict[str, Any]] = None,
|
||||
|
|
@ -395,6 +396,7 @@ class AIAgent:
|
|||
self.reasoning_callback = reasoning_callback
|
||||
self.clarify_callback = clarify_callback
|
||||
self.step_callback = step_callback
|
||||
self.stream_delta_callback = stream_delta_callback
|
||||
self._last_reported_tool = None # Track for "new tool" mode
|
||||
|
||||
# Interrupt mechanism for breaking out of tool loops
|
||||
|
|
@ -856,9 +858,9 @@ class AIAgent:
|
|||
"""Verbose print — suppressed when streaming TTS is active.
|
||||
|
||||
Pass ``force=True`` for error/warning messages that should always be
|
||||
shown even during streaming TTS playback.
|
||||
shown even during streaming playback (TTS or display).
|
||||
"""
|
||||
if not force and getattr(self, "_stream_callback", None) is not None:
|
||||
if not force and self._has_stream_consumers():
|
||||
return
|
||||
print(*args, **kwargs)
|
||||
|
||||
|
|
@ -2602,15 +2604,39 @@ class AIAgent:
|
|||
def _close_request_openai_client(self, client: Any, *, reason: str) -> None:
|
||||
self._close_openai_client(client, reason=reason, shared=False)
|
||||
|
||||
def _run_codex_stream(self, api_kwargs: dict, client: Any = None):
|
||||
def _run_codex_stream(self, api_kwargs: dict, client: Any = None, on_first_delta: callable = None):
|
||||
"""Execute one streaming Responses API request and return the final response."""
|
||||
active_client = client or self._ensure_primary_openai_client(reason="codex_stream_direct")
|
||||
max_stream_retries = 1
|
||||
has_tool_calls = False
|
||||
first_delta_fired = False
|
||||
for attempt in range(max_stream_retries + 1):
|
||||
try:
|
||||
with active_client.responses.stream(**api_kwargs) as stream:
|
||||
for _ in stream:
|
||||
pass
|
||||
for event in stream:
|
||||
if self._interrupt_requested:
|
||||
break
|
||||
event_type = getattr(event, "type", "")
|
||||
# Fire callbacks on text content deltas (suppress during tool calls)
|
||||
if "output_text.delta" in event_type or event_type == "response.output_text.delta":
|
||||
delta_text = getattr(event, "delta", "")
|
||||
if delta_text and not has_tool_calls:
|
||||
if not first_delta_fired:
|
||||
first_delta_fired = True
|
||||
if on_first_delta:
|
||||
try:
|
||||
on_first_delta()
|
||||
except Exception:
|
||||
pass
|
||||
self._fire_stream_delta(delta_text)
|
||||
# Track tool calls to suppress text streaming
|
||||
elif "function_call" in event_type:
|
||||
has_tool_calls = True
|
||||
# Fire reasoning callbacks
|
||||
elif "reasoning" in event_type and "delta" in event_type:
|
||||
reasoning_text = getattr(event, "delta", "")
|
||||
if reasoning_text:
|
||||
self._fire_reasoning_delta(reasoning_text)
|
||||
return stream.get_final_response()
|
||||
except RuntimeError as exc:
|
||||
err_text = str(exc)
|
||||
|
|
@ -2791,6 +2817,7 @@ class AIAgent:
|
|||
result["response"] = self._run_codex_stream(
|
||||
api_kwargs,
|
||||
client=request_client_holder["client"],
|
||||
on_first_delta=getattr(self, "_codex_on_first_delta", None),
|
||||
)
|
||||
elif self.api_mode == "anthropic_messages":
|
||||
result["response"] = self._anthropic_messages_create(api_kwargs)
|
||||
|
|
@ -2832,116 +2859,246 @@ class AIAgent:
|
|||
raise result["error"]
|
||||
return result["response"]
|
||||
|
||||
def _streaming_api_call(self, api_kwargs: dict, stream_callback):
|
||||
"""Streaming variant of _interruptible_api_call for voice TTS pipeline.
|
||||
# ── Unified streaming API call ─────────────────────────────────────────
|
||||
|
||||
Uses ``stream=True`` and forwards content deltas to *stream_callback*
|
||||
in real-time. Returns a ``SimpleNamespace`` that mimics a normal
|
||||
``ChatCompletion`` so the rest of the agent loop works unchanged.
|
||||
def _fire_stream_delta(self, text: str) -> None:
|
||||
"""Fire all registered stream delta callbacks (display + TTS)."""
|
||||
for cb in (self.stream_delta_callback, self._stream_callback):
|
||||
if cb is not None:
|
||||
try:
|
||||
cb(text)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
This method is separate from ``_interruptible_api_call`` to keep the
|
||||
core agent loop untouched for non-voice users.
|
||||
def _fire_reasoning_delta(self, text: str) -> None:
|
||||
"""Fire reasoning callback if registered."""
|
||||
cb = self.reasoning_callback
|
||||
if cb is not None:
|
||||
try:
|
||||
cb(text)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _has_stream_consumers(self) -> bool:
|
||||
"""Return True if any streaming consumer is registered."""
|
||||
return (
|
||||
self.stream_delta_callback is not None
|
||||
or getattr(self, "_stream_callback", None) is not None
|
||||
)
|
||||
|
||||
def _interruptible_streaming_api_call(
|
||||
self, api_kwargs: dict, *, on_first_delta: callable = None
|
||||
):
|
||||
"""Streaming variant of _interruptible_api_call for real-time token delivery.
|
||||
|
||||
Handles all three api_modes:
|
||||
- chat_completions: stream=True on OpenAI-compatible endpoints
|
||||
- anthropic_messages: client.messages.stream() via Anthropic SDK
|
||||
- codex_responses: delegates to _run_codex_stream (already streaming)
|
||||
|
||||
Fires stream_delta_callback and _stream_callback for each text token.
|
||||
Tool-call turns suppress the callback — only text-only final responses
|
||||
stream to the consumer. Returns a SimpleNamespace that mimics the
|
||||
non-streaming response shape so the rest of the agent loop is unchanged.
|
||||
|
||||
Falls back to _interruptible_api_call on provider errors indicating
|
||||
streaming is not supported.
|
||||
"""
|
||||
if self.api_mode == "codex_responses":
|
||||
# Codex streams internally via _run_codex_stream. The main dispatch
|
||||
# in _interruptible_api_call already calls it; we just need to
|
||||
# ensure on_first_delta reaches it. Store it on the instance
|
||||
# temporarily so _run_codex_stream can pick it up.
|
||||
self._codex_on_first_delta = on_first_delta
|
||||
try:
|
||||
return self._interruptible_api_call(api_kwargs)
|
||||
finally:
|
||||
self._codex_on_first_delta = None
|
||||
|
||||
result = {"response": None, "error": None}
|
||||
request_client_holder = {"client": None}
|
||||
first_delta_fired = {"done": False}
|
||||
deltas_were_sent = {"yes": False} # Track if any deltas were fired (for fallback)
|
||||
|
||||
def _fire_first_delta():
|
||||
if not first_delta_fired["done"] and on_first_delta:
|
||||
first_delta_fired["done"] = True
|
||||
try:
|
||||
on_first_delta()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _call_chat_completions():
|
||||
"""Stream a chat completions response."""
|
||||
stream_kwargs = {**api_kwargs, "stream": True, "stream_options": {"include_usage": True}}
|
||||
request_client_holder["client"] = self._create_request_openai_client(
|
||||
reason="chat_completion_stream_request"
|
||||
)
|
||||
stream = request_client_holder["client"].chat.completions.create(**stream_kwargs)
|
||||
|
||||
content_parts: list = []
|
||||
tool_calls_acc: dict = {}
|
||||
finish_reason = None
|
||||
model_name = None
|
||||
role = "assistant"
|
||||
reasoning_parts: list = []
|
||||
usage_obj = None
|
||||
|
||||
for chunk in stream:
|
||||
if self._interrupt_requested:
|
||||
break
|
||||
|
||||
if not chunk.choices:
|
||||
if hasattr(chunk, "model") and chunk.model:
|
||||
model_name = chunk.model
|
||||
# Usage comes in the final chunk with empty choices
|
||||
if hasattr(chunk, "usage") and chunk.usage:
|
||||
usage_obj = chunk.usage
|
||||
continue
|
||||
|
||||
delta = chunk.choices[0].delta
|
||||
if hasattr(chunk, "model") and chunk.model:
|
||||
model_name = chunk.model
|
||||
|
||||
# Accumulate reasoning content
|
||||
reasoning_text = getattr(delta, "reasoning_content", None) or getattr(delta, "reasoning", None)
|
||||
if reasoning_text:
|
||||
reasoning_parts.append(reasoning_text)
|
||||
self._fire_reasoning_delta(reasoning_text)
|
||||
|
||||
# Accumulate text content — fire callback only when no tool calls
|
||||
if delta and delta.content:
|
||||
content_parts.append(delta.content)
|
||||
if not tool_calls_acc:
|
||||
_fire_first_delta()
|
||||
self._fire_stream_delta(delta.content)
|
||||
deltas_were_sent["yes"] = True
|
||||
|
||||
# Accumulate tool call deltas (silently, no callback)
|
||||
if delta and delta.tool_calls:
|
||||
for tc_delta in delta.tool_calls:
|
||||
idx = tc_delta.index if tc_delta.index is not None else 0
|
||||
if idx not in tool_calls_acc:
|
||||
tool_calls_acc[idx] = {
|
||||
"id": tc_delta.id or "",
|
||||
"type": "function",
|
||||
"function": {"name": "", "arguments": ""},
|
||||
}
|
||||
entry = tool_calls_acc[idx]
|
||||
if tc_delta.id:
|
||||
entry["id"] = tc_delta.id
|
||||
if tc_delta.function:
|
||||
if tc_delta.function.name:
|
||||
entry["function"]["name"] += tc_delta.function.name
|
||||
if tc_delta.function.arguments:
|
||||
entry["function"]["arguments"] += tc_delta.function.arguments
|
||||
|
||||
if chunk.choices[0].finish_reason:
|
||||
finish_reason = chunk.choices[0].finish_reason
|
||||
|
||||
# Usage in the final chunk
|
||||
if hasattr(chunk, "usage") and chunk.usage:
|
||||
usage_obj = chunk.usage
|
||||
|
||||
# Build mock response matching non-streaming shape
|
||||
full_content = "".join(content_parts) or None
|
||||
mock_tool_calls = None
|
||||
if tool_calls_acc:
|
||||
mock_tool_calls = []
|
||||
for idx in sorted(tool_calls_acc):
|
||||
tc = tool_calls_acc[idx]
|
||||
mock_tool_calls.append(SimpleNamespace(
|
||||
id=tc["id"],
|
||||
type=tc["type"],
|
||||
function=SimpleNamespace(
|
||||
name=tc["function"]["name"],
|
||||
arguments=tc["function"]["arguments"],
|
||||
),
|
||||
))
|
||||
|
||||
full_reasoning = "".join(reasoning_parts) or None
|
||||
mock_message = SimpleNamespace(
|
||||
role=role,
|
||||
content=full_content,
|
||||
tool_calls=mock_tool_calls,
|
||||
reasoning_content=full_reasoning,
|
||||
)
|
||||
mock_choice = SimpleNamespace(
|
||||
index=0,
|
||||
message=mock_message,
|
||||
finish_reason=finish_reason or "stop",
|
||||
)
|
||||
return SimpleNamespace(
|
||||
id="stream-" + str(uuid.uuid4()),
|
||||
model=model_name,
|
||||
choices=[mock_choice],
|
||||
usage=usage_obj,
|
||||
)
|
||||
|
||||
def _call_anthropic():
|
||||
"""Stream an Anthropic Messages API response.
|
||||
|
||||
Fires delta callbacks for real-time token delivery, but returns
|
||||
the native Anthropic Message object from get_final_message() so
|
||||
the rest of the agent loop (validation, tool extraction, etc.)
|
||||
works unchanged.
|
||||
"""
|
||||
has_tool_use = False
|
||||
|
||||
# Use the Anthropic SDK's streaming context manager
|
||||
with self._anthropic_client.messages.stream(**api_kwargs) as stream:
|
||||
for event in stream:
|
||||
if self._interrupt_requested:
|
||||
break
|
||||
|
||||
event_type = getattr(event, "type", None)
|
||||
|
||||
if event_type == "content_block_start":
|
||||
block = getattr(event, "content_block", None)
|
||||
if block and getattr(block, "type", None) == "tool_use":
|
||||
has_tool_use = True
|
||||
|
||||
elif event_type == "content_block_delta":
|
||||
delta = getattr(event, "delta", None)
|
||||
if delta:
|
||||
delta_type = getattr(delta, "type", None)
|
||||
if delta_type == "text_delta":
|
||||
text = getattr(delta, "text", "")
|
||||
if text and not has_tool_use:
|
||||
_fire_first_delta()
|
||||
self._fire_stream_delta(text)
|
||||
elif delta_type == "thinking_delta":
|
||||
thinking_text = getattr(delta, "thinking", "")
|
||||
if thinking_text:
|
||||
self._fire_reasoning_delta(thinking_text)
|
||||
|
||||
# Return the native Anthropic Message for downstream processing
|
||||
return stream.get_final_message()
|
||||
|
||||
def _call():
|
||||
try:
|
||||
stream_kwargs = {**api_kwargs, "stream": True}
|
||||
request_client_holder["client"] = self._create_request_openai_client(
|
||||
reason="chat_completion_stream_request"
|
||||
)
|
||||
stream = request_client_holder["client"].chat.completions.create(**stream_kwargs)
|
||||
|
||||
content_parts: list[str] = []
|
||||
tool_calls_acc: dict[int, dict] = {}
|
||||
finish_reason = None
|
||||
model_name = None
|
||||
role = "assistant"
|
||||
|
||||
for chunk in stream:
|
||||
if not chunk.choices:
|
||||
if hasattr(chunk, "model") and chunk.model:
|
||||
model_name = chunk.model
|
||||
continue
|
||||
|
||||
delta = chunk.choices[0].delta
|
||||
if hasattr(chunk, "model") and chunk.model:
|
||||
model_name = chunk.model
|
||||
|
||||
if delta and delta.content:
|
||||
content_parts.append(delta.content)
|
||||
try:
|
||||
stream_callback(delta.content)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if delta and delta.tool_calls:
|
||||
for tc_delta in delta.tool_calls:
|
||||
idx = tc_delta.index if tc_delta.index is not None else 0
|
||||
if idx in tool_calls_acc and tc_delta.id and tc_delta.id != tool_calls_acc[idx]["id"]:
|
||||
matched = False
|
||||
for eidx, eentry in tool_calls_acc.items():
|
||||
if eentry["id"] == tc_delta.id:
|
||||
idx = eidx
|
||||
matched = True
|
||||
break
|
||||
if not matched:
|
||||
idx = (max(k for k in tool_calls_acc if isinstance(k, int)) + 1) if tool_calls_acc else 0
|
||||
if idx not in tool_calls_acc:
|
||||
tool_calls_acc[idx] = {
|
||||
"id": tc_delta.id or "",
|
||||
"type": "function",
|
||||
"function": {"name": "", "arguments": ""},
|
||||
}
|
||||
entry = tool_calls_acc[idx]
|
||||
if tc_delta.id:
|
||||
entry["id"] = tc_delta.id
|
||||
if tc_delta.function:
|
||||
if tc_delta.function.name:
|
||||
entry["function"]["name"] += tc_delta.function.name
|
||||
if tc_delta.function.arguments:
|
||||
entry["function"]["arguments"] += tc_delta.function.arguments
|
||||
|
||||
if chunk.choices[0].finish_reason:
|
||||
finish_reason = chunk.choices[0].finish_reason
|
||||
|
||||
full_content = "".join(content_parts) or None
|
||||
mock_tool_calls = None
|
||||
if tool_calls_acc:
|
||||
mock_tool_calls = []
|
||||
for idx in sorted(tool_calls_acc):
|
||||
tc = tool_calls_acc[idx]
|
||||
mock_tool_calls.append(SimpleNamespace(
|
||||
id=tc["id"],
|
||||
type=tc["type"],
|
||||
function=SimpleNamespace(
|
||||
name=tc["function"]["name"],
|
||||
arguments=tc["function"]["arguments"],
|
||||
),
|
||||
))
|
||||
|
||||
mock_message = SimpleNamespace(
|
||||
role=role,
|
||||
content=full_content,
|
||||
tool_calls=mock_tool_calls,
|
||||
reasoning_content=None,
|
||||
)
|
||||
mock_choice = SimpleNamespace(
|
||||
index=0,
|
||||
message=mock_message,
|
||||
finish_reason=finish_reason or "stop",
|
||||
)
|
||||
mock_response = SimpleNamespace(
|
||||
id="stream-" + str(uuid.uuid4()),
|
||||
model=model_name,
|
||||
choices=[mock_choice],
|
||||
usage=None,
|
||||
)
|
||||
result["response"] = mock_response
|
||||
|
||||
if self.api_mode == "anthropic_messages":
|
||||
self._try_refresh_anthropic_client_credentials()
|
||||
result["response"] = _call_anthropic()
|
||||
else:
|
||||
result["response"] = _call_chat_completions()
|
||||
except Exception as e:
|
||||
result["error"] = e
|
||||
if deltas_were_sent["yes"]:
|
||||
# Streaming failed AFTER some tokens were already delivered
|
||||
# to consumers. Don't fall back — that would cause
|
||||
# double-delivery (partial streamed + full non-streamed).
|
||||
# Let the error propagate; the partial content already
|
||||
# reached the user via the stream.
|
||||
logger.warning("Streaming failed after partial delivery, not falling back: %s", e)
|
||||
result["error"] = e
|
||||
else:
|
||||
# Streaming failed before any tokens reached consumers.
|
||||
# Safe to fall back to the standard non-streaming path.
|
||||
logger.info("Streaming failed before delivery, falling back to non-streaming: %s", e)
|
||||
try:
|
||||
result["response"] = self._interruptible_api_call(api_kwargs)
|
||||
except Exception as fallback_err:
|
||||
result["error"] = fallback_err
|
||||
finally:
|
||||
request_client = request_client_holder.get("client")
|
||||
if request_client is not None:
|
||||
|
|
@ -2967,7 +3124,7 @@ class AIAgent:
|
|||
self._close_request_openai_client(request_client, reason="stream_interrupt_abort")
|
||||
except Exception:
|
||||
pass
|
||||
raise InterruptedError("Agent interrupted during API call")
|
||||
raise InterruptedError("Agent interrupted during streaming API call")
|
||||
if result["error"] is not None:
|
||||
raise result["error"]
|
||||
return result["response"]
|
||||
|
|
@ -4173,7 +4330,7 @@ class AIAgent:
|
|||
spinner.stop(cute_msg)
|
||||
elif self.quiet_mode:
|
||||
self._vprint(f" {cute_msg}")
|
||||
elif self.quiet_mode and self._stream_callback is None:
|
||||
elif self.quiet_mode and not self._has_stream_consumers():
|
||||
face = random.choice(KawaiiSpinner.KAWAII_WAITING)
|
||||
emoji = _get_tool_emoji(function_name)
|
||||
preview = _build_tool_preview(function_name, function_args) or function_name
|
||||
|
|
@ -4810,8 +4967,8 @@ class AIAgent:
|
|||
self._vprint(f"\n{self.log_prefix}🔄 Making API call #{api_call_count}/{self.max_iterations}...")
|
||||
self._vprint(f"{self.log_prefix} 📊 Request size: {len(api_messages)} messages, ~{approx_tokens:,} tokens (~{total_chars:,} chars)")
|
||||
self._vprint(f"{self.log_prefix} 🔧 Available tools: {len(self.tools) if self.tools else 0}")
|
||||
elif self._stream_callback is None:
|
||||
# Animated thinking spinner in quiet mode (skip during streaming TTS)
|
||||
elif not self._has_stream_consumers():
|
||||
# Animated thinking spinner in quiet mode (skip during streaming)
|
||||
face = random.choice(KawaiiSpinner.KAWAII_THINKING)
|
||||
verb = random.choice(KawaiiSpinner.THINKING_VERBS)
|
||||
if self.thinking_callback:
|
||||
|
|
@ -4851,33 +5008,22 @@ class AIAgent:
|
|||
if os.getenv("HERMES_DUMP_REQUESTS", "").strip().lower() in {"1", "true", "yes", "on"}:
|
||||
self._dump_api_request_debug(api_kwargs, reason="preflight")
|
||||
|
||||
cb = getattr(self, "_stream_callback", None)
|
||||
if cb is not None and self.api_mode == "chat_completions":
|
||||
response = self._streaming_api_call(api_kwargs, cb)
|
||||
if self._has_stream_consumers():
|
||||
# Streaming path: fire delta callbacks for real-time
|
||||
# token delivery to CLI display, gateway, or TTS.
|
||||
def _stop_spinner():
|
||||
nonlocal thinking_spinner
|
||||
if thinking_spinner:
|
||||
thinking_spinner.stop("")
|
||||
thinking_spinner = None
|
||||
if self.thinking_callback:
|
||||
self.thinking_callback("")
|
||||
|
||||
response = self._interruptible_streaming_api_call(
|
||||
api_kwargs, on_first_delta=_stop_spinner
|
||||
)
|
||||
else:
|
||||
response = self._interruptible_api_call(api_kwargs)
|
||||
# Forward full response to TTS callback for non-streaming providers
|
||||
# (e.g. Anthropic) so voice TTS still works via batch delivery.
|
||||
if cb is not None and response:
|
||||
try:
|
||||
content = None
|
||||
# Try choices first — _interruptible_api_call converts all
|
||||
# providers (including Anthropic) to this format.
|
||||
try:
|
||||
content = response.choices[0].message.content
|
||||
except (AttributeError, IndexError):
|
||||
pass
|
||||
# Fallback: Anthropic native content blocks
|
||||
if not content and self.api_mode == "anthropic_messages":
|
||||
text_parts = [
|
||||
block.text for block in getattr(response, "content", [])
|
||||
if getattr(block, "type", None) == "text" and getattr(block, "text", None)
|
||||
]
|
||||
content = " ".join(text_parts) if text_parts else None
|
||||
if content:
|
||||
cb(content)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
api_duration = time.time() - api_start_time
|
||||
|
||||
|
|
|
|||
|
|
@ -59,8 +59,11 @@ def _build_agent(shared_client=None):
|
|||
agent._interrupt_requested = False
|
||||
agent._interrupt_message = None
|
||||
agent._client_lock = threading.RLock()
|
||||
agent._client_kwargs = {"api_key": "test-key", "base_url": agent.base_url}
|
||||
agent._client_kwargs = {"api_key": "***", "base_url": agent.base_url}
|
||||
agent.client = shared_client or FakeSharedClient(lambda **kwargs: {"shared": True})
|
||||
agent.stream_delta_callback = None
|
||||
agent._stream_callback = None
|
||||
agent.reasoning_callback = None
|
||||
return agent
|
||||
|
||||
|
||||
|
|
@ -173,7 +176,11 @@ def test_streaming_call_recreates_closed_shared_client_before_request(monkeypatc
|
|||
monkeypatch.setattr(run_agent, "OpenAI", factory)
|
||||
|
||||
agent = _build_agent(shared_client=stale_shared)
|
||||
response = agent._streaming_api_call({"model": agent.model, "messages": []}, lambda _delta: None)
|
||||
agent.stream_delta_callback = lambda _delta: None
|
||||
# Force chat_completions mode so the streaming path uses
|
||||
# chat.completions.create(stream=True) instead of Codex responses.stream()
|
||||
agent.api_mode = "chat_completions"
|
||||
response = agent._interruptible_streaming_api_call({"model": agent.model, "messages": []})
|
||||
|
||||
assert response.choices[0].message.content == "Hello world"
|
||||
assert agent.client is replacement_shared
|
||||
|
|
|
|||
|
|
@ -2329,8 +2329,9 @@ class TestStreamingApiCall:
|
|||
]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
callback = MagicMock()
|
||||
agent.stream_delta_callback = callback
|
||||
|
||||
resp = agent._streaming_api_call({"messages": []}, callback)
|
||||
resp = agent._interruptible_streaming_api_call({"messages": []})
|
||||
|
||||
assert resp.choices[0].message.content == "Hello World"
|
||||
assert resp.choices[0].finish_reason == "stop"
|
||||
|
|
@ -2347,7 +2348,7 @@ class TestStreamingApiCall:
|
|||
]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
resp = agent._streaming_api_call({"messages": []}, MagicMock())
|
||||
resp = agent._interruptible_streaming_api_call({"messages": []})
|
||||
|
||||
tc = resp.choices[0].message.tool_calls
|
||||
assert len(tc) == 1
|
||||
|
|
@ -2363,7 +2364,7 @@ class TestStreamingApiCall:
|
|||
]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
resp = agent._streaming_api_call({"messages": []}, MagicMock())
|
||||
resp = agent._interruptible_streaming_api_call({"messages": []})
|
||||
|
||||
tc = resp.choices[0].message.tool_calls
|
||||
assert len(tc) == 2
|
||||
|
|
@ -2378,7 +2379,7 @@ class TestStreamingApiCall:
|
|||
]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
resp = agent._streaming_api_call({"messages": []}, MagicMock())
|
||||
resp = agent._interruptible_streaming_api_call({"messages": []})
|
||||
|
||||
assert resp.choices[0].message.content == "I'll search"
|
||||
assert len(resp.choices[0].message.tool_calls) == 1
|
||||
|
|
@ -2387,7 +2388,7 @@ class TestStreamingApiCall:
|
|||
chunks = [_make_chunk(finish_reason="stop")]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
resp = agent._streaming_api_call({"messages": []}, MagicMock())
|
||||
resp = agent._interruptible_streaming_api_call({"messages": []})
|
||||
|
||||
assert resp.choices[0].message.content is None
|
||||
assert resp.choices[0].message.tool_calls is None
|
||||
|
|
@ -2399,9 +2400,9 @@ class TestStreamingApiCall:
|
|||
_make_chunk(finish_reason="stop"),
|
||||
]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
callback = MagicMock(side_effect=ValueError("boom"))
|
||||
agent.stream_delta_callback = MagicMock(side_effect=ValueError("boom"))
|
||||
|
||||
resp = agent._streaming_api_call({"messages": []}, callback)
|
||||
resp = agent._interruptible_streaming_api_call({"messages": []})
|
||||
|
||||
assert resp.choices[0].message.content == "Hello World"
|
||||
|
||||
|
|
@ -2412,7 +2413,7 @@ class TestStreamingApiCall:
|
|||
]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
resp = agent._streaming_api_call({"messages": []}, MagicMock())
|
||||
resp = agent._interruptible_streaming_api_call({"messages": []})
|
||||
|
||||
assert resp.model == "gpt-4o"
|
||||
|
||||
|
|
@ -2420,22 +2421,23 @@ class TestStreamingApiCall:
|
|||
chunks = [_make_chunk(content="x"), _make_chunk(finish_reason="stop")]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
agent._streaming_api_call({"messages": [], "model": "test"}, MagicMock())
|
||||
agent._interruptible_streaming_api_call({"messages": [], "model": "test"})
|
||||
|
||||
call_kwargs = agent.client.chat.completions.create.call_args
|
||||
assert call_kwargs[1].get("stream") is True or call_kwargs.kwargs.get("stream") is True
|
||||
|
||||
def test_api_exception_propagated(self, agent):
|
||||
def test_api_exception_falls_back_to_non_streaming(self, agent):
|
||||
"""When streaming fails before any deltas, fallback to non-streaming is attempted."""
|
||||
agent.client.chat.completions.create.side_effect = ConnectionError("fail")
|
||||
|
||||
# The fallback also uses the same client, so it'll fail too
|
||||
with pytest.raises(ConnectionError, match="fail"):
|
||||
agent._streaming_api_call({"messages": []}, MagicMock())
|
||||
agent._interruptible_streaming_api_call({"messages": []})
|
||||
|
||||
def test_response_has_uuid_id(self, agent):
|
||||
chunks = [_make_chunk(content="x"), _make_chunk(finish_reason="stop")]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
resp = agent._streaming_api_call({"messages": []}, MagicMock())
|
||||
resp = agent._interruptible_streaming_api_call({"messages": []})
|
||||
|
||||
assert resp.id.startswith("stream-")
|
||||
assert len(resp.id) > len("stream-")
|
||||
|
|
@ -2449,7 +2451,7 @@ class TestStreamingApiCall:
|
|||
]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
resp = agent._streaming_api_call({"messages": []}, MagicMock())
|
||||
resp = agent._interruptible_streaming_api_call({"messages": []})
|
||||
|
||||
assert resp.choices[0].message.content == "Hello"
|
||||
assert resp.model == "gpt-4"
|
||||
|
|
@ -2505,7 +2507,7 @@ class TestAnthropicInterruptHandler:
|
|||
def test_streaming_has_anthropic_branch(self):
|
||||
"""_streaming_api_call must also handle Anthropic interrupt."""
|
||||
import inspect
|
||||
source = inspect.getsource(AIAgent._streaming_api_call)
|
||||
source = inspect.getsource(AIAgent._interruptible_streaming_api_call)
|
||||
assert "anthropic_messages" in source, \
|
||||
"_streaming_api_call must handle Anthropic interrupt"
|
||||
|
||||
|
|
|
|||
571
tests/test_streaming.py
Normal file
571
tests/test_streaming.py
Normal file
|
|
@ -0,0 +1,571 @@
|
|||
"""Tests for streaming token delivery infrastructure.
|
||||
|
||||
Tests the unified streaming API call, delta callbacks, tool-call
|
||||
suppression, provider fallback, and CLI streaming display.
|
||||
"""
|
||||
import json
|
||||
import threading
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch, PropertyMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _make_stream_chunk(
|
||||
content=None, tool_calls=None, finish_reason=None,
|
||||
model=None, reasoning_content=None, usage=None,
|
||||
):
|
||||
"""Build a mock streaming chunk matching OpenAI's ChatCompletionChunk shape."""
|
||||
delta = SimpleNamespace(
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
reasoning_content=reasoning_content,
|
||||
reasoning=None,
|
||||
)
|
||||
choice = SimpleNamespace(
|
||||
index=0,
|
||||
delta=delta,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
chunk = SimpleNamespace(
|
||||
choices=[choice],
|
||||
model=model,
|
||||
usage=usage,
|
||||
)
|
||||
return chunk
|
||||
|
||||
|
||||
def _make_tool_call_delta(index=0, tc_id=None, name=None, arguments=None):
|
||||
"""Build a mock tool call delta."""
|
||||
func = SimpleNamespace(name=name, arguments=arguments)
|
||||
return SimpleNamespace(index=index, id=tc_id, function=func)
|
||||
|
||||
|
||||
def _make_empty_chunk(model=None, usage=None):
|
||||
"""Build a chunk with no choices (usage-only final chunk)."""
|
||||
return SimpleNamespace(choices=[], model=model, usage=usage)
|
||||
|
||||
|
||||
# ── Test: Streaming Accumulator ──────────────────────────────────────────
|
||||
|
||||
|
||||
class TestStreamingAccumulator:
|
||||
"""Verify that _interruptible_streaming_api_call accumulates content
|
||||
and tool calls into a response matching the non-streaming shape."""
|
||||
|
||||
@patch("run_agent.AIAgent._create_request_openai_client")
|
||||
@patch("run_agent.AIAgent._close_request_openai_client")
|
||||
def test_text_only_response(self, mock_close, mock_create):
|
||||
"""Text-only stream produces correct response shape."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
chunks = [
|
||||
_make_stream_chunk(content="Hello"),
|
||||
_make_stream_chunk(content=" world"),
|
||||
_make_stream_chunk(content="!", finish_reason="stop", model="test-model"),
|
||||
_make_empty_chunk(usage=SimpleNamespace(prompt_tokens=10, completion_tokens=3)),
|
||||
]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create.return_value = iter(chunks)
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
agent.api_mode = "chat_completions"
|
||||
agent._interrupt_requested = False
|
||||
|
||||
response = agent._interruptible_streaming_api_call({})
|
||||
|
||||
assert response.choices[0].message.content == "Hello world!"
|
||||
assert response.choices[0].message.tool_calls is None
|
||||
assert response.choices[0].finish_reason == "stop"
|
||||
assert response.usage is not None
|
||||
assert response.usage.completion_tokens == 3
|
||||
|
||||
@patch("run_agent.AIAgent._create_request_openai_client")
|
||||
@patch("run_agent.AIAgent._close_request_openai_client")
|
||||
def test_tool_call_response(self, mock_close, mock_create):
|
||||
"""Tool call stream accumulates ID, name, and arguments."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
chunks = [
|
||||
_make_stream_chunk(tool_calls=[
|
||||
_make_tool_call_delta(index=0, tc_id="call_123", name="terminal")
|
||||
]),
|
||||
_make_stream_chunk(tool_calls=[
|
||||
_make_tool_call_delta(index=0, arguments='{"command":')
|
||||
]),
|
||||
_make_stream_chunk(tool_calls=[
|
||||
_make_tool_call_delta(index=0, arguments=' "ls"}')
|
||||
]),
|
||||
_make_stream_chunk(finish_reason="tool_calls"),
|
||||
]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create.return_value = iter(chunks)
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
agent.api_mode = "chat_completions"
|
||||
agent._interrupt_requested = False
|
||||
|
||||
response = agent._interruptible_streaming_api_call({})
|
||||
|
||||
tc = response.choices[0].message.tool_calls
|
||||
assert tc is not None
|
||||
assert len(tc) == 1
|
||||
assert tc[0].id == "call_123"
|
||||
assert tc[0].function.name == "terminal"
|
||||
assert tc[0].function.arguments == '{"command": "ls"}'
|
||||
|
||||
@patch("run_agent.AIAgent._create_request_openai_client")
|
||||
@patch("run_agent.AIAgent._close_request_openai_client")
|
||||
def test_mixed_content_and_tool_calls(self, mock_close, mock_create):
|
||||
"""Stream with both text and tool calls accumulates both."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
chunks = [
|
||||
_make_stream_chunk(content="Let me check"),
|
||||
_make_stream_chunk(tool_calls=[
|
||||
_make_tool_call_delta(index=0, tc_id="call_456", name="web_search")
|
||||
]),
|
||||
_make_stream_chunk(tool_calls=[
|
||||
_make_tool_call_delta(index=0, arguments='{"query": "test"}')
|
||||
]),
|
||||
_make_stream_chunk(finish_reason="tool_calls"),
|
||||
]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create.return_value = iter(chunks)
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
agent.api_mode = "chat_completions"
|
||||
agent._interrupt_requested = False
|
||||
|
||||
response = agent._interruptible_streaming_api_call({})
|
||||
|
||||
assert response.choices[0].message.content == "Let me check"
|
||||
assert len(response.choices[0].message.tool_calls) == 1
|
||||
|
||||
|
||||
# ── Test: Streaming Callbacks ────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestStreamingCallbacks:
|
||||
"""Verify that delta callbacks fire correctly."""
|
||||
|
||||
@patch("run_agent.AIAgent._create_request_openai_client")
|
||||
@patch("run_agent.AIAgent._close_request_openai_client")
|
||||
def test_deltas_fire_in_order(self, mock_close, mock_create):
|
||||
"""Callbacks receive text deltas in order."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
chunks = [
|
||||
_make_stream_chunk(content="a"),
|
||||
_make_stream_chunk(content="b"),
|
||||
_make_stream_chunk(content="c"),
|
||||
_make_stream_chunk(finish_reason="stop"),
|
||||
]
|
||||
|
||||
deltas = []
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create.return_value = iter(chunks)
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
stream_delta_callback=lambda t: deltas.append(t),
|
||||
)
|
||||
agent.api_mode = "chat_completions"
|
||||
agent._interrupt_requested = False
|
||||
|
||||
agent._interruptible_streaming_api_call({})
|
||||
|
||||
assert deltas == ["a", "b", "c"]
|
||||
|
||||
@patch("run_agent.AIAgent._create_request_openai_client")
|
||||
@patch("run_agent.AIAgent._close_request_openai_client")
|
||||
def test_on_first_delta_fires_once(self, mock_close, mock_create):
|
||||
"""on_first_delta callback fires exactly once."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
chunks = [
|
||||
_make_stream_chunk(content="a"),
|
||||
_make_stream_chunk(content="b"),
|
||||
_make_stream_chunk(finish_reason="stop"),
|
||||
]
|
||||
|
||||
first_delta_calls = []
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create.return_value = iter(chunks)
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
agent.api_mode = "chat_completions"
|
||||
agent._interrupt_requested = False
|
||||
|
||||
agent._interruptible_streaming_api_call(
|
||||
{}, on_first_delta=lambda: first_delta_calls.append(True)
|
||||
)
|
||||
|
||||
assert len(first_delta_calls) == 1
|
||||
|
||||
@patch("run_agent.AIAgent._create_request_openai_client")
|
||||
@patch("run_agent.AIAgent._close_request_openai_client")
|
||||
def test_tool_only_does_not_fire_callback(self, mock_close, mock_create):
|
||||
"""Tool-call-only stream does not fire the delta callback."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
chunks = [
|
||||
_make_stream_chunk(tool_calls=[
|
||||
_make_tool_call_delta(index=0, tc_id="call_789", name="terminal")
|
||||
]),
|
||||
_make_stream_chunk(tool_calls=[
|
||||
_make_tool_call_delta(index=0, arguments='{"command": "ls"}')
|
||||
]),
|
||||
_make_stream_chunk(finish_reason="tool_calls"),
|
||||
]
|
||||
|
||||
deltas = []
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create.return_value = iter(chunks)
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
stream_delta_callback=lambda t: deltas.append(t),
|
||||
)
|
||||
agent.api_mode = "chat_completions"
|
||||
agent._interrupt_requested = False
|
||||
|
||||
agent._interruptible_streaming_api_call({})
|
||||
|
||||
assert deltas == []
|
||||
|
||||
@patch("run_agent.AIAgent._create_request_openai_client")
|
||||
@patch("run_agent.AIAgent._close_request_openai_client")
|
||||
def test_text_suppressed_when_tool_calls_present(self, mock_close, mock_create):
|
||||
"""Text deltas are suppressed when tool calls are also in the stream."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
chunks = [
|
||||
_make_stream_chunk(content="thinking..."),
|
||||
_make_stream_chunk(tool_calls=[
|
||||
_make_tool_call_delta(index=0, tc_id="call_abc", name="read_file")
|
||||
]),
|
||||
_make_stream_chunk(content=" more text"),
|
||||
_make_stream_chunk(finish_reason="tool_calls"),
|
||||
]
|
||||
|
||||
deltas = []
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create.return_value = iter(chunks)
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
stream_delta_callback=lambda t: deltas.append(t),
|
||||
)
|
||||
agent.api_mode = "chat_completions"
|
||||
agent._interrupt_requested = False
|
||||
|
||||
response = agent._interruptible_streaming_api_call({})
|
||||
|
||||
# Text before tool call IS fired (we don't know yet it will have tools)
|
||||
assert "thinking..." in deltas
|
||||
# Text after tool call is NOT fired
|
||||
assert " more text" not in deltas
|
||||
# But content is still accumulated in the response
|
||||
assert response.choices[0].message.content == "thinking... more text"
|
||||
|
||||
|
||||
# ── Test: Streaming Fallback ────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestStreamingFallback:
|
||||
"""Verify fallback to non-streaming on ANY streaming error."""
|
||||
|
||||
@patch("run_agent.AIAgent._interruptible_api_call")
|
||||
@patch("run_agent.AIAgent._create_request_openai_client")
|
||||
@patch("run_agent.AIAgent._close_request_openai_client")
|
||||
def test_stream_error_falls_back(self, mock_close, mock_create, mock_non_stream):
|
||||
"""'not supported' error triggers fallback to non-streaming."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create.side_effect = Exception(
|
||||
"Streaming is not supported for this model"
|
||||
)
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
fallback_response = SimpleNamespace(
|
||||
id="fallback",
|
||||
model="test",
|
||||
choices=[SimpleNamespace(
|
||||
index=0,
|
||||
message=SimpleNamespace(
|
||||
role="assistant",
|
||||
content="fallback response",
|
||||
tool_calls=None,
|
||||
reasoning_content=None,
|
||||
),
|
||||
finish_reason="stop",
|
||||
)],
|
||||
usage=None,
|
||||
)
|
||||
mock_non_stream.return_value = fallback_response
|
||||
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
agent.api_mode = "chat_completions"
|
||||
agent._interrupt_requested = False
|
||||
|
||||
response = agent._interruptible_streaming_api_call({})
|
||||
|
||||
assert response.choices[0].message.content == "fallback response"
|
||||
mock_non_stream.assert_called_once()
|
||||
|
||||
@patch("run_agent.AIAgent._interruptible_api_call")
|
||||
@patch("run_agent.AIAgent._create_request_openai_client")
|
||||
@patch("run_agent.AIAgent._close_request_openai_client")
|
||||
def test_any_stream_error_falls_back(self, mock_close, mock_create, mock_non_stream):
|
||||
"""ANY streaming error triggers fallback — not just specific messages."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create.side_effect = Exception(
|
||||
"Connection reset by peer"
|
||||
)
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
fallback_response = SimpleNamespace(
|
||||
id="fallback",
|
||||
model="test",
|
||||
choices=[SimpleNamespace(
|
||||
index=0,
|
||||
message=SimpleNamespace(
|
||||
role="assistant",
|
||||
content="fallback after connection error",
|
||||
tool_calls=None,
|
||||
reasoning_content=None,
|
||||
),
|
||||
finish_reason="stop",
|
||||
)],
|
||||
usage=None,
|
||||
)
|
||||
mock_non_stream.return_value = fallback_response
|
||||
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
agent.api_mode = "chat_completions"
|
||||
agent._interrupt_requested = False
|
||||
|
||||
response = agent._interruptible_streaming_api_call({})
|
||||
|
||||
assert response.choices[0].message.content == "fallback after connection error"
|
||||
mock_non_stream.assert_called_once()
|
||||
|
||||
@patch("run_agent.AIAgent._interruptible_api_call")
|
||||
@patch("run_agent.AIAgent._create_request_openai_client")
|
||||
@patch("run_agent.AIAgent._close_request_openai_client")
|
||||
def test_fallback_error_propagates(self, mock_close, mock_create, mock_non_stream):
|
||||
"""When both streaming AND fallback fail, the fallback error propagates."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create.side_effect = Exception("stream broke")
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
mock_non_stream.side_effect = Exception("Rate limit exceeded")
|
||||
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
agent.api_mode = "chat_completions"
|
||||
agent._interrupt_requested = False
|
||||
|
||||
with pytest.raises(Exception, match="Rate limit exceeded"):
|
||||
agent._interruptible_streaming_api_call({})
|
||||
|
||||
|
||||
# ── Test: Reasoning Streaming ────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestReasoningStreaming:
|
||||
"""Verify reasoning content is accumulated and callback fires."""
|
||||
|
||||
@patch("run_agent.AIAgent._create_request_openai_client")
|
||||
@patch("run_agent.AIAgent._close_request_openai_client")
|
||||
def test_reasoning_callback_fires(self, mock_close, mock_create):
|
||||
"""Reasoning deltas fire the reasoning_callback."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
chunks = [
|
||||
_make_stream_chunk(reasoning_content="Let me think"),
|
||||
_make_stream_chunk(reasoning_content=" about this"),
|
||||
_make_stream_chunk(content="The answer is 42"),
|
||||
_make_stream_chunk(finish_reason="stop"),
|
||||
]
|
||||
|
||||
reasoning_deltas = []
|
||||
text_deltas = []
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create.return_value = iter(chunks)
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
stream_delta_callback=lambda t: text_deltas.append(t),
|
||||
reasoning_callback=lambda t: reasoning_deltas.append(t),
|
||||
)
|
||||
agent.api_mode = "chat_completions"
|
||||
agent._interrupt_requested = False
|
||||
|
||||
response = agent._interruptible_streaming_api_call({})
|
||||
|
||||
assert reasoning_deltas == ["Let me think", " about this"]
|
||||
assert text_deltas == ["The answer is 42"]
|
||||
assert response.choices[0].message.reasoning_content == "Let me think about this"
|
||||
assert response.choices[0].message.content == "The answer is 42"
|
||||
|
||||
|
||||
# ── Test: _has_stream_consumers ──────────────────────────────────────────
|
||||
|
||||
|
||||
class TestHasStreamConsumers:
|
||||
"""Verify _has_stream_consumers() detects registered callbacks."""
|
||||
|
||||
def test_no_consumers(self):
|
||||
from run_agent import AIAgent
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
assert agent._has_stream_consumers() is False
|
||||
|
||||
def test_delta_callback_set(self):
|
||||
from run_agent import AIAgent
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
stream_delta_callback=lambda t: None,
|
||||
)
|
||||
assert agent._has_stream_consumers() is True
|
||||
|
||||
def test_stream_callback_set(self):
|
||||
from run_agent import AIAgent
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
agent._stream_callback = lambda t: None
|
||||
assert agent._has_stream_consumers() is True
|
||||
|
||||
|
||||
# ── Test: Codex stream fires callbacks ────────────────────────────────
|
||||
|
||||
|
||||
class TestCodexStreamCallbacks:
|
||||
"""Verify _run_codex_stream fires delta callbacks."""
|
||||
|
||||
def test_codex_text_delta_fires_callback(self):
|
||||
from run_agent import AIAgent
|
||||
|
||||
deltas = []
|
||||
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
stream_delta_callback=lambda t: deltas.append(t),
|
||||
)
|
||||
agent.api_mode = "codex_responses"
|
||||
agent._interrupt_requested = False
|
||||
|
||||
# Mock the stream context manager
|
||||
mock_event_text = SimpleNamespace(
|
||||
type="response.output_text.delta",
|
||||
delta="Hello from Codex!",
|
||||
)
|
||||
mock_event_done = SimpleNamespace(
|
||||
type="response.completed",
|
||||
delta="",
|
||||
)
|
||||
|
||||
mock_stream = MagicMock()
|
||||
mock_stream.__enter__ = MagicMock(return_value=mock_stream)
|
||||
mock_stream.__exit__ = MagicMock(return_value=False)
|
||||
mock_stream.__iter__ = MagicMock(return_value=iter([mock_event_text, mock_event_done]))
|
||||
mock_stream.get_final_response.return_value = SimpleNamespace(
|
||||
output=[SimpleNamespace(
|
||||
type="message",
|
||||
content=[SimpleNamespace(type="output_text", text="Hello from Codex!")],
|
||||
)],
|
||||
status="completed",
|
||||
)
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.responses.stream.return_value = mock_stream
|
||||
|
||||
response = agent._run_codex_stream({}, client=mock_client)
|
||||
assert "Hello from Codex!" in deltas
|
||||
|
|
@ -855,6 +855,7 @@ display:
|
|||
resume_display: full # full (show previous messages on resume) | minimal (one-liner only)
|
||||
bell_on_complete: false # Play terminal bell when agent finishes (great for long tasks)
|
||||
show_reasoning: false # Show model reasoning/thinking above each response (toggle with /reasoning show|hide)
|
||||
streaming: false # Stream tokens to terminal as they arrive (real-time output)
|
||||
background_process_notifications: all # all | result | error | off (gateway only)
|
||||
```
|
||||
|
||||
|
|
@ -928,6 +929,36 @@ voice:
|
|||
|
||||
Use `/voice on` in the CLI to enable microphone mode, `record_key` to start/stop recording, and `/voice tts` to toggle spoken replies. See [Voice Mode](/docs/user-guide/features/voice-mode) for end-to-end setup and platform-specific behavior.
|
||||
|
||||
## Streaming
|
||||
|
||||
Stream tokens to the terminal or messaging platforms as they arrive, instead of waiting for the full response.
|
||||
|
||||
### CLI Streaming
|
||||
|
||||
```yaml
|
||||
display:
|
||||
streaming: true # Stream tokens to terminal in real-time
|
||||
show_reasoning: true # Also stream reasoning/thinking tokens (optional)
|
||||
```
|
||||
|
||||
When enabled, responses appear token-by-token inside a streaming box. Tool calls are still captured silently. If the provider doesn't support streaming, it falls back to the normal display automatically.
|
||||
|
||||
### Gateway Streaming (Telegram, Discord, Slack)
|
||||
|
||||
```yaml
|
||||
streaming:
|
||||
enabled: true # Enable progressive message editing
|
||||
edit_interval: 0.3 # Seconds between message edits
|
||||
buffer_threshold: 40 # Characters before forcing an edit flush
|
||||
cursor: " ▉" # Cursor shown during streaming
|
||||
```
|
||||
|
||||
When enabled, the bot sends a message on the first token, then progressively edits it as more tokens arrive. Platforms that don't support message editing (Signal, Email) gracefully skip streaming and deliver the final response normally.
|
||||
|
||||
:::note
|
||||
Streaming is disabled by default. Enable it in `~/.hermes/config.yaml` to try the streaming UX.
|
||||
:::
|
||||
|
||||
## Group Chat Session Isolation
|
||||
|
||||
Control whether shared chats keep one conversation per room or one conversation per participant:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue