fix: audit fixes — 5 bugs found and resolved
Thorough code review found 5 issues across run_agent.py, cli.py, and gateway/: 1. CRITICAL — Gateway stream consumer task never started: stream_consumer_holder was checked BEFORE run_sync populated it. Fixed with async polling pattern (same as track_agent). 2. MEDIUM-HIGH — Streaming fallback after partial delivery caused double-response: if streaming failed after some tokens were delivered, the fallback would re-deliver the full response. Now tracks deltas_were_sent and only falls back when no tokens reached consumers yet. 3. MEDIUM — Codex mode lost on_first_delta spinner callback: _run_codex_stream now accepts on_first_delta parameter, fires it on first text delta. Passed through from _interruptible_streaming_api_call via _codex_on_first_delta instance attribute. 4. MEDIUM — CLI close-tag after-text bypassed tag filtering: text after a reasoning close tag was sent directly to _emit_stream_text, skipping open-tag detection. Now routes through _stream_delta for full filtering. 5. LOW — Removed 140 lines of dead code: old _streaming_api_call method (superseded by _interruptible_streaming_api_call). Updated 13 tests in test_run_agent.py and test_openai_client_lifecycle.py to use the new method name and signature. 4573 tests passing.
This commit is contained in:
parent
99369b926c
commit
8e07f9ca56
5 changed files with 75 additions and 176 deletions
5
cli.py
5
cli.py
|
|
@ -1474,9 +1474,10 @@ class HermesCLI:
|
||||||
self._in_reasoning_block = False
|
self._in_reasoning_block = False
|
||||||
after = self._stream_prefilt[idx + len(tag):]
|
after = self._stream_prefilt[idx + len(tag):]
|
||||||
self._stream_prefilt = ""
|
self._stream_prefilt = ""
|
||||||
# Process remaining text after close tag
|
# Process remaining text after close tag through full
|
||||||
|
# filtering (it could contain another open tag)
|
||||||
if after:
|
if after:
|
||||||
self._emit_stream_text(after)
|
self._stream_delta(after)
|
||||||
return
|
return
|
||||||
# Still inside reasoning block — keep only the tail that could
|
# Still inside reasoning block — keep only the tail that could
|
||||||
# be a partial close tag prefix (save memory on long blocks).
|
# be a partial close tag prefix (save memory on long blocks).
|
||||||
|
|
|
||||||
|
|
@ -4371,10 +4371,19 @@ class GatewayRunner:
|
||||||
if tool_progress_enabled:
|
if tool_progress_enabled:
|
||||||
progress_task = asyncio.create_task(send_progress_messages())
|
progress_task = asyncio.create_task(send_progress_messages())
|
||||||
|
|
||||||
# Start stream consumer task if configured
|
# Start stream consumer task — polls for consumer creation since it
|
||||||
|
# happens inside run_sync (thread pool) after the agent is constructed.
|
||||||
stream_task = None
|
stream_task = None
|
||||||
if stream_consumer_holder[0] is not None:
|
|
||||||
stream_task = asyncio.create_task(stream_consumer_holder[0].run())
|
async def _start_stream_consumer():
|
||||||
|
"""Wait for the stream consumer to be created, then run it."""
|
||||||
|
for _ in range(200): # Up to 10s wait
|
||||||
|
if stream_consumer_holder[0] is not None:
|
||||||
|
await stream_consumer_holder[0].run()
|
||||||
|
return
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
|
stream_task = asyncio.create_task(_start_stream_consumer())
|
||||||
|
|
||||||
# Track this agent as running for this session (for interrupt support)
|
# Track this agent as running for this session (for interrupt support)
|
||||||
# We do this in a callback after the agent is created
|
# We do this in a callback after the agent is created
|
||||||
|
|
|
||||||
188
run_agent.py
188
run_agent.py
|
|
@ -2604,7 +2604,7 @@ class AIAgent:
|
||||||
def _close_request_openai_client(self, client: Any, *, reason: str) -> None:
|
def _close_request_openai_client(self, client: Any, *, reason: str) -> None:
|
||||||
self._close_openai_client(client, reason=reason, shared=False)
|
self._close_openai_client(client, reason=reason, shared=False)
|
||||||
|
|
||||||
def _run_codex_stream(self, api_kwargs: dict, client: Any = None):
|
def _run_codex_stream(self, api_kwargs: dict, client: Any = None, on_first_delta: callable = None):
|
||||||
"""Execute one streaming Responses API request and return the final response."""
|
"""Execute one streaming Responses API request and return the final response."""
|
||||||
active_client = client or self._ensure_primary_openai_client(reason="codex_stream_direct")
|
active_client = client or self._ensure_primary_openai_client(reason="codex_stream_direct")
|
||||||
max_stream_retries = 1
|
max_stream_retries = 1
|
||||||
|
|
@ -2623,6 +2623,11 @@ class AIAgent:
|
||||||
if delta_text and not has_tool_calls:
|
if delta_text and not has_tool_calls:
|
||||||
if not first_delta_fired:
|
if not first_delta_fired:
|
||||||
first_delta_fired = True
|
first_delta_fired = True
|
||||||
|
if on_first_delta:
|
||||||
|
try:
|
||||||
|
on_first_delta()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
self._fire_stream_delta(delta_text)
|
self._fire_stream_delta(delta_text)
|
||||||
# Track tool calls to suppress text streaming
|
# Track tool calls to suppress text streaming
|
||||||
elif "function_call" in event_type:
|
elif "function_call" in event_type:
|
||||||
|
|
@ -2812,6 +2817,7 @@ class AIAgent:
|
||||||
result["response"] = self._run_codex_stream(
|
result["response"] = self._run_codex_stream(
|
||||||
api_kwargs,
|
api_kwargs,
|
||||||
client=request_client_holder["client"],
|
client=request_client_holder["client"],
|
||||||
|
on_first_delta=getattr(self, "_codex_on_first_delta", None),
|
||||||
)
|
)
|
||||||
elif self.api_mode == "anthropic_messages":
|
elif self.api_mode == "anthropic_messages":
|
||||||
result["response"] = self._anthropic_messages_create(api_kwargs)
|
result["response"] = self._anthropic_messages_create(api_kwargs)
|
||||||
|
|
@ -2853,146 +2859,6 @@ class AIAgent:
|
||||||
raise result["error"]
|
raise result["error"]
|
||||||
return result["response"]
|
return result["response"]
|
||||||
|
|
||||||
def _streaming_api_call(self, api_kwargs: dict, stream_callback):
|
|
||||||
"""Streaming variant of _interruptible_api_call for voice TTS pipeline.
|
|
||||||
|
|
||||||
Uses ``stream=True`` and forwards content deltas to *stream_callback*
|
|
||||||
in real-time. Returns a ``SimpleNamespace`` that mimics a normal
|
|
||||||
``ChatCompletion`` so the rest of the agent loop works unchanged.
|
|
||||||
|
|
||||||
This method is separate from ``_interruptible_api_call`` to keep the
|
|
||||||
core agent loop untouched for non-voice users.
|
|
||||||
"""
|
|
||||||
result = {"response": None, "error": None}
|
|
||||||
request_client_holder = {"client": None}
|
|
||||||
|
|
||||||
def _call():
|
|
||||||
try:
|
|
||||||
stream_kwargs = {**api_kwargs, "stream": True}
|
|
||||||
request_client_holder["client"] = self._create_request_openai_client(
|
|
||||||
reason="chat_completion_stream_request"
|
|
||||||
)
|
|
||||||
stream = request_client_holder["client"].chat.completions.create(**stream_kwargs)
|
|
||||||
|
|
||||||
content_parts: list[str] = []
|
|
||||||
tool_calls_acc: dict[int, dict] = {}
|
|
||||||
finish_reason = None
|
|
||||||
model_name = None
|
|
||||||
role = "assistant"
|
|
||||||
|
|
||||||
for chunk in stream:
|
|
||||||
if not chunk.choices:
|
|
||||||
if hasattr(chunk, "model") and chunk.model:
|
|
||||||
model_name = chunk.model
|
|
||||||
continue
|
|
||||||
|
|
||||||
delta = chunk.choices[0].delta
|
|
||||||
if hasattr(chunk, "model") and chunk.model:
|
|
||||||
model_name = chunk.model
|
|
||||||
|
|
||||||
if delta and delta.content:
|
|
||||||
content_parts.append(delta.content)
|
|
||||||
try:
|
|
||||||
stream_callback(delta.content)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if delta and delta.tool_calls:
|
|
||||||
for tc_delta in delta.tool_calls:
|
|
||||||
idx = tc_delta.index if tc_delta.index is not None else 0
|
|
||||||
if idx in tool_calls_acc and tc_delta.id and tc_delta.id != tool_calls_acc[idx]["id"]:
|
|
||||||
matched = False
|
|
||||||
for eidx, eentry in tool_calls_acc.items():
|
|
||||||
if eentry["id"] == tc_delta.id:
|
|
||||||
idx = eidx
|
|
||||||
matched = True
|
|
||||||
break
|
|
||||||
if not matched:
|
|
||||||
idx = (max(k for k in tool_calls_acc if isinstance(k, int)) + 1) if tool_calls_acc else 0
|
|
||||||
if idx not in tool_calls_acc:
|
|
||||||
tool_calls_acc[idx] = {
|
|
||||||
"id": tc_delta.id or "",
|
|
||||||
"type": "function",
|
|
||||||
"function": {"name": "", "arguments": ""},
|
|
||||||
}
|
|
||||||
entry = tool_calls_acc[idx]
|
|
||||||
if tc_delta.id:
|
|
||||||
entry["id"] = tc_delta.id
|
|
||||||
if tc_delta.function:
|
|
||||||
if tc_delta.function.name:
|
|
||||||
entry["function"]["name"] += tc_delta.function.name
|
|
||||||
if tc_delta.function.arguments:
|
|
||||||
entry["function"]["arguments"] += tc_delta.function.arguments
|
|
||||||
|
|
||||||
if chunk.choices[0].finish_reason:
|
|
||||||
finish_reason = chunk.choices[0].finish_reason
|
|
||||||
|
|
||||||
full_content = "".join(content_parts) or None
|
|
||||||
mock_tool_calls = None
|
|
||||||
if tool_calls_acc:
|
|
||||||
mock_tool_calls = []
|
|
||||||
for idx in sorted(tool_calls_acc):
|
|
||||||
tc = tool_calls_acc[idx]
|
|
||||||
mock_tool_calls.append(SimpleNamespace(
|
|
||||||
id=tc["id"],
|
|
||||||
type=tc["type"],
|
|
||||||
function=SimpleNamespace(
|
|
||||||
name=tc["function"]["name"],
|
|
||||||
arguments=tc["function"]["arguments"],
|
|
||||||
),
|
|
||||||
))
|
|
||||||
|
|
||||||
mock_message = SimpleNamespace(
|
|
||||||
role=role,
|
|
||||||
content=full_content,
|
|
||||||
tool_calls=mock_tool_calls,
|
|
||||||
reasoning_content=None,
|
|
||||||
)
|
|
||||||
mock_choice = SimpleNamespace(
|
|
||||||
index=0,
|
|
||||||
message=mock_message,
|
|
||||||
finish_reason=finish_reason or "stop",
|
|
||||||
)
|
|
||||||
mock_response = SimpleNamespace(
|
|
||||||
id="stream-" + str(uuid.uuid4()),
|
|
||||||
model=model_name,
|
|
||||||
choices=[mock_choice],
|
|
||||||
usage=None,
|
|
||||||
)
|
|
||||||
result["response"] = mock_response
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
result["error"] = e
|
|
||||||
finally:
|
|
||||||
request_client = request_client_holder.get("client")
|
|
||||||
if request_client is not None:
|
|
||||||
self._close_request_openai_client(request_client, reason="stream_request_complete")
|
|
||||||
|
|
||||||
t = threading.Thread(target=_call, daemon=True)
|
|
||||||
t.start()
|
|
||||||
while t.is_alive():
|
|
||||||
t.join(timeout=0.3)
|
|
||||||
if self._interrupt_requested:
|
|
||||||
try:
|
|
||||||
if self.api_mode == "anthropic_messages":
|
|
||||||
from agent.anthropic_adapter import build_anthropic_client
|
|
||||||
|
|
||||||
self._anthropic_client.close()
|
|
||||||
self._anthropic_client = build_anthropic_client(
|
|
||||||
self._anthropic_api_key,
|
|
||||||
getattr(self, "_anthropic_base_url", None),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
request_client = request_client_holder.get("client")
|
|
||||||
if request_client is not None:
|
|
||||||
self._close_request_openai_client(request_client, reason="stream_interrupt_abort")
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
raise InterruptedError("Agent interrupted during API call")
|
|
||||||
if result["error"] is not None:
|
|
||||||
raise result["error"]
|
|
||||||
return result["response"]
|
|
||||||
|
|
||||||
# ── Unified streaming API call ─────────────────────────────────────────
|
# ── Unified streaming API call ─────────────────────────────────────────
|
||||||
|
|
||||||
def _fire_stream_delta(self, text: str) -> None:
|
def _fire_stream_delta(self, text: str) -> None:
|
||||||
|
|
@ -3039,12 +2905,20 @@ class AIAgent:
|
||||||
streaming is not supported.
|
streaming is not supported.
|
||||||
"""
|
"""
|
||||||
if self.api_mode == "codex_responses":
|
if self.api_mode == "codex_responses":
|
||||||
# Codex already streams internally; we just need to pass callbacks
|
# Codex streams internally via _run_codex_stream. The main dispatch
|
||||||
return self._interruptible_api_call(api_kwargs)
|
# in _interruptible_api_call already calls it; we just need to
|
||||||
|
# ensure on_first_delta reaches it. Store it on the instance
|
||||||
|
# temporarily so _run_codex_stream can pick it up.
|
||||||
|
self._codex_on_first_delta = on_first_delta
|
||||||
|
try:
|
||||||
|
return self._interruptible_api_call(api_kwargs)
|
||||||
|
finally:
|
||||||
|
self._codex_on_first_delta = None
|
||||||
|
|
||||||
result = {"response": None, "error": None}
|
result = {"response": None, "error": None}
|
||||||
request_client_holder = {"client": None}
|
request_client_holder = {"client": None}
|
||||||
first_delta_fired = {"done": False}
|
first_delta_fired = {"done": False}
|
||||||
|
deltas_were_sent = {"yes": False} # Track if any deltas were fired (for fallback)
|
||||||
|
|
||||||
def _fire_first_delta():
|
def _fire_first_delta():
|
||||||
if not first_delta_fired["done"] and on_first_delta:
|
if not first_delta_fired["done"] and on_first_delta:
|
||||||
|
|
@ -3098,6 +2972,7 @@ class AIAgent:
|
||||||
if not tool_calls_acc:
|
if not tool_calls_acc:
|
||||||
_fire_first_delta()
|
_fire_first_delta()
|
||||||
self._fire_stream_delta(delta.content)
|
self._fire_stream_delta(delta.content)
|
||||||
|
deltas_were_sent["yes"] = True
|
||||||
|
|
||||||
# Accumulate tool call deltas (silently, no callback)
|
# Accumulate tool call deltas (silently, no callback)
|
||||||
if delta and delta.tool_calls:
|
if delta and delta.tool_calls:
|
||||||
|
|
@ -3208,17 +3083,22 @@ class AIAgent:
|
||||||
else:
|
else:
|
||||||
result["response"] = _call_chat_completions()
|
result["response"] = _call_chat_completions()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Always fall back to non-streaming on ANY streaming error.
|
if deltas_were_sent["yes"]:
|
||||||
# Many third-party/extrinsic providers have partial or broken
|
# Streaming failed AFTER some tokens were already delivered
|
||||||
# streaming support — rejecting stream=True, crashing on
|
# to consumers. Don't fall back — that would cause
|
||||||
# stream_options, dropping connections mid-stream, etc.
|
# double-delivery (partial streamed + full non-streamed).
|
||||||
# A clean fallback to the standard request path ensures the
|
# Let the error propagate; the partial content already
|
||||||
# agent still works even if streaming doesn't.
|
# reached the user via the stream.
|
||||||
logger.info("Streaming failed, falling back to non-streaming: %s", e)
|
logger.warning("Streaming failed after partial delivery, not falling back: %s", e)
|
||||||
try:
|
result["error"] = e
|
||||||
result["response"] = self._interruptible_api_call(api_kwargs)
|
else:
|
||||||
except Exception as fallback_err:
|
# Streaming failed before any tokens reached consumers.
|
||||||
result["error"] = fallback_err
|
# Safe to fall back to the standard non-streaming path.
|
||||||
|
logger.info("Streaming failed before delivery, falling back to non-streaming: %s", e)
|
||||||
|
try:
|
||||||
|
result["response"] = self._interruptible_api_call(api_kwargs)
|
||||||
|
except Exception as fallback_err:
|
||||||
|
result["error"] = fallback_err
|
||||||
finally:
|
finally:
|
||||||
request_client = request_client_holder.get("client")
|
request_client = request_client_holder.get("client")
|
||||||
if request_client is not None:
|
if request_client is not None:
|
||||||
|
|
|
||||||
|
|
@ -59,8 +59,11 @@ def _build_agent(shared_client=None):
|
||||||
agent._interrupt_requested = False
|
agent._interrupt_requested = False
|
||||||
agent._interrupt_message = None
|
agent._interrupt_message = None
|
||||||
agent._client_lock = threading.RLock()
|
agent._client_lock = threading.RLock()
|
||||||
agent._client_kwargs = {"api_key": "test-key", "base_url": agent.base_url}
|
agent._client_kwargs = {"api_key": "***", "base_url": agent.base_url}
|
||||||
agent.client = shared_client or FakeSharedClient(lambda **kwargs: {"shared": True})
|
agent.client = shared_client or FakeSharedClient(lambda **kwargs: {"shared": True})
|
||||||
|
agent.stream_delta_callback = None
|
||||||
|
agent._stream_callback = None
|
||||||
|
agent.reasoning_callback = None
|
||||||
return agent
|
return agent
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -173,7 +176,11 @@ def test_streaming_call_recreates_closed_shared_client_before_request(monkeypatc
|
||||||
monkeypatch.setattr(run_agent, "OpenAI", factory)
|
monkeypatch.setattr(run_agent, "OpenAI", factory)
|
||||||
|
|
||||||
agent = _build_agent(shared_client=stale_shared)
|
agent = _build_agent(shared_client=stale_shared)
|
||||||
response = agent._streaming_api_call({"model": agent.model, "messages": []}, lambda _delta: None)
|
agent.stream_delta_callback = lambda _delta: None
|
||||||
|
# Force chat_completions mode so the streaming path uses
|
||||||
|
# chat.completions.create(stream=True) instead of Codex responses.stream()
|
||||||
|
agent.api_mode = "chat_completions"
|
||||||
|
response = agent._interruptible_streaming_api_call({"model": agent.model, "messages": []})
|
||||||
|
|
||||||
assert response.choices[0].message.content == "Hello world"
|
assert response.choices[0].message.content == "Hello world"
|
||||||
assert agent.client is replacement_shared
|
assert agent.client is replacement_shared
|
||||||
|
|
|
||||||
|
|
@ -2329,8 +2329,9 @@ class TestStreamingApiCall:
|
||||||
]
|
]
|
||||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||||
callback = MagicMock()
|
callback = MagicMock()
|
||||||
|
agent.stream_delta_callback = callback
|
||||||
|
|
||||||
resp = agent._streaming_api_call({"messages": []}, callback)
|
resp = agent._interruptible_streaming_api_call({"messages": []})
|
||||||
|
|
||||||
assert resp.choices[0].message.content == "Hello World"
|
assert resp.choices[0].message.content == "Hello World"
|
||||||
assert resp.choices[0].finish_reason == "stop"
|
assert resp.choices[0].finish_reason == "stop"
|
||||||
|
|
@ -2347,7 +2348,7 @@ class TestStreamingApiCall:
|
||||||
]
|
]
|
||||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||||
|
|
||||||
resp = agent._streaming_api_call({"messages": []}, MagicMock())
|
resp = agent._interruptible_streaming_api_call({"messages": []})
|
||||||
|
|
||||||
tc = resp.choices[0].message.tool_calls
|
tc = resp.choices[0].message.tool_calls
|
||||||
assert len(tc) == 1
|
assert len(tc) == 1
|
||||||
|
|
@ -2363,7 +2364,7 @@ class TestStreamingApiCall:
|
||||||
]
|
]
|
||||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||||
|
|
||||||
resp = agent._streaming_api_call({"messages": []}, MagicMock())
|
resp = agent._interruptible_streaming_api_call({"messages": []})
|
||||||
|
|
||||||
tc = resp.choices[0].message.tool_calls
|
tc = resp.choices[0].message.tool_calls
|
||||||
assert len(tc) == 2
|
assert len(tc) == 2
|
||||||
|
|
@ -2378,7 +2379,7 @@ class TestStreamingApiCall:
|
||||||
]
|
]
|
||||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||||
|
|
||||||
resp = agent._streaming_api_call({"messages": []}, MagicMock())
|
resp = agent._interruptible_streaming_api_call({"messages": []})
|
||||||
|
|
||||||
assert resp.choices[0].message.content == "I'll search"
|
assert resp.choices[0].message.content == "I'll search"
|
||||||
assert len(resp.choices[0].message.tool_calls) == 1
|
assert len(resp.choices[0].message.tool_calls) == 1
|
||||||
|
|
@ -2387,7 +2388,7 @@ class TestStreamingApiCall:
|
||||||
chunks = [_make_chunk(finish_reason="stop")]
|
chunks = [_make_chunk(finish_reason="stop")]
|
||||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||||
|
|
||||||
resp = agent._streaming_api_call({"messages": []}, MagicMock())
|
resp = agent._interruptible_streaming_api_call({"messages": []})
|
||||||
|
|
||||||
assert resp.choices[0].message.content is None
|
assert resp.choices[0].message.content is None
|
||||||
assert resp.choices[0].message.tool_calls is None
|
assert resp.choices[0].message.tool_calls is None
|
||||||
|
|
@ -2399,9 +2400,9 @@ class TestStreamingApiCall:
|
||||||
_make_chunk(finish_reason="stop"),
|
_make_chunk(finish_reason="stop"),
|
||||||
]
|
]
|
||||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||||
callback = MagicMock(side_effect=ValueError("boom"))
|
agent.stream_delta_callback = MagicMock(side_effect=ValueError("boom"))
|
||||||
|
|
||||||
resp = agent._streaming_api_call({"messages": []}, callback)
|
resp = agent._interruptible_streaming_api_call({"messages": []})
|
||||||
|
|
||||||
assert resp.choices[0].message.content == "Hello World"
|
assert resp.choices[0].message.content == "Hello World"
|
||||||
|
|
||||||
|
|
@ -2412,7 +2413,7 @@ class TestStreamingApiCall:
|
||||||
]
|
]
|
||||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||||
|
|
||||||
resp = agent._streaming_api_call({"messages": []}, MagicMock())
|
resp = agent._interruptible_streaming_api_call({"messages": []})
|
||||||
|
|
||||||
assert resp.model == "gpt-4o"
|
assert resp.model == "gpt-4o"
|
||||||
|
|
||||||
|
|
@ -2420,22 +2421,23 @@ class TestStreamingApiCall:
|
||||||
chunks = [_make_chunk(content="x"), _make_chunk(finish_reason="stop")]
|
chunks = [_make_chunk(content="x"), _make_chunk(finish_reason="stop")]
|
||||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||||
|
|
||||||
agent._streaming_api_call({"messages": [], "model": "test"}, MagicMock())
|
agent._interruptible_streaming_api_call({"messages": [], "model": "test"})
|
||||||
|
|
||||||
call_kwargs = agent.client.chat.completions.create.call_args
|
call_kwargs = agent.client.chat.completions.create.call_args
|
||||||
assert call_kwargs[1].get("stream") is True or call_kwargs.kwargs.get("stream") is True
|
assert call_kwargs[1].get("stream") is True or call_kwargs.kwargs.get("stream") is True
|
||||||
|
|
||||||
def test_api_exception_propagated(self, agent):
|
def test_api_exception_falls_back_to_non_streaming(self, agent):
|
||||||
|
"""When streaming fails before any deltas, fallback to non-streaming is attempted."""
|
||||||
agent.client.chat.completions.create.side_effect = ConnectionError("fail")
|
agent.client.chat.completions.create.side_effect = ConnectionError("fail")
|
||||||
|
# The fallback also uses the same client, so it'll fail too
|
||||||
with pytest.raises(ConnectionError, match="fail"):
|
with pytest.raises(ConnectionError, match="fail"):
|
||||||
agent._streaming_api_call({"messages": []}, MagicMock())
|
agent._interruptible_streaming_api_call({"messages": []})
|
||||||
|
|
||||||
def test_response_has_uuid_id(self, agent):
|
def test_response_has_uuid_id(self, agent):
|
||||||
chunks = [_make_chunk(content="x"), _make_chunk(finish_reason="stop")]
|
chunks = [_make_chunk(content="x"), _make_chunk(finish_reason="stop")]
|
||||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||||
|
|
||||||
resp = agent._streaming_api_call({"messages": []}, MagicMock())
|
resp = agent._interruptible_streaming_api_call({"messages": []})
|
||||||
|
|
||||||
assert resp.id.startswith("stream-")
|
assert resp.id.startswith("stream-")
|
||||||
assert len(resp.id) > len("stream-")
|
assert len(resp.id) > len("stream-")
|
||||||
|
|
@ -2449,7 +2451,7 @@ class TestStreamingApiCall:
|
||||||
]
|
]
|
||||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||||
|
|
||||||
resp = agent._streaming_api_call({"messages": []}, MagicMock())
|
resp = agent._interruptible_streaming_api_call({"messages": []})
|
||||||
|
|
||||||
assert resp.choices[0].message.content == "Hello"
|
assert resp.choices[0].message.content == "Hello"
|
||||||
assert resp.model == "gpt-4"
|
assert resp.model == "gpt-4"
|
||||||
|
|
@ -2505,7 +2507,7 @@ class TestAnthropicInterruptHandler:
|
||||||
def test_streaming_has_anthropic_branch(self):
|
def test_streaming_has_anthropic_branch(self):
|
||||||
"""_streaming_api_call must also handle Anthropic interrupt."""
|
"""_streaming_api_call must also handle Anthropic interrupt."""
|
||||||
import inspect
|
import inspect
|
||||||
source = inspect.getsource(AIAgent._streaming_api_call)
|
source = inspect.getsource(AIAgent._interruptible_streaming_api_call)
|
||||||
assert "anthropic_messages" in source, \
|
assert "anthropic_messages" in source, \
|
||||||
"_streaming_api_call must handle Anthropic interrupt"
|
"_streaming_api_call must handle Anthropic interrupt"
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue