fix(context_compressor): prevent consecutive same-role messages after compression (#1743)
compress() checks both the head and tail neighbors when choosing the summary message role. When only the tail collides, the role is flipped. When BOTH roles would create consecutive same-role messages (e.g. head=assistant, tail=user), the summary is merged into the first tail message instead of inserting a standalone message that breaks role alternation and causes API 400 errors. The previous code handled head-side collision but left the tail-side uncovered — long conversations would crash mid-reply with no useful error, forcing the user to /reset and lose session history. Based on PR #1186 by @alireza78a, with improved double-collision handling (merge into tail instead of unconditional 'user' fallback). Co-authored-by: alireza78a <alireza78.crypto@gmail.com>
This commit is contained in:
parent
702191049f
commit
548cedb869
2 changed files with 160 additions and 3 deletions
|
|
@ -311,6 +311,7 @@ Write only the summary body. Do not include any preamble or prefix; the system w
|
||||||
)
|
)
|
||||||
compressed.append(msg)
|
compressed.append(msg)
|
||||||
|
|
||||||
|
_merge_summary_into_tail = False
|
||||||
if summary:
|
if summary:
|
||||||
last_head_role = messages[compress_start - 1].get("role", "user") if compress_start > 0 else "user"
|
last_head_role = messages[compress_start - 1].get("role", "user") if compress_start > 0 else "user"
|
||||||
first_tail_role = messages[compress_end].get("role", "user") if compress_end < n_messages else "user"
|
first_tail_role = messages[compress_end].get("role", "user") if compress_end < n_messages else "user"
|
||||||
|
|
@ -326,13 +327,25 @@ Write only the summary body. Do not include any preamble or prefix; the system w
|
||||||
flipped = "assistant" if summary_role == "user" else "user"
|
flipped = "assistant" if summary_role == "user" else "user"
|
||||||
if flipped != last_head_role:
|
if flipped != last_head_role:
|
||||||
summary_role = flipped
|
summary_role = flipped
|
||||||
compressed.append({"role": summary_role, "content": summary})
|
else:
|
||||||
|
# Both roles would create consecutive same-role messages
|
||||||
|
# (e.g. head=assistant, tail=user — neither role works).
|
||||||
|
# Merge the summary into the first tail message instead
|
||||||
|
# of inserting a standalone message that breaks alternation.
|
||||||
|
_merge_summary_into_tail = True
|
||||||
|
if not _merge_summary_into_tail:
|
||||||
|
compressed.append({"role": summary_role, "content": summary})
|
||||||
else:
|
else:
|
||||||
if not self.quiet_mode:
|
if not self.quiet_mode:
|
||||||
print(" ⚠️ No summary model available — middle turns dropped without summary")
|
print(" ⚠️ No summary model available — middle turns dropped without summary")
|
||||||
|
|
||||||
for i in range(compress_end, n_messages):
|
for i in range(compress_end, n_messages):
|
||||||
compressed.append(messages[i].copy())
|
msg = messages[i].copy()
|
||||||
|
if _merge_summary_into_tail and i == compress_end:
|
||||||
|
original = msg.get("content") or ""
|
||||||
|
msg["content"] = summary + "\n\n" + original
|
||||||
|
_merge_summary_into_tail = False
|
||||||
|
compressed.append(msg)
|
||||||
|
|
||||||
self.compression_count += 1
|
self.compression_count += 1
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -111,7 +111,11 @@ class TestCompress:
|
||||||
# First 2 messages should be preserved (protect_first_n=2)
|
# First 2 messages should be preserved (protect_first_n=2)
|
||||||
# Last 2 messages should be preserved (protect_last_n=2)
|
# Last 2 messages should be preserved (protect_last_n=2)
|
||||||
assert result[-1]["content"] == msgs[-1]["content"]
|
assert result[-1]["content"] == msgs[-1]["content"]
|
||||||
assert result[-2]["content"] == msgs[-2]["content"]
|
# The second-to-last tail message may have the summary merged
|
||||||
|
# into it when a double-collision prevents a standalone summary
|
||||||
|
# (head=assistant, tail=user in this fixture). Verify the
|
||||||
|
# original content is present in either case.
|
||||||
|
assert msgs[-2]["content"] in result[-2]["content"]
|
||||||
|
|
||||||
|
|
||||||
class TestGenerateSummaryNoneContent:
|
class TestGenerateSummaryNoneContent:
|
||||||
|
|
@ -329,6 +333,146 @@ class TestCompressWithClient:
|
||||||
assert len(summary_msg) == 1
|
assert len(summary_msg) == 1
|
||||||
assert summary_msg[0]["role"] == "assistant"
|
assert summary_msg[0]["role"] == "assistant"
|
||||||
|
|
||||||
|
def test_summary_role_flips_to_avoid_tail_collision(self):
|
||||||
|
"""When summary role collides with the first tail message but flipping
|
||||||
|
doesn't collide with head, the role should be flipped."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.choices = [MagicMock()]
|
||||||
|
mock_response.choices[0].message.content = "summary text"
|
||||||
|
|
||||||
|
with patch("agent.context_compressor.get_model_context_length", return_value=100000):
|
||||||
|
c = ContextCompressor(model="test", quiet_mode=True, protect_first_n=2, protect_last_n=2)
|
||||||
|
|
||||||
|
# Head ends with tool (index 1), tail starts with user (index 6).
|
||||||
|
# Default: tool → summary_role="user" → collides with tail.
|
||||||
|
# Flip to "assistant" → tool→assistant is fine.
|
||||||
|
msgs = [
|
||||||
|
{"role": "user", "content": "msg 0"},
|
||||||
|
{"role": "assistant", "content": "", "tool_calls": [
|
||||||
|
{"id": "call_1", "type": "function", "function": {"name": "t", "arguments": "{}"}},
|
||||||
|
]},
|
||||||
|
{"role": "tool", "tool_call_id": "call_1", "content": "result 1"},
|
||||||
|
{"role": "assistant", "content": "msg 3"},
|
||||||
|
{"role": "user", "content": "msg 4"},
|
||||||
|
{"role": "assistant", "content": "msg 5"},
|
||||||
|
{"role": "user", "content": "msg 6"},
|
||||||
|
{"role": "assistant", "content": "msg 7"},
|
||||||
|
]
|
||||||
|
with patch("agent.context_compressor.call_llm", return_value=mock_response):
|
||||||
|
result = c.compress(msgs)
|
||||||
|
# Verify no consecutive user or assistant messages
|
||||||
|
for i in range(1, len(result)):
|
||||||
|
r1 = result[i - 1].get("role")
|
||||||
|
r2 = result[i].get("role")
|
||||||
|
if r1 in ("user", "assistant") and r2 in ("user", "assistant"):
|
||||||
|
assert r1 != r2, f"consecutive {r1} at indices {i-1},{i}"
|
||||||
|
|
||||||
|
def test_double_collision_merges_summary_into_tail(self):
|
||||||
|
"""When neither role avoids collision with both neighbors, the summary
|
||||||
|
should be merged into the first tail message rather than creating a
|
||||||
|
standalone message that breaks role alternation.
|
||||||
|
|
||||||
|
Common scenario: head ends with 'assistant', tail starts with 'user'.
|
||||||
|
summary='user' collides with tail, summary='assistant' collides with head.
|
||||||
|
"""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.choices = [MagicMock()]
|
||||||
|
mock_response.choices[0].message.content = "summary text"
|
||||||
|
|
||||||
|
with patch("agent.context_compressor.get_model_context_length", return_value=100000):
|
||||||
|
c = ContextCompressor(model="test", quiet_mode=True, protect_first_n=3, protect_last_n=3)
|
||||||
|
|
||||||
|
# Head: [system, user, assistant] → last head = assistant
|
||||||
|
# Tail: [user, assistant, user] → first tail = user
|
||||||
|
# summary_role="user" collides with tail, "assistant" collides with head → merge
|
||||||
|
msgs = [
|
||||||
|
{"role": "system", "content": "system prompt"},
|
||||||
|
{"role": "user", "content": "msg 1"},
|
||||||
|
{"role": "assistant", "content": "msg 2"},
|
||||||
|
{"role": "user", "content": "msg 3"}, # compressed
|
||||||
|
{"role": "assistant", "content": "msg 4"}, # compressed
|
||||||
|
{"role": "user", "content": "msg 5"}, # compressed
|
||||||
|
{"role": "user", "content": "msg 6"}, # tail start
|
||||||
|
{"role": "assistant", "content": "msg 7"},
|
||||||
|
{"role": "user", "content": "msg 8"},
|
||||||
|
]
|
||||||
|
with patch("agent.context_compressor.call_llm", return_value=mock_response):
|
||||||
|
result = c.compress(msgs)
|
||||||
|
|
||||||
|
# Verify no consecutive user or assistant messages
|
||||||
|
for i in range(1, len(result)):
|
||||||
|
r1 = result[i - 1].get("role")
|
||||||
|
r2 = result[i].get("role")
|
||||||
|
if r1 in ("user", "assistant") and r2 in ("user", "assistant"):
|
||||||
|
assert r1 != r2, f"consecutive {r1} at indices {i-1},{i}"
|
||||||
|
|
||||||
|
# The summary text should be merged into the first tail message
|
||||||
|
first_tail = [m for m in result if "msg 6" in (m.get("content") or "")]
|
||||||
|
assert len(first_tail) == 1
|
||||||
|
assert "summary text" in first_tail[0]["content"]
|
||||||
|
|
||||||
|
def test_double_collision_user_head_assistant_tail(self):
|
||||||
|
"""Reverse double collision: head ends with 'user', tail starts with 'assistant'.
|
||||||
|
summary='assistant' collides with tail, 'user' collides with head → merge."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.choices = [MagicMock()]
|
||||||
|
mock_response.choices[0].message.content = "summary text"
|
||||||
|
|
||||||
|
with patch("agent.context_compressor.get_model_context_length", return_value=100000):
|
||||||
|
c = ContextCompressor(model="test", quiet_mode=True, protect_first_n=2, protect_last_n=2)
|
||||||
|
|
||||||
|
# Head: [system, user] → last head = user
|
||||||
|
# Tail: [assistant, user] → first tail = assistant
|
||||||
|
# summary_role="assistant" collides with tail, "user" collides with head → merge
|
||||||
|
msgs = [
|
||||||
|
{"role": "system", "content": "system prompt"},
|
||||||
|
{"role": "user", "content": "msg 1"},
|
||||||
|
{"role": "assistant", "content": "msg 2"}, # compressed
|
||||||
|
{"role": "user", "content": "msg 3"}, # compressed
|
||||||
|
{"role": "assistant", "content": "msg 4"}, # compressed
|
||||||
|
{"role": "assistant", "content": "msg 5"}, # tail start
|
||||||
|
{"role": "user", "content": "msg 6"},
|
||||||
|
]
|
||||||
|
with patch("agent.context_compressor.call_llm", return_value=mock_response):
|
||||||
|
result = c.compress(msgs)
|
||||||
|
|
||||||
|
# Verify no consecutive user or assistant messages
|
||||||
|
for i in range(1, len(result)):
|
||||||
|
r1 = result[i - 1].get("role")
|
||||||
|
r2 = result[i].get("role")
|
||||||
|
if r1 in ("user", "assistant") and r2 in ("user", "assistant"):
|
||||||
|
assert r1 != r2, f"consecutive {r1} at indices {i-1},{i}"
|
||||||
|
|
||||||
|
# The summary should be merged into the first tail message (assistant)
|
||||||
|
first_tail = [m for m in result if "msg 5" in (m.get("content") or "")]
|
||||||
|
assert len(first_tail) == 1
|
||||||
|
assert "summary text" in first_tail[0]["content"]
|
||||||
|
|
||||||
|
def test_no_collision_scenarios_still_work(self):
|
||||||
|
"""Verify that the common no-collision cases (head=assistant/tail=assistant,
|
||||||
|
head=user/tail=user) still produce a standalone summary message."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.choices = [MagicMock()]
|
||||||
|
mock_response.choices[0].message.content = "summary text"
|
||||||
|
|
||||||
|
with patch("agent.context_compressor.get_model_context_length", return_value=100000):
|
||||||
|
c = ContextCompressor(model="test", quiet_mode=True, protect_first_n=2, protect_last_n=2)
|
||||||
|
|
||||||
|
# Head=assistant, Tail=assistant → summary_role="user", no collision
|
||||||
|
msgs = [
|
||||||
|
{"role": "user", "content": "msg 0"},
|
||||||
|
{"role": "assistant", "content": "msg 1"},
|
||||||
|
{"role": "user", "content": "msg 2"},
|
||||||
|
{"role": "assistant", "content": "msg 3"},
|
||||||
|
{"role": "assistant", "content": "msg 4"},
|
||||||
|
{"role": "user", "content": "msg 5"},
|
||||||
|
]
|
||||||
|
with patch("agent.context_compressor.call_llm", return_value=mock_response):
|
||||||
|
result = c.compress(msgs)
|
||||||
|
summary_msgs = [m for m in result if (m.get("content") or "").startswith(SUMMARY_PREFIX)]
|
||||||
|
assert len(summary_msgs) == 1, "should have a standalone summary message"
|
||||||
|
assert summary_msgs[0]["role"] == "user"
|
||||||
|
|
||||||
def test_summarization_does_not_start_tail_with_tool_outputs(self):
|
def test_summarization_does_not_start_tail_with_tool_outputs(self):
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.choices = [MagicMock()]
|
mock_response.choices = [MagicMock()]
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue