From c797314fcf5ffbfc4260984aced869be837318c0 Mon Sep 17 00:00:00 2001 From: 0xbyt4 <35742124+0xbyt4@users.noreply.github.com> Date: Fri, 13 Mar 2026 15:34:46 +0300 Subject: [PATCH] test: add security and hardening tests for voice mode fixes - Path traversal sanitization (Path.name strips ../) - Media endpoint authentication (401 without token, 404 on traversal) - hmac.compare_digest usage verification (no == for tokens) - DOMPurify XSS prevention in HTML template - Default bind 127.0.0.1 (adapter and config) - /remote-control token hiding in group chats - Opus find_library instead of hardcoded paths - Opus decode error logging (no silent swallow) - Interrupt _vprint force=True on all 6 calls - Anthropic interrupt handler in both API call paths - Update test_web_defaults for new 127.0.0.1 default --- tests/gateway/test_discord_opus.py | 35 ++++ tests/gateway/test_web.py | 282 ++++++++++++++++++++++++++++- tests/test_run_agent.py | 55 ++++++ 3 files changed, 371 insertions(+), 1 deletion(-) create mode 100644 tests/gateway/test_discord_opus.py diff --git a/tests/gateway/test_discord_opus.py b/tests/gateway/test_discord_opus.py new file mode 100644 index 00000000..6c2e1e6c --- /dev/null +++ b/tests/gateway/test_discord_opus.py @@ -0,0 +1,35 @@ +"""Tests for Discord Opus codec loading — must use ctypes.util.find_library.""" + +import inspect + + +class TestOpusFindLibrary: + """Opus loading must use ctypes.util.find_library, not hardcoded paths.""" + + def test_no_hardcoded_opus_path(self): + from gateway.platforms.discord import DiscordAdapter + source = inspect.getsource(DiscordAdapter.connect) + assert "/opt/homebrew" not in source, \ + "Opus loading must not use hardcoded /opt/homebrew path" + assert "libopus.so.0" not in source, \ + "Opus loading must not use hardcoded libopus.so.0 path" + + def test_uses_find_library(self): + from gateway.platforms.discord import DiscordAdapter + source = inspect.getsource(DiscordAdapter.connect) + assert "find_library" in source, \ + "Opus loading must use ctypes.util.find_library" + + def test_opus_decode_error_logged(self): + """Opus decode failure must log the error, not silently return.""" + from gateway.platforms.discord import VoiceReceiver + source = inspect.getsource(VoiceReceiver._on_packet) + assert "logger" in source, \ + "_on_packet must log Opus decode errors" + # Must not have bare `except Exception:\n return` + lines = source.split("\n") + for i, line in enumerate(lines): + if "except Exception" in line and i + 1 < len(lines): + next_line = lines[i + 1].strip() + assert next_line != "return", \ + f"_on_packet has bare 'except Exception: return' at line {i+1}" diff --git a/tests/gateway/test_web.py b/tests/gateway/test_web.py index efa1204a..6c5ae0b6 100644 --- a/tests/gateway/test_web.py +++ b/tests/gateway/test_web.py @@ -15,6 +15,12 @@ Covers: 12. Authorization bypass (Web platform always authorized) 13. Toolset registration (hermes-web in toolset maps) 14. LAN IP detection (_get_local_ip / _get_local_ips) +15. Security: path traversal sanitization +16. Security: media endpoint authentication +17. Security: hmac.compare_digest for token comparison +18. Security: DOMPurify XSS prevention +19. Security: default bind to 127.0.0.1 +20. Security: /remote-control token hiding in group chats """ import asyncio @@ -79,7 +85,7 @@ class TestConfigEnvOverrides(unittest.TestCase): _apply_env_overrides(config) self.assertIn(Platform.WEB, config.platforms) self.assertEqual(config.platforms[Platform.WEB].extra["port"], 8765) - self.assertEqual(config.platforms[Platform.WEB].extra["host"], "0.0.0.0") + self.assertEqual(config.platforms[Platform.WEB].extra["host"], "127.0.0.1") self.assertEqual(config.platforms[Platform.WEB].extra["token"], "") @patch.dict(os.environ, {}, clear=True) @@ -515,3 +521,277 @@ class TestMediaDirectory: }) adapter = WebAdapter(config) assert adapter._media_dir.exists() or True # may use default path + + +# =========================================================================== +# 15. Security: Path traversal sanitization +# =========================================================================== + + +class TestPathTraversalSanitization: + """Upload filenames with traversal sequences are sanitized.""" + + def test_path_name_strips_traversal(self): + """Path.name strips directory traversal from filenames.""" + assert Path("../../../etc/passwd").name == "passwd" + assert Path("normal_file.txt").name == "normal_file.txt" + assert Path("/absolute/path/file.txt").name == "file.txt" + + @pytest.mark.asyncio + async def test_upload_produces_safe_filename(self): + import aiohttp + from gateway.platforms.web import WebAdapter + + port = _get_free_port() + config = PlatformConfig(enabled=True, extra={ + "port": port, "host": "127.0.0.1", "token": "tok", + }) + adapter = WebAdapter(config) + try: + await adapter.connect() + async with aiohttp.ClientSession() as session: + data = aiohttp.FormData() + data.add_field("file", b"test content", + filename="safe_file.txt", + content_type="application/octet-stream") + async with session.post( + f"http://127.0.0.1:{port}/upload", + data=data, + headers={"Authorization": "Bearer tok"}, + ) as resp: + assert resp.status == 200 + result = await resp.json() + assert result["filename"].startswith("upload_") + assert "safe_file.txt" in result["filename"] + # File must be inside media dir, not escaped + assert result["url"].startswith("/media/") + finally: + await adapter.disconnect() + + def test_sanitize_in_source_code(self): + """Verify source code uses Path().name for filename sanitization.""" + import inspect + from gateway.platforms.web import WebAdapter + source = inspect.getsource(WebAdapter._handle_upload) + assert "Path(" in source and ".name" in source + + +# =========================================================================== +# 16. Security: Media endpoint authentication +# =========================================================================== + + +class TestMediaEndpointAuth: + """Media files require a valid token query parameter.""" + + @pytest.mark.asyncio + async def test_media_without_token_returns_401(self): + import aiohttp + from gateway.platforms.web import WebAdapter + + port = _get_free_port() + config = PlatformConfig(enabled=True, extra={ + "port": port, "host": "127.0.0.1", "token": "secret", + }) + adapter = WebAdapter(config) + try: + await adapter.connect() + async with aiohttp.ClientSession() as session: + async with session.get( + f"http://127.0.0.1:{port}/media/test.txt" + ) as resp: + assert resp.status == 401 + + finally: + await adapter.disconnect() + + @pytest.mark.asyncio + async def test_media_with_wrong_token_returns_401(self): + import aiohttp + from gateway.platforms.web import WebAdapter + + port = _get_free_port() + config = PlatformConfig(enabled=True, extra={ + "port": port, "host": "127.0.0.1", "token": "secret", + }) + adapter = WebAdapter(config) + try: + await adapter.connect() + async with aiohttp.ClientSession() as session: + async with session.get( + f"http://127.0.0.1:{port}/media/test.txt?token=wrong" + ) as resp: + assert resp.status == 401 + finally: + await adapter.disconnect() + + @pytest.mark.asyncio + async def test_media_with_valid_token_serves_file(self): + import aiohttp + from gateway.platforms.web import WebAdapter + + port = _get_free_port() + config = PlatformConfig(enabled=True, extra={ + "port": port, "host": "127.0.0.1", "token": "secret", + }) + adapter = WebAdapter(config) + try: + await adapter.connect() + # Create a test file in the media directory + test_file = adapter._media_dir / "testfile.txt" + test_file.write_text("hello") + + async with aiohttp.ClientSession() as session: + async with session.get( + f"http://127.0.0.1:{port}/media/testfile.txt?token=secret" + ) as resp: + assert resp.status == 200 + body = await resp.text() + assert body == "hello" + finally: + await adapter.disconnect() + + @pytest.mark.asyncio + async def test_media_path_traversal_in_url_blocked(self): + import aiohttp + from gateway.platforms.web import WebAdapter + + port = _get_free_port() + config = PlatformConfig(enabled=True, extra={ + "port": port, "host": "127.0.0.1", "token": "secret", + }) + adapter = WebAdapter(config) + try: + await adapter.connect() + async with aiohttp.ClientSession() as session: + async with session.get( + f"http://127.0.0.1:{port}/media/..%2F..%2Fetc%2Fpasswd?token=secret" + ) as resp: + assert resp.status == 404 + finally: + await adapter.disconnect() + + +# =========================================================================== +# 17. Security: hmac.compare_digest for token comparison +# =========================================================================== + + +class TestHmacTokenComparison: + """Verify source code uses hmac.compare_digest, not == / !=.""" + + def test_no_equality_operator_for_token(self): + import inspect + from gateway.platforms.web import WebAdapter + source = inspect.getsource(WebAdapter) + # There should be no `== self._token` or `!= self._token` in the source + assert "== self._token" not in source, \ + "Token comparison must use hmac.compare_digest, not ==" + assert "!= self._token" not in source, \ + "Token comparison must use hmac.compare_digest, not !=" + + def test_hmac_compare_digest_used(self): + import inspect + from gateway.platforms.web import WebAdapter + source = inspect.getsource(WebAdapter) + assert "hmac.compare_digest" in source + + +# =========================================================================== +# 18. Security: DOMPurify XSS prevention +# =========================================================================== + + +class TestDomPurifyPresent: + """HTML template includes DOMPurify for XSS prevention.""" + + def test_dompurify_script_included(self): + from gateway.platforms.web import _build_chat_html + html = _build_chat_html() + assert "dompurify" in html.lower() + assert "DOMPurify.sanitize" in html + + def test_marked_output_sanitized(self): + from gateway.platforms.web import _build_chat_html + html = _build_chat_html() + assert "DOMPurify.sanitize(marked.parse(" in html + + +# =========================================================================== +# 19. Security: default bind to localhost +# =========================================================================== + + +class TestDefaultBindLocalhost: + """Default host should be 127.0.0.1, not 0.0.0.0.""" + + def test_adapter_default_host(self): + from gateway.platforms.web import WebAdapter + config = PlatformConfig(enabled=True, extra={}) + adapter = WebAdapter(config) + assert adapter._host == "127.0.0.1" + + @patch.dict(os.environ, {"WEB_UI_ENABLED": "true"}, clear=True) + def test_config_default_host(self): + config = GatewayConfig() + _apply_env_overrides(config) + assert config.platforms[Platform.WEB].extra["host"] == "127.0.0.1" + + +# =========================================================================== +# 20. Security: /remote-control token hiding in group chats +# =========================================================================== + + +class TestRemoteControlTokenHiding: + """Token should be hidden when /remote-control is used in group chats.""" + + def _make_runner(self, tmp_path): + from gateway.run import GatewayRunner + runner = object.__new__(GatewayRunner) + runner.adapters = {} + runner._voice_mode = {} + runner._VOICE_MODE_PATH = tmp_path / "voice.json" + runner._session_db = None + runner.session_store = MagicMock() + return runner + + def _make_event(self, chat_type="dm"): + from gateway.platforms.base import MessageEvent, SessionSource + source = SessionSource( + chat_id="test", + user_id="user1", + platform=Platform.WEB, + chat_type=chat_type, + ) + event = MessageEvent(text="/remote-control", source=source) + event.message_id = "msg1" + return event + + @pytest.mark.asyncio + async def test_token_visible_in_dm(self, tmp_path): + from gateway.platforms.web import WebAdapter + runner = self._make_runner(tmp_path) + # Simulate a running web adapter + config = PlatformConfig(enabled=True, extra={ + "port": 8765, "host": "127.0.0.1", "token": "mysecret", + }) + adapter = WebAdapter(config) + runner.adapters[Platform.WEB] = adapter + event = self._make_event(chat_type="dm") + result = await runner._handle_remote_control_command(event) + assert "mysecret" in result + + @pytest.mark.asyncio + async def test_token_hidden_in_group(self, tmp_path): + from gateway.platforms.web import WebAdapter + runner = self._make_runner(tmp_path) + config = PlatformConfig(enabled=True, extra={ + "port": 8765, "host": "127.0.0.1", "token": "mysecret", + }) + adapter = WebAdapter(config) + runner.adapters[Platform.WEB] = adapter + event = self._make_event(chat_type="group") + result = await runner._handle_remote_control_command(event) + assert "mysecret" not in result + assert "hidden" in result.lower() diff --git a/tests/test_run_agent.py b/tests/test_run_agent.py index 50cf3c90..6e04534e 100644 --- a/tests/test_run_agent.py +++ b/tests/test_run_agent.py @@ -2238,3 +2238,58 @@ class TestStreamingApiCall: assert resp.choices[0].message.content == "Hello" assert resp.model == "gpt-4" + + +# =================================================================== +# Interrupt _vprint force=True verification +# =================================================================== + + +class TestInterruptVprintForceTrue: + """All interrupt _vprint calls must use force=True so they are always visible.""" + + def test_all_interrupt_vprint_have_force_true(self): + """Scan source for _vprint calls containing 'Interrupt' — each must have force=True.""" + import inspect + source = inspect.getsource(AIAgent) + lines = source.split("\n") + violations = [] + for i, line in enumerate(lines, 1): + stripped = line.strip() + if "_vprint(" in stripped and "Interrupt" in stripped: + if "force=True" not in stripped: + violations.append(f"line {i}: {stripped}") + assert not violations, ( + f"Interrupt _vprint calls missing force=True:\n" + + "\n".join(violations) + ) + + +# =================================================================== +# Anthropic interrupt handler in _interruptible_api_call +# =================================================================== + + +class TestAnthropicInterruptHandler: + """_interruptible_api_call must handle Anthropic mode when interrupted.""" + + def test_interruptible_has_anthropic_branch(self): + """The interrupt handler must check api_mode == 'anthropic_messages'.""" + import inspect + source = inspect.getsource(AIAgent._interruptible_api_call) + assert "anthropic_messages" in source, \ + "_interruptible_api_call must handle Anthropic interrupt (api_mode check)" + + def test_interruptible_rebuilds_anthropic_client(self): + """After interrupting, the Anthropic client should be rebuilt.""" + import inspect + source = inspect.getsource(AIAgent._interruptible_api_call) + assert "build_anthropic_client" in source, \ + "_interruptible_api_call must rebuild Anthropic client after interrupt" + + def test_streaming_has_anthropic_branch(self): + """_streaming_api_call must also handle Anthropic interrupt.""" + import inspect + source = inspect.getsource(AIAgent._streaming_api_call) + assert "anthropic_messages" in source, \ + "_streaming_api_call must handle Anthropic interrupt"