feat: unified streaming infrastructure — core delta callbacks for all providers
Stage 1 of streaming support. Adds: - stream_delta_callback parameter on AIAgent.__init__ for real-time token delivery - _interruptible_streaming_api_call() handling chat_completions + anthropic_messages - Enhanced _run_codex_stream() to fire delta callbacks during Codex streaming - _fire_stream_delta() fires both display and TTS callbacks - _fire_reasoning_delta() for reasoning content streaming - Tool-call suppression: callbacks only fire on text-only responses - on_first_delta callback for spinner control on first token - Provider fallback: graceful degradation to non-streaming - _has_stream_consumers() unifies stream_delta_callback and _stream_callback checks - Anthropic streaming returns native Message for downstream compatibility Drawing from PRs #922 (unified streaming), #1312 (gateway consumer), #774 (Telegram streaming), #798 (CLI streaming), #1214 (reasoning modes). Credit: jobless0x, OutThisLife, clicksingh, raulvidis.
This commit is contained in:
parent
3543b755af
commit
c1ac32737d
1 changed files with 301 additions and 32 deletions
333
run_agent.py
333
run_agent.py
|
|
@ -296,6 +296,7 @@ class AIAgent:
|
|||
reasoning_callback: callable = None,
|
||||
clarify_callback: callable = None,
|
||||
step_callback: callable = None,
|
||||
stream_delta_callback: callable = None,
|
||||
max_tokens: int = None,
|
||||
reasoning_config: Dict[str, Any] = None,
|
||||
prefill_messages: List[Dict[str, Any]] = None,
|
||||
|
|
@ -395,6 +396,7 @@ class AIAgent:
|
|||
self.reasoning_callback = reasoning_callback
|
||||
self.clarify_callback = clarify_callback
|
||||
self.step_callback = step_callback
|
||||
self.stream_delta_callback = stream_delta_callback
|
||||
self._last_reported_tool = None # Track for "new tool" mode
|
||||
|
||||
# Interrupt mechanism for breaking out of tool loops
|
||||
|
|
@ -856,9 +858,9 @@ class AIAgent:
|
|||
"""Verbose print — suppressed when streaming TTS is active.
|
||||
|
||||
Pass ``force=True`` for error/warning messages that should always be
|
||||
shown even during streaming TTS playback.
|
||||
shown even during streaming playback (TTS or display).
|
||||
"""
|
||||
if not force and getattr(self, "_stream_callback", None) is not None:
|
||||
if not force and self._has_stream_consumers():
|
||||
return
|
||||
print(*args, **kwargs)
|
||||
|
||||
|
|
@ -2606,11 +2608,30 @@ class AIAgent:
|
|||
"""Execute one streaming Responses API request and return the final response."""
|
||||
active_client = client or self._ensure_primary_openai_client(reason="codex_stream_direct")
|
||||
max_stream_retries = 1
|
||||
has_tool_calls = False
|
||||
first_delta_fired = False
|
||||
for attempt in range(max_stream_retries + 1):
|
||||
try:
|
||||
with active_client.responses.stream(**api_kwargs) as stream:
|
||||
for _ in stream:
|
||||
pass
|
||||
for event in stream:
|
||||
if self._interrupt_requested:
|
||||
break
|
||||
event_type = getattr(event, "type", "")
|
||||
# Fire callbacks on text content deltas (suppress during tool calls)
|
||||
if "output_text.delta" in event_type or event_type == "response.output_text.delta":
|
||||
delta_text = getattr(event, "delta", "")
|
||||
if delta_text and not has_tool_calls:
|
||||
if not first_delta_fired:
|
||||
first_delta_fired = True
|
||||
self._fire_stream_delta(delta_text)
|
||||
# Track tool calls to suppress text streaming
|
||||
elif "function_call" in event_type:
|
||||
has_tool_calls = True
|
||||
# Fire reasoning callbacks
|
||||
elif "reasoning" in event_type and "delta" in event_type:
|
||||
reasoning_text = getattr(event, "delta", "")
|
||||
if reasoning_text:
|
||||
self._fire_reasoning_delta(reasoning_text)
|
||||
return stream.get_final_response()
|
||||
except RuntimeError as exc:
|
||||
err_text = str(exc)
|
||||
|
|
@ -2972,6 +2993,265 @@ class AIAgent:
|
|||
raise result["error"]
|
||||
return result["response"]
|
||||
|
||||
# ── Unified streaming API call ─────────────────────────────────────────
|
||||
|
||||
def _fire_stream_delta(self, text: str) -> None:
|
||||
"""Fire all registered stream delta callbacks (display + TTS)."""
|
||||
for cb in (self.stream_delta_callback, self._stream_callback):
|
||||
if cb is not None:
|
||||
try:
|
||||
cb(text)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _fire_reasoning_delta(self, text: str) -> None:
|
||||
"""Fire reasoning callback if registered."""
|
||||
cb = self.reasoning_callback
|
||||
if cb is not None:
|
||||
try:
|
||||
cb(text)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _has_stream_consumers(self) -> bool:
|
||||
"""Return True if any streaming consumer is registered."""
|
||||
return (
|
||||
self.stream_delta_callback is not None
|
||||
or getattr(self, "_stream_callback", None) is not None
|
||||
)
|
||||
|
||||
def _interruptible_streaming_api_call(
|
||||
self, api_kwargs: dict, *, on_first_delta: callable = None
|
||||
):
|
||||
"""Streaming variant of _interruptible_api_call for real-time token delivery.
|
||||
|
||||
Handles all three api_modes:
|
||||
- chat_completions: stream=True on OpenAI-compatible endpoints
|
||||
- anthropic_messages: client.messages.stream() via Anthropic SDK
|
||||
- codex_responses: delegates to _run_codex_stream (already streaming)
|
||||
|
||||
Fires stream_delta_callback and _stream_callback for each text token.
|
||||
Tool-call turns suppress the callback — only text-only final responses
|
||||
stream to the consumer. Returns a SimpleNamespace that mimics the
|
||||
non-streaming response shape so the rest of the agent loop is unchanged.
|
||||
|
||||
Falls back to _interruptible_api_call on provider errors indicating
|
||||
streaming is not supported.
|
||||
"""
|
||||
if self.api_mode == "codex_responses":
|
||||
# Codex already streams internally; we just need to pass callbacks
|
||||
return self._interruptible_api_call(api_kwargs)
|
||||
|
||||
result = {"response": None, "error": None}
|
||||
request_client_holder = {"client": None}
|
||||
first_delta_fired = {"done": False}
|
||||
|
||||
def _fire_first_delta():
|
||||
if not first_delta_fired["done"] and on_first_delta:
|
||||
first_delta_fired["done"] = True
|
||||
try:
|
||||
on_first_delta()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _call_chat_completions():
|
||||
"""Stream a chat completions response."""
|
||||
stream_kwargs = {**api_kwargs, "stream": True, "stream_options": {"include_usage": True}}
|
||||
request_client_holder["client"] = self._create_request_openai_client(
|
||||
reason="chat_completion_stream_request"
|
||||
)
|
||||
stream = request_client_holder["client"].chat.completions.create(**stream_kwargs)
|
||||
|
||||
content_parts: list = []
|
||||
tool_calls_acc: dict = {}
|
||||
finish_reason = None
|
||||
model_name = None
|
||||
role = "assistant"
|
||||
reasoning_parts: list = []
|
||||
usage_obj = None
|
||||
|
||||
for chunk in stream:
|
||||
if self._interrupt_requested:
|
||||
break
|
||||
|
||||
if not chunk.choices:
|
||||
if hasattr(chunk, "model") and chunk.model:
|
||||
model_name = chunk.model
|
||||
# Usage comes in the final chunk with empty choices
|
||||
if hasattr(chunk, "usage") and chunk.usage:
|
||||
usage_obj = chunk.usage
|
||||
continue
|
||||
|
||||
delta = chunk.choices[0].delta
|
||||
if hasattr(chunk, "model") and chunk.model:
|
||||
model_name = chunk.model
|
||||
|
||||
# Accumulate reasoning content
|
||||
reasoning_text = getattr(delta, "reasoning_content", None) or getattr(delta, "reasoning", None)
|
||||
if reasoning_text:
|
||||
reasoning_parts.append(reasoning_text)
|
||||
self._fire_reasoning_delta(reasoning_text)
|
||||
|
||||
# Accumulate text content — fire callback only when no tool calls
|
||||
if delta and delta.content:
|
||||
content_parts.append(delta.content)
|
||||
if not tool_calls_acc:
|
||||
_fire_first_delta()
|
||||
self._fire_stream_delta(delta.content)
|
||||
|
||||
# Accumulate tool call deltas (silently, no callback)
|
||||
if delta and delta.tool_calls:
|
||||
for tc_delta in delta.tool_calls:
|
||||
idx = tc_delta.index if tc_delta.index is not None else 0
|
||||
if idx not in tool_calls_acc:
|
||||
tool_calls_acc[idx] = {
|
||||
"id": tc_delta.id or "",
|
||||
"type": "function",
|
||||
"function": {"name": "", "arguments": ""},
|
||||
}
|
||||
entry = tool_calls_acc[idx]
|
||||
if tc_delta.id:
|
||||
entry["id"] = tc_delta.id
|
||||
if tc_delta.function:
|
||||
if tc_delta.function.name:
|
||||
entry["function"]["name"] += tc_delta.function.name
|
||||
if tc_delta.function.arguments:
|
||||
entry["function"]["arguments"] += tc_delta.function.arguments
|
||||
|
||||
if chunk.choices[0].finish_reason:
|
||||
finish_reason = chunk.choices[0].finish_reason
|
||||
|
||||
# Usage in the final chunk
|
||||
if hasattr(chunk, "usage") and chunk.usage:
|
||||
usage_obj = chunk.usage
|
||||
|
||||
# Build mock response matching non-streaming shape
|
||||
full_content = "".join(content_parts) or None
|
||||
mock_tool_calls = None
|
||||
if tool_calls_acc:
|
||||
mock_tool_calls = []
|
||||
for idx in sorted(tool_calls_acc):
|
||||
tc = tool_calls_acc[idx]
|
||||
mock_tool_calls.append(SimpleNamespace(
|
||||
id=tc["id"],
|
||||
type=tc["type"],
|
||||
function=SimpleNamespace(
|
||||
name=tc["function"]["name"],
|
||||
arguments=tc["function"]["arguments"],
|
||||
),
|
||||
))
|
||||
|
||||
full_reasoning = "".join(reasoning_parts) or None
|
||||
mock_message = SimpleNamespace(
|
||||
role=role,
|
||||
content=full_content,
|
||||
tool_calls=mock_tool_calls,
|
||||
reasoning_content=full_reasoning,
|
||||
)
|
||||
mock_choice = SimpleNamespace(
|
||||
index=0,
|
||||
message=mock_message,
|
||||
finish_reason=finish_reason or "stop",
|
||||
)
|
||||
return SimpleNamespace(
|
||||
id="stream-" + str(uuid.uuid4()),
|
||||
model=model_name,
|
||||
choices=[mock_choice],
|
||||
usage=usage_obj,
|
||||
)
|
||||
|
||||
def _call_anthropic():
|
||||
"""Stream an Anthropic Messages API response.
|
||||
|
||||
Fires delta callbacks for real-time token delivery, but returns
|
||||
the native Anthropic Message object from get_final_message() so
|
||||
the rest of the agent loop (validation, tool extraction, etc.)
|
||||
works unchanged.
|
||||
"""
|
||||
has_tool_use = False
|
||||
|
||||
# Use the Anthropic SDK's streaming context manager
|
||||
with self._anthropic_client.messages.stream(**api_kwargs) as stream:
|
||||
for event in stream:
|
||||
if self._interrupt_requested:
|
||||
break
|
||||
|
||||
event_type = getattr(event, "type", None)
|
||||
|
||||
if event_type == "content_block_start":
|
||||
block = getattr(event, "content_block", None)
|
||||
if block and getattr(block, "type", None) == "tool_use":
|
||||
has_tool_use = True
|
||||
|
||||
elif event_type == "content_block_delta":
|
||||
delta = getattr(event, "delta", None)
|
||||
if delta:
|
||||
delta_type = getattr(delta, "type", None)
|
||||
if delta_type == "text_delta":
|
||||
text = getattr(delta, "text", "")
|
||||
if text and not has_tool_use:
|
||||
_fire_first_delta()
|
||||
self._fire_stream_delta(text)
|
||||
elif delta_type == "thinking_delta":
|
||||
thinking_text = getattr(delta, "thinking", "")
|
||||
if thinking_text:
|
||||
self._fire_reasoning_delta(thinking_text)
|
||||
|
||||
# Return the native Anthropic Message for downstream processing
|
||||
return stream.get_final_message()
|
||||
|
||||
def _call():
|
||||
try:
|
||||
if self.api_mode == "anthropic_messages":
|
||||
self._try_refresh_anthropic_client_credentials()
|
||||
result["response"] = _call_anthropic()
|
||||
else:
|
||||
result["response"] = _call_chat_completions()
|
||||
except Exception as e:
|
||||
err_text = str(e).lower()
|
||||
# Fall back to non-streaming if provider doesn't support it
|
||||
stream_unsupported = any(
|
||||
kw in err_text
|
||||
for kw in ("stream", "not support", "unsupported", "not available")
|
||||
)
|
||||
if stream_unsupported:
|
||||
logger.info("Streaming not supported by provider, falling back to non-streaming: %s", e)
|
||||
try:
|
||||
result["response"] = self._interruptible_api_call(api_kwargs)
|
||||
except Exception as fallback_err:
|
||||
result["error"] = fallback_err
|
||||
else:
|
||||
result["error"] = e
|
||||
finally:
|
||||
request_client = request_client_holder.get("client")
|
||||
if request_client is not None:
|
||||
self._close_request_openai_client(request_client, reason="stream_request_complete")
|
||||
|
||||
t = threading.Thread(target=_call, daemon=True)
|
||||
t.start()
|
||||
while t.is_alive():
|
||||
t.join(timeout=0.3)
|
||||
if self._interrupt_requested:
|
||||
try:
|
||||
if self.api_mode == "anthropic_messages":
|
||||
from agent.anthropic_adapter import build_anthropic_client
|
||||
|
||||
self._anthropic_client.close()
|
||||
self._anthropic_client = build_anthropic_client(
|
||||
self._anthropic_api_key,
|
||||
getattr(self, "_anthropic_base_url", None),
|
||||
)
|
||||
else:
|
||||
request_client = request_client_holder.get("client")
|
||||
if request_client is not None:
|
||||
self._close_request_openai_client(request_client, reason="stream_interrupt_abort")
|
||||
except Exception:
|
||||
pass
|
||||
raise InterruptedError("Agent interrupted during streaming API call")
|
||||
if result["error"] is not None:
|
||||
raise result["error"]
|
||||
return result["response"]
|
||||
|
||||
# ── Provider fallback ──────────────────────────────────────────────────
|
||||
|
||||
def _try_activate_fallback(self) -> bool:
|
||||
|
|
@ -4172,7 +4452,7 @@ class AIAgent:
|
|||
spinner.stop(cute_msg)
|
||||
elif self.quiet_mode:
|
||||
self._vprint(f" {cute_msg}")
|
||||
elif self.quiet_mode and self._stream_callback is None:
|
||||
elif self.quiet_mode and not self._has_stream_consumers():
|
||||
face = random.choice(KawaiiSpinner.KAWAII_WAITING)
|
||||
emoji = _get_tool_emoji(function_name)
|
||||
preview = _build_tool_preview(function_name, function_args) or function_name
|
||||
|
|
@ -4807,8 +5087,8 @@ class AIAgent:
|
|||
self._vprint(f"\n{self.log_prefix}🔄 Making API call #{api_call_count}/{self.max_iterations}...")
|
||||
self._vprint(f"{self.log_prefix} 📊 Request size: {len(api_messages)} messages, ~{approx_tokens:,} tokens (~{total_chars:,} chars)")
|
||||
self._vprint(f"{self.log_prefix} 🔧 Available tools: {len(self.tools) if self.tools else 0}")
|
||||
elif self._stream_callback is None:
|
||||
# Animated thinking spinner in quiet mode (skip during streaming TTS)
|
||||
elif not self._has_stream_consumers():
|
||||
# Animated thinking spinner in quiet mode (skip during streaming)
|
||||
face = random.choice(KawaiiSpinner.KAWAII_THINKING)
|
||||
verb = random.choice(KawaiiSpinner.THINKING_VERBS)
|
||||
if self.thinking_callback:
|
||||
|
|
@ -4848,33 +5128,22 @@ class AIAgent:
|
|||
if os.getenv("HERMES_DUMP_REQUESTS", "").strip().lower() in {"1", "true", "yes", "on"}:
|
||||
self._dump_api_request_debug(api_kwargs, reason="preflight")
|
||||
|
||||
cb = getattr(self, "_stream_callback", None)
|
||||
if cb is not None and self.api_mode == "chat_completions":
|
||||
response = self._streaming_api_call(api_kwargs, cb)
|
||||
if self._has_stream_consumers():
|
||||
# Streaming path: fire delta callbacks for real-time
|
||||
# token delivery to CLI display, gateway, or TTS.
|
||||
def _stop_spinner():
|
||||
nonlocal thinking_spinner
|
||||
if thinking_spinner:
|
||||
thinking_spinner.stop("")
|
||||
thinking_spinner = None
|
||||
if self.thinking_callback:
|
||||
self.thinking_callback("")
|
||||
|
||||
response = self._interruptible_streaming_api_call(
|
||||
api_kwargs, on_first_delta=_stop_spinner
|
||||
)
|
||||
else:
|
||||
response = self._interruptible_api_call(api_kwargs)
|
||||
# Forward full response to TTS callback for non-streaming providers
|
||||
# (e.g. Anthropic) so voice TTS still works via batch delivery.
|
||||
if cb is not None and response:
|
||||
try:
|
||||
content = None
|
||||
# Try choices first — _interruptible_api_call converts all
|
||||
# providers (including Anthropic) to this format.
|
||||
try:
|
||||
content = response.choices[0].message.content
|
||||
except (AttributeError, IndexError):
|
||||
pass
|
||||
# Fallback: Anthropic native content blocks
|
||||
if not content and self.api_mode == "anthropic_messages":
|
||||
text_parts = [
|
||||
block.text for block in getattr(response, "content", [])
|
||||
if getattr(block, "type", None) == "text" and getattr(block, "text", None)
|
||||
]
|
||||
content = " ".join(text_parts) if text_parts else None
|
||||
if content:
|
||||
cb(content)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
api_duration = time.time() - api_start_time
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue