Merge pull request #153 from tekelala/main
fix(agent): handle 413 payload-too-large via compression instead of aborting
This commit is contained in:
commit
2c817ce4a5
7 changed files with 895 additions and 9 deletions
157
tests/gateway/test_document_cache.py
Normal file
157
tests/gateway/test_document_cache.py
Normal file
|
|
@ -0,0 +1,157 @@
|
|||
"""
|
||||
Tests for document cache utilities in gateway/platforms/base.py.
|
||||
|
||||
Covers: get_document_cache_dir, cache_document_from_bytes,
|
||||
cleanup_document_cache, SUPPORTED_DOCUMENT_TYPES.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.platforms.base import (
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
cache_document_from_bytes,
|
||||
cleanup_document_cache,
|
||||
get_document_cache_dir,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixture: redirect DOCUMENT_CACHE_DIR to a temp directory for every test
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _redirect_cache(tmp_path, monkeypatch):
|
||||
"""Point the module-level DOCUMENT_CACHE_DIR to a fresh tmp_path."""
|
||||
monkeypatch.setattr(
|
||||
"gateway.platforms.base.DOCUMENT_CACHE_DIR", tmp_path / "doc_cache"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestGetDocumentCacheDir
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGetDocumentCacheDir:
|
||||
def test_creates_directory(self, tmp_path):
|
||||
cache_dir = get_document_cache_dir()
|
||||
assert cache_dir.exists()
|
||||
assert cache_dir.is_dir()
|
||||
|
||||
def test_returns_existing_directory(self):
|
||||
first = get_document_cache_dir()
|
||||
second = get_document_cache_dir()
|
||||
assert first == second
|
||||
assert first.exists()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestCacheDocumentFromBytes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCacheDocumentFromBytes:
|
||||
def test_basic_caching(self):
|
||||
data = b"hello world"
|
||||
path = cache_document_from_bytes(data, "test.txt")
|
||||
assert os.path.exists(path)
|
||||
assert Path(path).read_bytes() == data
|
||||
|
||||
def test_filename_preserved_in_path(self):
|
||||
path = cache_document_from_bytes(b"data", "report.pdf")
|
||||
assert "report.pdf" in os.path.basename(path)
|
||||
|
||||
def test_empty_filename_uses_fallback(self):
|
||||
path = cache_document_from_bytes(b"data", "")
|
||||
assert "document" in os.path.basename(path)
|
||||
|
||||
def test_unique_filenames(self):
|
||||
p1 = cache_document_from_bytes(b"a", "same.txt")
|
||||
p2 = cache_document_from_bytes(b"b", "same.txt")
|
||||
assert p1 != p2
|
||||
|
||||
def test_path_traversal_blocked(self):
|
||||
"""Malicious directory components are stripped — only the leaf name survives."""
|
||||
path = cache_document_from_bytes(b"data", "../../etc/passwd")
|
||||
basename = os.path.basename(path)
|
||||
assert "passwd" in basename
|
||||
# Must NOT contain directory separators
|
||||
assert ".." not in basename
|
||||
# File must reside inside the cache directory
|
||||
cache_dir = get_document_cache_dir()
|
||||
assert Path(path).resolve().is_relative_to(cache_dir.resolve())
|
||||
|
||||
def test_null_bytes_stripped(self):
|
||||
path = cache_document_from_bytes(b"data", "file\x00.pdf")
|
||||
basename = os.path.basename(path)
|
||||
assert "\x00" not in basename
|
||||
assert "file.pdf" in basename
|
||||
|
||||
def test_dot_dot_filename_handled(self):
|
||||
"""A filename that is literally '..' falls back to 'document'."""
|
||||
path = cache_document_from_bytes(b"data", "..")
|
||||
basename = os.path.basename(path)
|
||||
assert "document" in basename
|
||||
|
||||
def test_none_filename_uses_fallback(self):
|
||||
path = cache_document_from_bytes(b"data", None)
|
||||
assert "document" in os.path.basename(path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestCleanupDocumentCache
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCleanupDocumentCache:
|
||||
def test_removes_old_files(self, tmp_path):
|
||||
cache_dir = get_document_cache_dir()
|
||||
old_file = cache_dir / "old.txt"
|
||||
old_file.write_text("old")
|
||||
# Set modification time to 48 hours ago
|
||||
old_mtime = time.time() - 48 * 3600
|
||||
os.utime(old_file, (old_mtime, old_mtime))
|
||||
|
||||
removed = cleanup_document_cache(max_age_hours=24)
|
||||
assert removed == 1
|
||||
assert not old_file.exists()
|
||||
|
||||
def test_keeps_recent_files(self):
|
||||
cache_dir = get_document_cache_dir()
|
||||
recent = cache_dir / "recent.txt"
|
||||
recent.write_text("fresh")
|
||||
|
||||
removed = cleanup_document_cache(max_age_hours=24)
|
||||
assert removed == 0
|
||||
assert recent.exists()
|
||||
|
||||
def test_returns_removed_count(self):
|
||||
cache_dir = get_document_cache_dir()
|
||||
old_time = time.time() - 48 * 3600
|
||||
for i in range(3):
|
||||
f = cache_dir / f"old_{i}.txt"
|
||||
f.write_text("x")
|
||||
os.utime(f, (old_time, old_time))
|
||||
|
||||
assert cleanup_document_cache(max_age_hours=24) == 3
|
||||
|
||||
def test_empty_cache_dir(self):
|
||||
assert cleanup_document_cache(max_age_hours=24) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestSupportedDocumentTypes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSupportedDocumentTypes:
|
||||
def test_all_extensions_have_mime_types(self):
|
||||
for ext, mime in SUPPORTED_DOCUMENT_TYPES.items():
|
||||
assert ext.startswith("."), f"{ext} missing leading dot"
|
||||
assert "/" in mime, f"{mime} is not a valid MIME type"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ext",
|
||||
[".pdf", ".md", ".txt", ".docx", ".xlsx", ".pptx"],
|
||||
)
|
||||
def test_expected_extensions_present(self, ext):
|
||||
assert ext in SUPPORTED_DOCUMENT_TYPES
|
||||
338
tests/gateway/test_telegram_documents.py
Normal file
338
tests/gateway/test_telegram_documents.py
Normal file
|
|
@ -0,0 +1,338 @@
|
|||
"""
|
||||
Tests for Telegram document handling in gateway/platforms/telegram.py.
|
||||
|
||||
Covers: document type detection, download/cache flow, size limits,
|
||||
text injection, error handling.
|
||||
|
||||
Note: python-telegram-bot may not be installed in the test environment.
|
||||
We mock the telegram module at import time to avoid collection errors.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import (
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock the telegram package if it's not installed
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _ensure_telegram_mock():
|
||||
"""Install mock telegram modules so TelegramAdapter can be imported."""
|
||||
if "telegram" in sys.modules and hasattr(sys.modules["telegram"], "__file__"):
|
||||
# Real library is installed — no mocking needed
|
||||
return
|
||||
|
||||
telegram_mod = MagicMock()
|
||||
# ContextTypes needs DEFAULT_TYPE as an actual attribute for the annotation
|
||||
telegram_mod.ext.ContextTypes.DEFAULT_TYPE = type(None)
|
||||
telegram_mod.constants.ParseMode.MARKDOWN_V2 = "MarkdownV2"
|
||||
telegram_mod.constants.ChatType.GROUP = "group"
|
||||
telegram_mod.constants.ChatType.SUPERGROUP = "supergroup"
|
||||
telegram_mod.constants.ChatType.CHANNEL = "channel"
|
||||
telegram_mod.constants.ChatType.PRIVATE = "private"
|
||||
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants"):
|
||||
sys.modules.setdefault(name, telegram_mod)
|
||||
|
||||
|
||||
_ensure_telegram_mock()
|
||||
|
||||
# Now we can safely import
|
||||
from gateway.platforms.telegram import TelegramAdapter # noqa: E402
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers to build mock Telegram objects
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_file_obj(data: bytes = b"hello"):
|
||||
"""Create a mock Telegram File with download_as_bytearray."""
|
||||
f = AsyncMock()
|
||||
f.download_as_bytearray = AsyncMock(return_value=bytearray(data))
|
||||
f.file_path = "documents/file.pdf"
|
||||
return f
|
||||
|
||||
|
||||
def _make_document(
|
||||
file_name="report.pdf",
|
||||
mime_type="application/pdf",
|
||||
file_size=1024,
|
||||
file_obj=None,
|
||||
):
|
||||
"""Create a mock Telegram Document object."""
|
||||
doc = MagicMock()
|
||||
doc.file_name = file_name
|
||||
doc.mime_type = mime_type
|
||||
doc.file_size = file_size
|
||||
doc.get_file = AsyncMock(return_value=file_obj or _make_file_obj())
|
||||
return doc
|
||||
|
||||
|
||||
def _make_message(document=None, caption=None):
|
||||
"""Build a mock Telegram Message with the given document."""
|
||||
msg = MagicMock()
|
||||
msg.message_id = 42
|
||||
msg.text = caption or ""
|
||||
msg.caption = caption
|
||||
msg.date = None
|
||||
# Media flags — all None except document
|
||||
msg.photo = None
|
||||
msg.video = None
|
||||
msg.audio = None
|
||||
msg.voice = None
|
||||
msg.sticker = None
|
||||
msg.document = document
|
||||
# Chat / user
|
||||
msg.chat = MagicMock()
|
||||
msg.chat.id = 100
|
||||
msg.chat.type = "private"
|
||||
msg.chat.title = None
|
||||
msg.chat.full_name = "Test User"
|
||||
msg.from_user = MagicMock()
|
||||
msg.from_user.id = 1
|
||||
msg.from_user.full_name = "Test User"
|
||||
msg.message_thread_id = None
|
||||
return msg
|
||||
|
||||
|
||||
def _make_update(msg):
|
||||
"""Wrap a message in a mock Update."""
|
||||
update = MagicMock()
|
||||
update.message = msg
|
||||
return update
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture()
|
||||
def adapter():
|
||||
config = PlatformConfig(enabled=True, token="fake-token")
|
||||
a = TelegramAdapter(config)
|
||||
# Capture events instead of processing them
|
||||
a.handle_message = AsyncMock()
|
||||
return a
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _redirect_cache(tmp_path, monkeypatch):
|
||||
"""Point document cache to tmp_path so tests don't touch ~/.hermes."""
|
||||
monkeypatch.setattr(
|
||||
"gateway.platforms.base.DOCUMENT_CACHE_DIR", tmp_path / "doc_cache"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestDocumentTypeDetection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDocumentTypeDetection:
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_detected_explicitly(self, adapter):
|
||||
doc = _make_document()
|
||||
msg = _make_message(document=doc)
|
||||
update = _make_update(msg)
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert event.message_type == MessageType.DOCUMENT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_is_document(self, adapter):
|
||||
"""When no specific media attr is set, message_type defaults to DOCUMENT."""
|
||||
msg = _make_message()
|
||||
msg.document = None # no media at all
|
||||
update = _make_update(msg)
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert event.message_type == MessageType.DOCUMENT
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestDocumentDownloadBlock
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDocumentDownloadBlock:
|
||||
@pytest.mark.asyncio
|
||||
async def test_supported_pdf_is_cached(self, adapter):
|
||||
pdf_bytes = b"%PDF-1.4 fake"
|
||||
file_obj = _make_file_obj(pdf_bytes)
|
||||
doc = _make_document(file_name="report.pdf", file_size=1024, file_obj=file_obj)
|
||||
msg = _make_message(document=doc)
|
||||
update = _make_update(msg)
|
||||
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert len(event.media_urls) == 1
|
||||
assert os.path.exists(event.media_urls[0])
|
||||
assert event.media_types == ["application/pdf"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_supported_txt_injects_content(self, adapter):
|
||||
content = b"Hello from a text file"
|
||||
file_obj = _make_file_obj(content)
|
||||
doc = _make_document(
|
||||
file_name="notes.txt", mime_type="text/plain",
|
||||
file_size=len(content), file_obj=file_obj,
|
||||
)
|
||||
msg = _make_message(document=doc)
|
||||
update = _make_update(msg)
|
||||
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert "Hello from a text file" in event.text
|
||||
assert "[Content of notes.txt]" in event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_supported_md_injects_content(self, adapter):
|
||||
content = b"# Title\nSome markdown"
|
||||
file_obj = _make_file_obj(content)
|
||||
doc = _make_document(
|
||||
file_name="readme.md", mime_type="text/markdown",
|
||||
file_size=len(content), file_obj=file_obj,
|
||||
)
|
||||
msg = _make_message(document=doc)
|
||||
update = _make_update(msg)
|
||||
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert "# Title" in event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_caption_preserved_with_injection(self, adapter):
|
||||
content = b"file text"
|
||||
file_obj = _make_file_obj(content)
|
||||
doc = _make_document(
|
||||
file_name="doc.txt", mime_type="text/plain",
|
||||
file_size=len(content), file_obj=file_obj,
|
||||
)
|
||||
msg = _make_message(document=doc, caption="Please summarize")
|
||||
update = _make_update(msg)
|
||||
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert "file text" in event.text
|
||||
assert "Please summarize" in event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsupported_type_rejected(self, adapter):
|
||||
doc = _make_document(file_name="archive.zip", mime_type="application/zip", file_size=100)
|
||||
msg = _make_message(document=doc)
|
||||
update = _make_update(msg)
|
||||
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert "Unsupported document type" in event.text
|
||||
assert ".zip" in event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oversized_file_rejected(self, adapter):
|
||||
doc = _make_document(file_name="huge.pdf", file_size=25 * 1024 * 1024)
|
||||
msg = _make_message(document=doc)
|
||||
update = _make_update(msg)
|
||||
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert "too large" in event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_file_size_rejected(self, adapter):
|
||||
"""Security fix: file_size=None must be rejected (not silently allowed)."""
|
||||
doc = _make_document(file_name="tricky.pdf", file_size=None)
|
||||
msg = _make_message(document=doc)
|
||||
update = _make_update(msg)
|
||||
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert "too large" in event.text or "could not be verified" in event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_filename_uses_mime_lookup(self, adapter):
|
||||
"""No file_name but valid mime_type should resolve to extension."""
|
||||
content = b"some pdf bytes"
|
||||
file_obj = _make_file_obj(content)
|
||||
doc = _make_document(
|
||||
file_name=None, mime_type="application/pdf",
|
||||
file_size=len(content), file_obj=file_obj,
|
||||
)
|
||||
msg = _make_message(document=doc)
|
||||
update = _make_update(msg)
|
||||
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert len(event.media_urls) == 1
|
||||
assert event.media_types == ["application/pdf"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_filename_and_mime_rejected(self, adapter):
|
||||
doc = _make_document(file_name=None, mime_type=None, file_size=100)
|
||||
msg = _make_message(document=doc)
|
||||
update = _make_update(msg)
|
||||
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert "Unsupported" in event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unicode_decode_error_handled(self, adapter):
|
||||
"""Binary bytes that aren't valid UTF-8 in a .txt — content not injected but file still cached."""
|
||||
binary = bytes(range(128, 256)) # not valid UTF-8
|
||||
file_obj = _make_file_obj(binary)
|
||||
doc = _make_document(
|
||||
file_name="binary.txt", mime_type="text/plain",
|
||||
file_size=len(binary), file_obj=file_obj,
|
||||
)
|
||||
msg = _make_message(document=doc)
|
||||
update = _make_update(msg)
|
||||
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
# File should still be cached
|
||||
assert len(event.media_urls) == 1
|
||||
assert os.path.exists(event.media_urls[0])
|
||||
# Content NOT injected — text should be empty (no caption set)
|
||||
assert "[Content of" not in (event.text or "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_injection_capped(self, adapter):
|
||||
"""A .txt file over 100 KB should NOT have its content injected."""
|
||||
large = b"x" * (200 * 1024) # 200 KB
|
||||
file_obj = _make_file_obj(large)
|
||||
doc = _make_document(
|
||||
file_name="big.txt", mime_type="text/plain",
|
||||
file_size=len(large), file_obj=file_obj,
|
||||
)
|
||||
msg = _make_message(document=doc)
|
||||
update = _make_update(msg)
|
||||
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
# File should be cached
|
||||
assert len(event.media_urls) == 1
|
||||
# Content should NOT be injected
|
||||
assert "[Content of" not in (event.text or "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_exception_handled(self, adapter):
|
||||
"""If get_file() raises, the handler logs the error without crashing."""
|
||||
doc = _make_document(file_name="crash.pdf", file_size=100)
|
||||
doc.get_file = AsyncMock(side_effect=RuntimeError("Telegram API down"))
|
||||
msg = _make_message(document=doc)
|
||||
update = _make_update(msg)
|
||||
|
||||
# Should not raise
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
# handle_message should still be called (the handler catches the exception)
|
||||
adapter.handle_message.assert_called_once()
|
||||
171
tests/test_413_compression.py
Normal file
171
tests/test_413_compression.py
Normal file
|
|
@ -0,0 +1,171 @@
|
|||
"""Tests for 413 payload-too-large → compression retry logic in AIAgent.
|
||||
|
||||
Verifies that HTTP 413 errors trigger history compression and retry,
|
||||
rather than being treated as non-retryable generic 4xx errors.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from run_agent import AIAgent
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_tool_defs(*names: str) -> list:
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": n,
|
||||
"description": f"{n} tool",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
}
|
||||
for n in names
|
||||
]
|
||||
|
||||
|
||||
def _mock_response(content="Hello", finish_reason="stop", tool_calls=None, usage=None):
|
||||
msg = SimpleNamespace(
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
reasoning_content=None,
|
||||
reasoning=None,
|
||||
)
|
||||
choice = SimpleNamespace(message=msg, finish_reason=finish_reason)
|
||||
resp = SimpleNamespace(choices=[choice], model="test/model")
|
||||
resp.usage = SimpleNamespace(**usage) if usage else None
|
||||
return resp
|
||||
|
||||
|
||||
def _make_413_error(*, use_status_code=True, message="Request entity too large"):
|
||||
"""Create an exception that mimics a 413 HTTP error."""
|
||||
err = Exception(message)
|
||||
if use_status_code:
|
||||
err.status_code = 413
|
||||
return err
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def agent():
|
||||
with (
|
||||
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI"),
|
||||
):
|
||||
a = AIAgent(
|
||||
api_key="test-key-1234567890",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
a.client = MagicMock()
|
||||
a._cached_system_prompt = "You are helpful."
|
||||
a._use_prompt_caching = False
|
||||
a.tool_delay = 0
|
||||
a.compression_enabled = False
|
||||
a.save_trajectories = False
|
||||
return a
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestHTTP413Compression:
|
||||
"""413 errors should trigger compression, not abort as generic 4xx."""
|
||||
|
||||
def test_413_triggers_compression(self, agent):
|
||||
"""A 413 error should call _compress_context and retry, not abort."""
|
||||
# First call raises 413; second call succeeds after compression.
|
||||
err_413 = _make_413_error()
|
||||
ok_resp = _mock_response(content="Success after compression", finish_reason="stop")
|
||||
agent.client.chat.completions.create.side_effect = [err_413, ok_resp]
|
||||
|
||||
with (
|
||||
patch.object(agent, "_compress_context") as mock_compress,
|
||||
patch.object(agent, "_persist_session"),
|
||||
patch.object(agent, "_save_trajectory"),
|
||||
patch.object(agent, "_cleanup_task_resources"),
|
||||
):
|
||||
# Compression removes messages, enabling retry
|
||||
mock_compress.return_value = (
|
||||
[{"role": "user", "content": "hello"}],
|
||||
"compressed prompt",
|
||||
)
|
||||
result = agent.run_conversation("hello")
|
||||
|
||||
mock_compress.assert_called_once()
|
||||
assert result["completed"] is True
|
||||
assert result["final_response"] == "Success after compression"
|
||||
|
||||
def test_413_not_treated_as_generic_4xx(self, agent):
|
||||
"""413 must NOT hit the generic 4xx abort path; it should attempt compression."""
|
||||
err_413 = _make_413_error()
|
||||
ok_resp = _mock_response(content="Recovered", finish_reason="stop")
|
||||
agent.client.chat.completions.create.side_effect = [err_413, ok_resp]
|
||||
|
||||
with (
|
||||
patch.object(agent, "_compress_context") as mock_compress,
|
||||
patch.object(agent, "_persist_session"),
|
||||
patch.object(agent, "_save_trajectory"),
|
||||
patch.object(agent, "_cleanup_task_resources"),
|
||||
):
|
||||
mock_compress.return_value = (
|
||||
[{"role": "user", "content": "hello"}],
|
||||
"compressed",
|
||||
)
|
||||
result = agent.run_conversation("hello")
|
||||
|
||||
# If 413 were treated as generic 4xx, result would have "failed": True
|
||||
assert result.get("failed") is not True
|
||||
assert result["completed"] is True
|
||||
|
||||
def test_413_error_message_detection(self, agent):
|
||||
"""413 detected via error message string (no status_code attr)."""
|
||||
err = _make_413_error(use_status_code=False, message="error code: 413")
|
||||
ok_resp = _mock_response(content="OK", finish_reason="stop")
|
||||
agent.client.chat.completions.create.side_effect = [err, ok_resp]
|
||||
|
||||
with (
|
||||
patch.object(agent, "_compress_context") as mock_compress,
|
||||
patch.object(agent, "_persist_session"),
|
||||
patch.object(agent, "_save_trajectory"),
|
||||
patch.object(agent, "_cleanup_task_resources"),
|
||||
):
|
||||
mock_compress.return_value = (
|
||||
[{"role": "user", "content": "hello"}],
|
||||
"compressed",
|
||||
)
|
||||
result = agent.run_conversation("hello")
|
||||
|
||||
mock_compress.assert_called_once()
|
||||
assert result["completed"] is True
|
||||
|
||||
def test_413_cannot_compress_further(self, agent):
|
||||
"""When compression can't reduce messages, return partial result."""
|
||||
err_413 = _make_413_error()
|
||||
agent.client.chat.completions.create.side_effect = [err_413]
|
||||
|
||||
with (
|
||||
patch.object(agent, "_compress_context") as mock_compress,
|
||||
patch.object(agent, "_persist_session"),
|
||||
patch.object(agent, "_save_trajectory"),
|
||||
patch.object(agent, "_cleanup_task_resources"),
|
||||
):
|
||||
# Compression returns same number of messages → can't compress further
|
||||
mock_compress.return_value = (
|
||||
[{"role": "user", "content": "hello"}],
|
||||
"same prompt",
|
||||
)
|
||||
result = agent.run_conversation("hello")
|
||||
|
||||
assert result["completed"] is False
|
||||
assert result.get("partial") is True
|
||||
assert "413" in result["error"]
|
||||
Loading…
Add table
Add a link
Reference in a new issue