The architecture has been updated

This commit is contained in:
Skyber_2 2026-03-31 23:31:36 +03:00
parent 805f7a017e
commit a01257ead9
1119 changed files with 226 additions and 352 deletions

View file

View file

@ -0,0 +1,238 @@
"""Integration tests for gateway AIAgent caching.
Verifies that the agent cache correctly:
- Reuses agents across messages (same config same instance)
- Rebuilds agents when config changes (model, provider, toolsets)
- Updates reasoning_config in-place without rebuilding
- Evicts on session reset
- Evicts on fallback activation
- Preserves frozen system prompt across turns
"""
import hashlib
import json
import threading
from unittest.mock import MagicMock, patch
import pytest
def _make_runner():
"""Create a minimal GatewayRunner with just the cache infrastructure."""
from gateway.run import GatewayRunner
runner = GatewayRunner.__new__(GatewayRunner)
runner._agent_cache = {}
runner._agent_cache_lock = threading.Lock()
return runner
class TestAgentConfigSignature:
"""Config signature produces stable, distinct keys."""
def test_same_config_same_signature(self):
from gateway.run import GatewayRunner
runtime = {"api_key": "sk-test12345678", "base_url": "https://openrouter.ai/api/v1",
"provider": "openrouter", "api_mode": "chat_completions"}
sig1 = GatewayRunner._agent_config_signature("claude-sonnet-4", runtime, ["hermes-telegram"], "")
sig2 = GatewayRunner._agent_config_signature("claude-sonnet-4", runtime, ["hermes-telegram"], "")
assert sig1 == sig2
def test_model_change_different_signature(self):
from gateway.run import GatewayRunner
runtime = {"api_key": "sk-test12345678", "base_url": "https://openrouter.ai/api/v1",
"provider": "openrouter"}
sig1 = GatewayRunner._agent_config_signature("claude-sonnet-4", runtime, ["hermes-telegram"], "")
sig2 = GatewayRunner._agent_config_signature("claude-opus-4.6", runtime, ["hermes-telegram"], "")
assert sig1 != sig2
def test_provider_change_different_signature(self):
from gateway.run import GatewayRunner
rt1 = {"api_key": "sk-test12345678", "base_url": "https://openrouter.ai/api/v1", "provider": "openrouter"}
rt2 = {"api_key": "sk-test12345678", "base_url": "https://api.anthropic.com", "provider": "anthropic"}
sig1 = GatewayRunner._agent_config_signature("claude-sonnet-4", rt1, ["hermes-telegram"], "")
sig2 = GatewayRunner._agent_config_signature("claude-sonnet-4", rt2, ["hermes-telegram"], "")
assert sig1 != sig2
def test_toolset_change_different_signature(self):
from gateway.run import GatewayRunner
runtime = {"api_key": "sk-test12345678", "base_url": "https://openrouter.ai/api/v1", "provider": "openrouter"}
sig1 = GatewayRunner._agent_config_signature("claude-sonnet-4", runtime, ["hermes-telegram"], "")
sig2 = GatewayRunner._agent_config_signature("claude-sonnet-4", runtime, ["hermes-discord"], "")
assert sig1 != sig2
def test_reasoning_not_in_signature(self):
"""Reasoning config is set per-message, not part of the signature."""
from gateway.run import GatewayRunner
runtime = {"api_key": "sk-test12345678", "base_url": "https://openrouter.ai/api/v1", "provider": "openrouter"}
# Same config — signature should be identical regardless of what
# reasoning_config the caller might have (it's not passed in)
sig1 = GatewayRunner._agent_config_signature("claude-sonnet-4", runtime, ["hermes-telegram"], "")
sig2 = GatewayRunner._agent_config_signature("claude-sonnet-4", runtime, ["hermes-telegram"], "")
assert sig1 == sig2
class TestAgentCacheLifecycle:
"""End-to-end cache behavior with real AIAgent construction."""
def test_cache_hit_returns_same_agent(self):
"""Second message with same config reuses the cached agent instance."""
from run_agent import AIAgent
runner = _make_runner()
session_key = "telegram:12345"
runtime = {"api_key": "test", "base_url": "https://openrouter.ai/api/v1",
"provider": "openrouter", "api_mode": "chat_completions"}
sig = runner._agent_config_signature("anthropic/claude-sonnet-4", runtime, ["hermes-telegram"], "")
# First message — create and cache
agent1 = AIAgent(
model="anthropic/claude-sonnet-4", api_key="test",
base_url="https://openrouter.ai/api/v1", provider="openrouter",
max_iterations=5, quiet_mode=True, skip_context_files=True,
skip_memory=True, platform="telegram",
)
with runner._agent_cache_lock:
runner._agent_cache[session_key] = (agent1, sig)
# Second message — cache hit
with runner._agent_cache_lock:
cached = runner._agent_cache.get(session_key)
assert cached is not None
assert cached[1] == sig
assert cached[0] is agent1 # same instance
def test_cache_miss_on_model_change(self):
"""Model change produces different signature → cache miss."""
from run_agent import AIAgent
runner = _make_runner()
session_key = "telegram:12345"
runtime = {"api_key": "test", "base_url": "https://openrouter.ai/api/v1",
"provider": "openrouter", "api_mode": "chat_completions"}
old_sig = runner._agent_config_signature("anthropic/claude-sonnet-4", runtime, ["hermes-telegram"], "")
agent1 = AIAgent(
model="anthropic/claude-sonnet-4", api_key="test",
base_url="https://openrouter.ai/api/v1", provider="openrouter",
max_iterations=5, quiet_mode=True, skip_context_files=True,
skip_memory=True, platform="telegram",
)
with runner._agent_cache_lock:
runner._agent_cache[session_key] = (agent1, old_sig)
# New model → different signature
new_sig = runner._agent_config_signature("anthropic/claude-opus-4.6", runtime, ["hermes-telegram"], "")
assert new_sig != old_sig
with runner._agent_cache_lock:
cached = runner._agent_cache.get(session_key)
assert cached[1] != new_sig # signature mismatch → would create new agent
def test_evict_on_session_reset(self):
"""_evict_cached_agent removes the entry."""
from run_agent import AIAgent
runner = _make_runner()
session_key = "telegram:12345"
agent = AIAgent(
model="anthropic/claude-sonnet-4", api_key="test",
base_url="https://openrouter.ai/api/v1", provider="openrouter",
max_iterations=5, quiet_mode=True, skip_context_files=True,
skip_memory=True,
)
with runner._agent_cache_lock:
runner._agent_cache[session_key] = (agent, "sig123")
runner._evict_cached_agent(session_key)
with runner._agent_cache_lock:
assert session_key not in runner._agent_cache
def test_evict_does_not_affect_other_sessions(self):
"""Evicting one session leaves other sessions cached."""
runner = _make_runner()
with runner._agent_cache_lock:
runner._agent_cache["session-A"] = ("agent-A", "sig-A")
runner._agent_cache["session-B"] = ("agent-B", "sig-B")
runner._evict_cached_agent("session-A")
with runner._agent_cache_lock:
assert "session-A" not in runner._agent_cache
assert "session-B" in runner._agent_cache
def test_reasoning_config_updates_in_place(self):
"""Reasoning config can be set on a cached agent without eviction."""
from run_agent import AIAgent
agent = AIAgent(
model="anthropic/claude-sonnet-4", api_key="test",
base_url="https://openrouter.ai/api/v1", provider="openrouter",
max_iterations=5, quiet_mode=True, skip_context_files=True,
skip_memory=True,
reasoning_config={"enabled": True, "effort": "medium"},
)
# Simulate per-message reasoning update
agent.reasoning_config = {"enabled": True, "effort": "high"}
assert agent.reasoning_config["effort"] == "high"
# System prompt should not be affected by reasoning change
prompt1 = agent._build_system_prompt()
agent._cached_system_prompt = prompt1 # simulate run_conversation caching
agent.reasoning_config = {"enabled": True, "effort": "low"}
prompt2 = agent._cached_system_prompt
assert prompt1 is prompt2 # same object — not invalidated by reasoning change
def test_system_prompt_frozen_across_cache_reuse(self):
"""The cached agent's system prompt stays identical across turns."""
from run_agent import AIAgent
agent = AIAgent(
model="anthropic/claude-sonnet-4", api_key="test",
base_url="https://openrouter.ai/api/v1", provider="openrouter",
max_iterations=5, quiet_mode=True, skip_context_files=True,
skip_memory=True, platform="telegram",
)
# Build system prompt (simulates first run_conversation)
prompt1 = agent._build_system_prompt()
agent._cached_system_prompt = prompt1
# Simulate second turn — prompt should be frozen
prompt2 = agent._cached_system_prompt
assert prompt1 is prompt2 # same object, not rebuilt
def test_callbacks_update_without_cache_eviction(self):
"""Per-message callbacks can be set on cached agent."""
from run_agent import AIAgent
agent = AIAgent(
model="anthropic/claude-sonnet-4", api_key="test",
base_url="https://openrouter.ai/api/v1", provider="openrouter",
max_iterations=5, quiet_mode=True, skip_context_files=True,
skip_memory=True,
)
# Set callbacks like the gateway does per-message
cb1 = lambda *a: None
cb2 = lambda *a: None
agent.tool_progress_callback = cb1
agent.step_callback = cb2
agent.stream_delta_callback = None
agent.status_callback = None
assert agent.tool_progress_callback is cb1
assert agent.step_callback is cb2
# Update for next message
cb3 = lambda *a: None
agent.tool_progress_callback = cb3
assert agent.tool_progress_callback is cb3

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,597 @@
"""
Tests for the Cron Jobs API endpoints on the API server adapter.
Covers:
- CRUD operations for cron jobs (list, create, get, update, delete)
- Pause / resume / run (trigger) actions
- Input validation (missing name, name too long, prompt too long, invalid repeat)
- Job ID validation (invalid hex)
- Auth enforcement (401 when API_SERVER_KEY is set)
- Cron module unavailability (501 when _CRON_AVAILABLE is False)
"""
import json
from unittest.mock import MagicMock, patch
import pytest
from aiohttp import web
from aiohttp.test_utils import TestClient, TestServer
from gateway.config import PlatformConfig
from gateway.platforms.api_server import APIServerAdapter, cors_middleware
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
SAMPLE_JOB = {
"id": "aabbccddeeff",
"name": "test-job",
"schedule": "*/5 * * * *",
"prompt": "do something",
"deliver": "local",
"enabled": True,
}
VALID_JOB_ID = "aabbccddeeff"
def _make_adapter(api_key: str = "") -> APIServerAdapter:
"""Create an adapter with optional API key."""
extra = {}
if api_key:
extra["key"] = api_key
config = PlatformConfig(enabled=True, extra=extra)
return APIServerAdapter(config)
def _create_app(adapter: APIServerAdapter) -> web.Application:
"""Create the aiohttp app with jobs routes registered."""
app = web.Application(middlewares=[cors_middleware])
app["api_server_adapter"] = adapter
# Register only job routes (plus health for sanity)
app.router.add_get("/health", adapter._handle_health)
app.router.add_get("/api/jobs", adapter._handle_list_jobs)
app.router.add_post("/api/jobs", adapter._handle_create_job)
app.router.add_get("/api/jobs/{job_id}", adapter._handle_get_job)
app.router.add_patch("/api/jobs/{job_id}", adapter._handle_update_job)
app.router.add_delete("/api/jobs/{job_id}", adapter._handle_delete_job)
app.router.add_post("/api/jobs/{job_id}/pause", adapter._handle_pause_job)
app.router.add_post("/api/jobs/{job_id}/resume", adapter._handle_resume_job)
app.router.add_post("/api/jobs/{job_id}/run", adapter._handle_run_job)
return app
@pytest.fixture
def adapter():
return _make_adapter()
@pytest.fixture
def auth_adapter():
return _make_adapter(api_key="sk-secret")
# ---------------------------------------------------------------------------
# 1. test_list_jobs
# ---------------------------------------------------------------------------
class TestListJobs:
@pytest.mark.asyncio
async def test_list_jobs(self, adapter):
"""GET /api/jobs returns job list."""
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(
APIServerAdapter, "_CRON_AVAILABLE", True
), patch.object(
APIServerAdapter, "_cron_list", return_value=[SAMPLE_JOB]
):
resp = await cli.get("/api/jobs")
assert resp.status == 200
data = await resp.json()
assert "jobs" in data
assert data["jobs"] == [SAMPLE_JOB]
# -------------------------------------------------------------------
# 2. test_list_jobs_include_disabled
# -------------------------------------------------------------------
@pytest.mark.asyncio
async def test_list_jobs_include_disabled(self, adapter):
"""GET /api/jobs?include_disabled=true passes the flag."""
app = _create_app(adapter)
mock_list = MagicMock(return_value=[SAMPLE_JOB])
async with TestClient(TestServer(app)) as cli:
with patch.object(
APIServerAdapter, "_CRON_AVAILABLE", True
), patch.object(
APIServerAdapter, "_cron_list", mock_list
):
resp = await cli.get("/api/jobs?include_disabled=true")
assert resp.status == 200
mock_list.assert_called_once_with(include_disabled=True)
@pytest.mark.asyncio
async def test_list_jobs_default_excludes_disabled(self, adapter):
"""GET /api/jobs without flag passes include_disabled=False."""
app = _create_app(adapter)
mock_list = MagicMock(return_value=[])
async with TestClient(TestServer(app)) as cli:
with patch.object(
APIServerAdapter, "_CRON_AVAILABLE", True
), patch.object(
APIServerAdapter, "_cron_list", mock_list
):
resp = await cli.get("/api/jobs")
assert resp.status == 200
mock_list.assert_called_once_with(include_disabled=False)
# ---------------------------------------------------------------------------
# 3-7. test_create_job and validation
# ---------------------------------------------------------------------------
class TestCreateJob:
@pytest.mark.asyncio
async def test_create_job(self, adapter):
"""POST /api/jobs with valid body returns created job."""
app = _create_app(adapter)
mock_create = MagicMock(return_value=SAMPLE_JOB)
async with TestClient(TestServer(app)) as cli:
with patch.object(
APIServerAdapter, "_CRON_AVAILABLE", True
), patch.object(
APIServerAdapter, "_cron_create", mock_create
):
resp = await cli.post("/api/jobs", json={
"name": "test-job",
"schedule": "*/5 * * * *",
"prompt": "do something",
})
assert resp.status == 200
data = await resp.json()
assert data["job"] == SAMPLE_JOB
mock_create.assert_called_once()
call_kwargs = mock_create.call_args[1]
assert call_kwargs["name"] == "test-job"
assert call_kwargs["schedule"] == "*/5 * * * *"
assert call_kwargs["prompt"] == "do something"
@pytest.mark.asyncio
async def test_create_job_missing_name(self, adapter):
"""POST /api/jobs without name returns 400."""
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
resp = await cli.post("/api/jobs", json={
"schedule": "*/5 * * * *",
"prompt": "do something",
})
assert resp.status == 400
data = await resp.json()
assert "name" in data["error"].lower() or "Name" in data["error"]
@pytest.mark.asyncio
async def test_create_job_name_too_long(self, adapter):
"""POST /api/jobs with name > 200 chars returns 400."""
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
resp = await cli.post("/api/jobs", json={
"name": "x" * 201,
"schedule": "*/5 * * * *",
})
assert resp.status == 400
data = await resp.json()
assert "200" in data["error"] or "Name" in data["error"]
@pytest.mark.asyncio
async def test_create_job_prompt_too_long(self, adapter):
"""POST /api/jobs with prompt > 5000 chars returns 400."""
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
resp = await cli.post("/api/jobs", json={
"name": "test-job",
"schedule": "*/5 * * * *",
"prompt": "x" * 5001,
})
assert resp.status == 400
data = await resp.json()
assert "5000" in data["error"] or "Prompt" in data["error"]
@pytest.mark.asyncio
async def test_create_job_invalid_repeat(self, adapter):
"""POST /api/jobs with repeat=0 returns 400."""
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
resp = await cli.post("/api/jobs", json={
"name": "test-job",
"schedule": "*/5 * * * *",
"repeat": 0,
})
assert resp.status == 400
data = await resp.json()
assert "repeat" in data["error"].lower() or "Repeat" in data["error"]
@pytest.mark.asyncio
async def test_create_job_missing_schedule(self, adapter):
"""POST /api/jobs without schedule returns 400."""
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
resp = await cli.post("/api/jobs", json={
"name": "test-job",
})
assert resp.status == 400
data = await resp.json()
assert "schedule" in data["error"].lower() or "Schedule" in data["error"]
# ---------------------------------------------------------------------------
# 8-10. test_get_job
# ---------------------------------------------------------------------------
class TestGetJob:
@pytest.mark.asyncio
async def test_get_job(self, adapter):
"""GET /api/jobs/{id} returns job."""
app = _create_app(adapter)
mock_get = MagicMock(return_value=SAMPLE_JOB)
async with TestClient(TestServer(app)) as cli:
with patch.object(
APIServerAdapter, "_CRON_AVAILABLE", True
), patch.object(
APIServerAdapter, "_cron_get", mock_get
):
resp = await cli.get(f"/api/jobs/{VALID_JOB_ID}")
assert resp.status == 200
data = await resp.json()
assert data["job"] == SAMPLE_JOB
mock_get.assert_called_once_with(VALID_JOB_ID)
@pytest.mark.asyncio
async def test_get_job_not_found(self, adapter):
"""GET /api/jobs/{id} returns 404 when job doesn't exist."""
app = _create_app(adapter)
mock_get = MagicMock(return_value=None)
async with TestClient(TestServer(app)) as cli:
with patch.object(
APIServerAdapter, "_CRON_AVAILABLE", True
), patch.object(
APIServerAdapter, "_cron_get", mock_get
):
resp = await cli.get(f"/api/jobs/{VALID_JOB_ID}")
assert resp.status == 404
@pytest.mark.asyncio
async def test_get_job_invalid_id(self, adapter):
"""GET /api/jobs/{id} with non-hex id returns 400."""
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
resp = await cli.get("/api/jobs/not-a-valid-hex!")
assert resp.status == 400
data = await resp.json()
assert "Invalid" in data["error"]
# ---------------------------------------------------------------------------
# 11-12. test_update_job
# ---------------------------------------------------------------------------
class TestUpdateJob:
@pytest.mark.asyncio
async def test_update_job(self, adapter):
"""PATCH /api/jobs/{id} updates with whitelisted fields."""
app = _create_app(adapter)
updated_job = {**SAMPLE_JOB, "name": "updated-name"}
mock_update = MagicMock(return_value=updated_job)
async with TestClient(TestServer(app)) as cli:
with patch.object(
APIServerAdapter, "_CRON_AVAILABLE", True
), patch.object(
APIServerAdapter, "_cron_update", mock_update
):
resp = await cli.patch(
f"/api/jobs/{VALID_JOB_ID}",
json={"name": "updated-name", "schedule": "0 * * * *"},
)
assert resp.status == 200
data = await resp.json()
assert data["job"] == updated_job
mock_update.assert_called_once()
call_args = mock_update.call_args
assert call_args[0][0] == VALID_JOB_ID
sanitized = call_args[0][1]
assert "name" in sanitized
assert "schedule" in sanitized
@pytest.mark.asyncio
async def test_update_job_rejects_unknown_fields(self, adapter):
"""PATCH /api/jobs/{id} — only allowed fields pass through."""
app = _create_app(adapter)
updated_job = {**SAMPLE_JOB, "name": "new-name"}
mock_update = MagicMock(return_value=updated_job)
async with TestClient(TestServer(app)) as cli:
with patch.object(
APIServerAdapter, "_CRON_AVAILABLE", True
), patch.object(
APIServerAdapter, "_cron_update", mock_update
):
resp = await cli.patch(
f"/api/jobs/{VALID_JOB_ID}",
json={
"name": "new-name",
"evil_field": "malicious",
"__proto__": "hack",
},
)
assert resp.status == 200
call_args = mock_update.call_args
sanitized = call_args[0][1]
assert "name" in sanitized
assert "evil_field" not in sanitized
assert "__proto__" not in sanitized
@pytest.mark.asyncio
async def test_update_job_no_valid_fields(self, adapter):
"""PATCH /api/jobs/{id} with only unknown fields returns 400."""
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
resp = await cli.patch(
f"/api/jobs/{VALID_JOB_ID}",
json={"evil_field": "malicious"},
)
assert resp.status == 400
data = await resp.json()
assert "No valid fields" in data["error"]
# ---------------------------------------------------------------------------
# 13. test_delete_job
# ---------------------------------------------------------------------------
class TestDeleteJob:
@pytest.mark.asyncio
async def test_delete_job(self, adapter):
"""DELETE /api/jobs/{id} returns ok."""
app = _create_app(adapter)
mock_remove = MagicMock(return_value=True)
async with TestClient(TestServer(app)) as cli:
with patch.object(
APIServerAdapter, "_CRON_AVAILABLE", True
), patch.object(
APIServerAdapter, "_cron_remove", mock_remove
):
resp = await cli.delete(f"/api/jobs/{VALID_JOB_ID}")
assert resp.status == 200
data = await resp.json()
assert data["ok"] is True
mock_remove.assert_called_once_with(VALID_JOB_ID)
@pytest.mark.asyncio
async def test_delete_job_not_found(self, adapter):
"""DELETE /api/jobs/{id} returns 404 when job doesn't exist."""
app = _create_app(adapter)
mock_remove = MagicMock(return_value=False)
async with TestClient(TestServer(app)) as cli:
with patch.object(
APIServerAdapter, "_CRON_AVAILABLE", True
), patch.object(
APIServerAdapter, "_cron_remove", mock_remove
):
resp = await cli.delete(f"/api/jobs/{VALID_JOB_ID}")
assert resp.status == 404
# ---------------------------------------------------------------------------
# 14. test_pause_job
# ---------------------------------------------------------------------------
class TestPauseJob:
@pytest.mark.asyncio
async def test_pause_job(self, adapter):
"""POST /api/jobs/{id}/pause returns updated job."""
app = _create_app(adapter)
paused_job = {**SAMPLE_JOB, "enabled": False}
mock_pause = MagicMock(return_value=paused_job)
async with TestClient(TestServer(app)) as cli:
with patch.object(
APIServerAdapter, "_CRON_AVAILABLE", True
), patch.object(
APIServerAdapter, "_cron_pause", mock_pause
):
resp = await cli.post(f"/api/jobs/{VALID_JOB_ID}/pause")
assert resp.status == 200
data = await resp.json()
assert data["job"] == paused_job
assert data["job"]["enabled"] is False
mock_pause.assert_called_once_with(VALID_JOB_ID)
# ---------------------------------------------------------------------------
# 15. test_resume_job
# ---------------------------------------------------------------------------
class TestResumeJob:
@pytest.mark.asyncio
async def test_resume_job(self, adapter):
"""POST /api/jobs/{id}/resume returns updated job."""
app = _create_app(adapter)
resumed_job = {**SAMPLE_JOB, "enabled": True}
mock_resume = MagicMock(return_value=resumed_job)
async with TestClient(TestServer(app)) as cli:
with patch.object(
APIServerAdapter, "_CRON_AVAILABLE", True
), patch.object(
APIServerAdapter, "_cron_resume", mock_resume
):
resp = await cli.post(f"/api/jobs/{VALID_JOB_ID}/resume")
assert resp.status == 200
data = await resp.json()
assert data["job"] == resumed_job
assert data["job"]["enabled"] is True
mock_resume.assert_called_once_with(VALID_JOB_ID)
# ---------------------------------------------------------------------------
# 16. test_run_job
# ---------------------------------------------------------------------------
class TestRunJob:
@pytest.mark.asyncio
async def test_run_job(self, adapter):
"""POST /api/jobs/{id}/run returns triggered job."""
app = _create_app(adapter)
triggered_job = {**SAMPLE_JOB, "last_run": "2025-01-01T00:00:00Z"}
mock_trigger = MagicMock(return_value=triggered_job)
async with TestClient(TestServer(app)) as cli:
with patch.object(
APIServerAdapter, "_CRON_AVAILABLE", True
), patch.object(
APIServerAdapter, "_cron_trigger", mock_trigger
):
resp = await cli.post(f"/api/jobs/{VALID_JOB_ID}/run")
assert resp.status == 200
data = await resp.json()
assert data["job"] == triggered_job
mock_trigger.assert_called_once_with(VALID_JOB_ID)
# ---------------------------------------------------------------------------
# 17. test_auth_required
# ---------------------------------------------------------------------------
class TestAuthRequired:
@pytest.mark.asyncio
async def test_auth_required_list_jobs(self, auth_adapter):
"""GET /api/jobs without API key returns 401 when key is set."""
app = _create_app(auth_adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
resp = await cli.get("/api/jobs")
assert resp.status == 401
@pytest.mark.asyncio
async def test_auth_required_create_job(self, auth_adapter):
"""POST /api/jobs without API key returns 401 when key is set."""
app = _create_app(auth_adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
resp = await cli.post("/api/jobs", json={
"name": "test", "schedule": "* * * * *",
})
assert resp.status == 401
@pytest.mark.asyncio
async def test_auth_required_get_job(self, auth_adapter):
"""GET /api/jobs/{id} without API key returns 401 when key is set."""
app = _create_app(auth_adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
resp = await cli.get(f"/api/jobs/{VALID_JOB_ID}")
assert resp.status == 401
@pytest.mark.asyncio
async def test_auth_required_delete_job(self, auth_adapter):
"""DELETE /api/jobs/{id} without API key returns 401."""
app = _create_app(auth_adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
resp = await cli.delete(f"/api/jobs/{VALID_JOB_ID}")
assert resp.status == 401
@pytest.mark.asyncio
async def test_auth_passes_with_valid_key(self, auth_adapter):
"""GET /api/jobs with correct API key succeeds."""
app = _create_app(auth_adapter)
mock_list = MagicMock(return_value=[])
async with TestClient(TestServer(app)) as cli:
with patch.object(
APIServerAdapter, "_CRON_AVAILABLE", True
), patch.object(
APIServerAdapter, "_cron_list", mock_list
):
resp = await cli.get(
"/api/jobs",
headers={"Authorization": "Bearer sk-secret"},
)
assert resp.status == 200
# ---------------------------------------------------------------------------
# 18. test_cron_unavailable
# ---------------------------------------------------------------------------
class TestCronUnavailable:
@pytest.mark.asyncio
async def test_cron_unavailable_list(self, adapter):
"""GET /api/jobs returns 501 when _CRON_AVAILABLE is False."""
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", False):
resp = await cli.get("/api/jobs")
assert resp.status == 501
data = await resp.json()
assert "not available" in data["error"].lower()
@pytest.mark.asyncio
async def test_cron_unavailable_create(self, adapter):
"""POST /api/jobs returns 501 when _CRON_AVAILABLE is False."""
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", False):
resp = await cli.post("/api/jobs", json={
"name": "test", "schedule": "* * * * *",
})
assert resp.status == 501
@pytest.mark.asyncio
async def test_cron_unavailable_get(self, adapter):
"""GET /api/jobs/{id} returns 501 when _CRON_AVAILABLE is False."""
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", False):
resp = await cli.get(f"/api/jobs/{VALID_JOB_ID}")
assert resp.status == 501
@pytest.mark.asyncio
async def test_cron_unavailable_delete(self, adapter):
"""DELETE /api/jobs/{id} returns 501 when _CRON_AVAILABLE is False."""
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", False):
resp = await cli.delete(f"/api/jobs/{VALID_JOB_ID}")
assert resp.status == 501
@pytest.mark.asyncio
async def test_cron_unavailable_pause(self, adapter):
"""POST /api/jobs/{id}/pause returns 501 when _CRON_AVAILABLE is False."""
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", False):
resp = await cli.post(f"/api/jobs/{VALID_JOB_ID}/pause")
assert resp.status == 501
@pytest.mark.asyncio
async def test_cron_unavailable_resume(self, adapter):
"""POST /api/jobs/{id}/resume returns 501 when _CRON_AVAILABLE is False."""
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", False):
resp = await cli.post(f"/api/jobs/{VALID_JOB_ID}/resume")
assert resp.status == 501
@pytest.mark.asyncio
async def test_cron_unavailable_run(self, adapter):
"""POST /api/jobs/{id}/run returns 501 when _CRON_AVAILABLE is False."""
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", False):
resp = await cli.post(f"/api/jobs/{VALID_JOB_ID}/run")
assert resp.status == 501

View file

@ -0,0 +1,240 @@
"""Tests for /approve and /deny gateway commands.
Verifies that dangerous command approvals require explicit /approve or /deny
slash commands, not bare "yes"/"no" text matching.
"""
import time
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from gateway.config import GatewayConfig, Platform, PlatformConfig
from gateway.platforms.base import MessageEvent
from gateway.session import SessionEntry, SessionSource, build_session_key
def _make_source() -> SessionSource:
return SessionSource(
platform=Platform.TELEGRAM,
user_id="u1",
chat_id="c1",
user_name="tester",
chat_type="dm",
)
def _make_event(text: str) -> MessageEvent:
return MessageEvent(
text=text,
source=_make_source(),
message_id="m1",
)
def _make_runner():
from gateway.run import GatewayRunner
runner = object.__new__(GatewayRunner)
runner.config = GatewayConfig(
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}
)
adapter = MagicMock()
adapter.send = AsyncMock()
runner.adapters = {Platform.TELEGRAM: adapter}
runner._voice_mode = {}
runner.hooks = SimpleNamespace(emit=AsyncMock(), loaded_hooks=False)
runner.session_store = MagicMock()
runner._running_agents = {}
runner._pending_messages = {}
runner._pending_approvals = {}
runner._session_db = None
runner._reasoning_config = None
runner._provider_routing = {}
runner._fallback_model = None
runner._show_reasoning = False
runner._is_user_authorized = lambda _source: True
runner._set_session_env = lambda _context: None
return runner
def _make_pending_approval(command="sudo rm -rf /tmp/test", pattern_key="sudo"):
return {
"command": command,
"pattern_key": pattern_key,
"pattern_keys": [pattern_key],
"description": "sudo command",
"timestamp": time.time(),
}
# ------------------------------------------------------------------
# /approve command
# ------------------------------------------------------------------
class TestApproveCommand:
@pytest.mark.asyncio
async def test_approve_executes_pending_command(self):
"""Basic /approve executes the pending command."""
runner = _make_runner()
source = _make_source()
session_key = runner._session_key_for_source(source)
runner._pending_approvals[session_key] = _make_pending_approval()
event = _make_event("/approve")
with patch("tools.terminal_tool.terminal_tool", return_value="done") as mock_term:
result = await runner._handle_approve_command(event)
assert "✅ Command approved and executed" in result
mock_term.assert_called_once_with(command="sudo rm -rf /tmp/test", force=True)
assert session_key not in runner._pending_approvals
@pytest.mark.asyncio
async def test_approve_session_remembers_pattern(self):
"""/approve session approves the pattern for the session."""
runner = _make_runner()
source = _make_source()
session_key = runner._session_key_for_source(source)
runner._pending_approvals[session_key] = _make_pending_approval()
event = _make_event("/approve session")
with (
patch("tools.terminal_tool.terminal_tool", return_value="done"),
patch("tools.approval.approve_session") as mock_session,
):
result = await runner._handle_approve_command(event)
assert "pattern approved for this session" in result
mock_session.assert_called_once_with(session_key, "sudo")
@pytest.mark.asyncio
async def test_approve_always_approves_permanently(self):
"""/approve always approves the pattern permanently."""
runner = _make_runner()
source = _make_source()
session_key = runner._session_key_for_source(source)
runner._pending_approvals[session_key] = _make_pending_approval()
event = _make_event("/approve always")
with (
patch("tools.terminal_tool.terminal_tool", return_value="done"),
patch("tools.approval.approve_permanent") as mock_perm,
):
result = await runner._handle_approve_command(event)
assert "pattern approved permanently" in result
mock_perm.assert_called_once_with("sudo")
@pytest.mark.asyncio
async def test_approve_no_pending(self):
"""/approve with no pending approval returns helpful message."""
runner = _make_runner()
event = _make_event("/approve")
result = await runner._handle_approve_command(event)
assert "No pending command" in result
@pytest.mark.asyncio
async def test_approve_expired(self):
"""/approve on a timed-out approval rejects it."""
runner = _make_runner()
source = _make_source()
session_key = runner._session_key_for_source(source)
approval = _make_pending_approval()
approval["timestamp"] = time.time() - 600 # 10 minutes ago
runner._pending_approvals[session_key] = approval
event = _make_event("/approve")
result = await runner._handle_approve_command(event)
assert "expired" in result
assert session_key not in runner._pending_approvals
# ------------------------------------------------------------------
# /deny command
# ------------------------------------------------------------------
class TestDenyCommand:
@pytest.mark.asyncio
async def test_deny_clears_pending(self):
"""/deny clears the pending approval."""
runner = _make_runner()
source = _make_source()
session_key = runner._session_key_for_source(source)
runner._pending_approvals[session_key] = _make_pending_approval()
event = _make_event("/deny")
result = await runner._handle_deny_command(event)
assert "❌ Command denied" in result
assert session_key not in runner._pending_approvals
@pytest.mark.asyncio
async def test_deny_no_pending(self):
"""/deny with no pending approval returns helpful message."""
runner = _make_runner()
event = _make_event("/deny")
result = await runner._handle_deny_command(event)
assert "No pending command" in result
# ------------------------------------------------------------------
# Bare "yes" must NOT trigger approval
# ------------------------------------------------------------------
class TestBareTextNoLongerApproves:
@pytest.mark.asyncio
async def test_yes_does_not_execute_pending_command(self):
"""Saying 'yes' in normal conversation must not execute a pending command.
This is the core bug from issue #1888: bare text matching against
'yes'/'no' could intercept unrelated user messages.
"""
runner = _make_runner()
source = _make_source()
session_key = runner._session_key_for_source(source)
runner._pending_approvals[session_key] = _make_pending_approval()
# Simulate the user saying "yes" as a normal message.
# The old code would have executed the pending command.
# Now it should fall through to normal processing (agent handles it).
event = _make_event("yes")
# The approval should still be pending — "yes" is not /approve
# We can't easily run _handle_message end-to-end, but we CAN verify
# the old text-matching block no longer exists by confirming the
# approval is untouched after the command dispatch section.
# The key assertion is that _pending_approvals is NOT consumed.
assert session_key in runner._pending_approvals
# ------------------------------------------------------------------
# Approval hint appended to response
# ------------------------------------------------------------------
class TestApprovalHint:
def test_approval_hint_appended_to_response(self):
"""When a pending approval is collected, structured instructions
should be appended to the agent response."""
# This tests the approval collection logic at the end of _handle_message.
# We verify the hint format directly.
cmd = "sudo rm -rf /tmp/dangerous"
cmd_preview = cmd
hint = (
f"\n\n⚠️ **Dangerous command requires approval:**\n"
f"```\n{cmd_preview}\n```\n"
f"Reply `/approve` to execute, `/approve session` to approve this pattern "
f"for the session, or `/deny` to cancel."
)
assert "/approve" in hint
assert "/deny" in hint
assert cmd in hint

View file

@ -0,0 +1,180 @@
"""Tests for proactive memory flush on session expiry.
Verifies that:
1. _is_session_expired() works from a SessionEntry alone (no source needed)
2. The sync callback is no longer called in get_or_create_session
3. _pre_flushed_sessions tracking works correctly
4. The background watcher can detect expired sessions
"""
import pytest
from datetime import datetime, timedelta
from pathlib import Path
from unittest.mock import patch, MagicMock
from gateway.config import Platform, GatewayConfig, SessionResetPolicy
from gateway.session import SessionSource, SessionStore, SessionEntry
@pytest.fixture()
def idle_store(tmp_path):
"""SessionStore with a 60-minute idle reset policy."""
config = GatewayConfig(
default_reset_policy=SessionResetPolicy(mode="idle", idle_minutes=60),
)
with patch("gateway.session.SessionStore._ensure_loaded"):
s = SessionStore(sessions_dir=tmp_path, config=config)
s._db = None
s._loaded = True
return s
@pytest.fixture()
def no_reset_store(tmp_path):
"""SessionStore with no reset policy (mode=none)."""
config = GatewayConfig(
default_reset_policy=SessionResetPolicy(mode="none"),
)
with patch("gateway.session.SessionStore._ensure_loaded"):
s = SessionStore(sessions_dir=tmp_path, config=config)
s._db = None
s._loaded = True
return s
class TestIsSessionExpired:
"""_is_session_expired should detect expiry from entry alone."""
def test_idle_session_expired(self, idle_store):
entry = SessionEntry(
session_key="agent:main:telegram:dm",
session_id="sid_1",
created_at=datetime.now() - timedelta(hours=3),
updated_at=datetime.now() - timedelta(minutes=120),
platform=Platform.TELEGRAM,
chat_type="dm",
)
assert idle_store._is_session_expired(entry) is True
def test_active_session_not_expired(self, idle_store):
entry = SessionEntry(
session_key="agent:main:telegram:dm",
session_id="sid_2",
created_at=datetime.now() - timedelta(hours=1),
updated_at=datetime.now() - timedelta(minutes=10),
platform=Platform.TELEGRAM,
chat_type="dm",
)
assert idle_store._is_session_expired(entry) is False
def test_none_mode_never_expires(self, no_reset_store):
entry = SessionEntry(
session_key="agent:main:telegram:dm",
session_id="sid_3",
created_at=datetime.now() - timedelta(days=30),
updated_at=datetime.now() - timedelta(days=30),
platform=Platform.TELEGRAM,
chat_type="dm",
)
assert no_reset_store._is_session_expired(entry) is False
def test_active_processes_prevent_expiry(self, idle_store):
"""Sessions with active background processes should never expire."""
idle_store._has_active_processes_fn = lambda key: True
entry = SessionEntry(
session_key="agent:main:telegram:dm",
session_id="sid_4",
created_at=datetime.now() - timedelta(hours=5),
updated_at=datetime.now() - timedelta(hours=5),
platform=Platform.TELEGRAM,
chat_type="dm",
)
assert idle_store._is_session_expired(entry) is False
def test_daily_mode_expired(self, tmp_path):
"""Daily mode should expire sessions from before today's reset hour."""
config = GatewayConfig(
default_reset_policy=SessionResetPolicy(mode="daily", at_hour=4),
)
with patch("gateway.session.SessionStore._ensure_loaded"):
store = SessionStore(sessions_dir=tmp_path, config=config)
store._db = None
store._loaded = True
entry = SessionEntry(
session_key="agent:main:telegram:dm",
session_id="sid_5",
created_at=datetime.now() - timedelta(days=2),
updated_at=datetime.now() - timedelta(days=2),
platform=Platform.TELEGRAM,
chat_type="dm",
)
assert store._is_session_expired(entry) is True
class TestGetOrCreateSessionNoCallback:
"""get_or_create_session should NOT call a sync flush callback."""
def test_auto_reset_cleans_pre_flushed_marker(self, idle_store):
"""When a session auto-resets, the pre_flushed marker should be discarded."""
source = SessionSource(
platform=Platform.TELEGRAM,
chat_id="123",
chat_type="dm",
)
# Create initial session
entry1 = idle_store.get_or_create_session(source)
old_sid = entry1.session_id
# Simulate the watcher having flushed it
idle_store._pre_flushed_sessions.add(old_sid)
# Simulate the session going idle
entry1.updated_at = datetime.now() - timedelta(minutes=120)
idle_store._save()
# Next call should auto-reset
entry2 = idle_store.get_or_create_session(source)
assert entry2.session_id != old_sid
assert entry2.was_auto_reset is True
# The old session_id should be removed from pre_flushed
assert old_sid not in idle_store._pre_flushed_sessions
def test_no_sync_callback_invoked(self, idle_store):
"""No synchronous callback should block during auto-reset."""
source = SessionSource(
platform=Platform.TELEGRAM,
chat_id="123",
chat_type="dm",
)
entry1 = idle_store.get_or_create_session(source)
entry1.updated_at = datetime.now() - timedelta(minutes=120)
idle_store._save()
# Verify no _on_auto_reset attribute
assert not hasattr(idle_store, '_on_auto_reset')
# This should NOT block (no sync LLM call)
entry2 = idle_store.get_or_create_session(source)
assert entry2.was_auto_reset is True
class TestPreFlushedSessionsTracking:
"""The _pre_flushed_sessions set should prevent double-flushing."""
def test_starts_empty(self, idle_store):
assert len(idle_store._pre_flushed_sessions) == 0
def test_add_and_check(self, idle_store):
idle_store._pre_flushed_sessions.add("sid_old")
assert "sid_old" in idle_store._pre_flushed_sessions
assert "sid_other" not in idle_store._pre_flushed_sessions
def test_discard_on_reset(self, idle_store):
"""discard should remove without raising if not present."""
idle_store._pre_flushed_sessions.add("sid_a")
idle_store._pre_flushed_sessions.discard("sid_a")
assert "sid_a" not in idle_store._pre_flushed_sessions
# discard on non-existent should not raise
idle_store._pre_flushed_sessions.discard("sid_nonexistent")

View file

@ -0,0 +1,322 @@
"""Tests for /background gateway slash command.
Tests the _handle_background_command handler (run a prompt in a separate
background session) across gateway messenger platforms.
"""
import asyncio
import os
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from gateway.config import Platform
from gateway.platforms.base import MessageEvent
from gateway.session import SessionSource
def _make_event(text="/background", platform=Platform.TELEGRAM,
user_id="12345", chat_id="67890"):
"""Build a MessageEvent for testing."""
source = SessionSource(
platform=platform,
user_id=user_id,
chat_id=chat_id,
user_name="testuser",
)
return MessageEvent(text=text, source=source)
def _make_runner():
"""Create a bare GatewayRunner with minimal mocks."""
from gateway.run import GatewayRunner
runner = object.__new__(GatewayRunner)
runner.adapters = {}
runner._voice_mode = {}
runner._session_db = None
runner._reasoning_config = None
runner._provider_routing = {}
runner._fallback_model = None
runner._running_agents = {}
mock_store = MagicMock()
runner.session_store = mock_store
from gateway.hooks import HookRegistry
runner.hooks = HookRegistry()
return runner
# ---------------------------------------------------------------------------
# _handle_background_command
# ---------------------------------------------------------------------------
class TestHandleBackgroundCommand:
"""Tests for GatewayRunner._handle_background_command."""
@pytest.mark.asyncio
async def test_no_prompt_shows_usage(self):
"""Running /background with no prompt shows usage."""
runner = _make_runner()
event = _make_event(text="/background")
result = await runner._handle_background_command(event)
assert "Usage:" in result
assert "/background" in result
@pytest.mark.asyncio
async def test_bg_alias_no_prompt_shows_usage(self):
"""Running /bg with no prompt shows usage."""
runner = _make_runner()
event = _make_event(text="/bg")
result = await runner._handle_background_command(event)
assert "Usage:" in result
@pytest.mark.asyncio
async def test_empty_prompt_shows_usage(self):
"""Running /background with only whitespace shows usage."""
runner = _make_runner()
event = _make_event(text="/background ")
result = await runner._handle_background_command(event)
assert "Usage:" in result
@pytest.mark.asyncio
async def test_valid_prompt_starts_task(self):
"""Running /background with a prompt returns confirmation and starts task."""
runner = _make_runner()
# Patch asyncio.create_task to capture the coroutine
created_tasks = []
original_create_task = asyncio.create_task
def capture_task(coro, *args, **kwargs):
# Close the coroutine to avoid warnings
coro.close()
mock_task = MagicMock()
created_tasks.append(mock_task)
return mock_task
with patch("gateway.run.asyncio.create_task", side_effect=capture_task):
event = _make_event(text="/background Summarize the top HN stories")
result = await runner._handle_background_command(event)
assert "🔄" in result
assert "Background task started" in result
assert "bg_" in result # task ID starts with bg_
assert "Summarize the top HN stories" in result
assert len(created_tasks) == 1 # background task was created
@pytest.mark.asyncio
async def test_prompt_truncated_in_preview(self):
"""Long prompts are truncated to 60 chars in the confirmation message."""
runner = _make_runner()
long_prompt = "A" * 100
with patch("gateway.run.asyncio.create_task", side_effect=lambda c, **kw: (c.close(), MagicMock())[1]):
event = _make_event(text=f"/background {long_prompt}")
result = await runner._handle_background_command(event)
assert "..." in result
# Should not contain the full prompt
assert long_prompt not in result
@pytest.mark.asyncio
async def test_task_id_is_unique(self):
"""Each background task gets a unique task ID."""
runner = _make_runner()
task_ids = set()
with patch("gateway.run.asyncio.create_task", side_effect=lambda c, **kw: (c.close(), MagicMock())[1]):
for i in range(5):
event = _make_event(text=f"/background task {i}")
result = await runner._handle_background_command(event)
# Extract task ID from result (format: "Task ID: bg_HHMMSS_hex")
for line in result.split("\n"):
if "Task ID:" in line:
tid = line.split("Task ID:")[1].strip()
task_ids.add(tid)
assert len(task_ids) == 5 # all unique
@pytest.mark.asyncio
async def test_works_across_platforms(self):
"""The /background command works for all platforms."""
for platform in [Platform.TELEGRAM, Platform.DISCORD, Platform.SLACK]:
runner = _make_runner()
with patch("gateway.run.asyncio.create_task", side_effect=lambda c, **kw: (c.close(), MagicMock())[1]):
event = _make_event(
text="/background test task",
platform=platform,
)
result = await runner._handle_background_command(event)
assert "Background task started" in result
# ---------------------------------------------------------------------------
# _run_background_task
# ---------------------------------------------------------------------------
class TestRunBackgroundTask:
"""Tests for GatewayRunner._run_background_task (the actual execution)."""
@pytest.mark.asyncio
async def test_no_adapter_returns_silently(self):
"""When no adapter is available, the task returns without error."""
runner = _make_runner()
source = SessionSource(
platform=Platform.TELEGRAM,
user_id="12345",
chat_id="67890",
user_name="testuser",
)
# No adapters set — should not raise
await runner._run_background_task("test prompt", source, "bg_test")
@pytest.mark.asyncio
async def test_no_credentials_sends_error(self):
"""When provider credentials are missing, an error is sent."""
runner = _make_runner()
mock_adapter = AsyncMock()
mock_adapter.send = AsyncMock()
runner.adapters[Platform.TELEGRAM] = mock_adapter
source = SessionSource(
platform=Platform.TELEGRAM,
user_id="12345",
chat_id="67890",
user_name="testuser",
)
with patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": None}):
await runner._run_background_task("test prompt", source, "bg_test")
# Should have sent an error message
mock_adapter.send.assert_called_once()
call_args = mock_adapter.send.call_args
assert "failed" in call_args[1].get("content", call_args[0][1] if len(call_args[0]) > 1 else "").lower()
@pytest.mark.asyncio
async def test_successful_task_sends_result(self):
"""When the agent completes successfully, the result is sent."""
runner = _make_runner()
mock_adapter = AsyncMock()
mock_adapter.send = AsyncMock()
mock_adapter.extract_media = MagicMock(return_value=([], "Hello from background!"))
mock_adapter.extract_images = MagicMock(return_value=([], "Hello from background!"))
runner.adapters[Platform.TELEGRAM] = mock_adapter
source = SessionSource(
platform=Platform.TELEGRAM,
user_id="12345",
chat_id="67890",
user_name="testuser",
)
mock_result = {"final_response": "Hello from background!", "messages": []}
with patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "test-key"}), \
patch("run_agent.AIAgent") as MockAgent:
mock_agent_instance = MagicMock()
mock_agent_instance.run_conversation.return_value = mock_result
MockAgent.return_value = mock_agent_instance
await runner._run_background_task("say hello", source, "bg_test")
# Should have sent the result
mock_adapter.send.assert_called_once()
call_args = mock_adapter.send.call_args
content = call_args[1].get("content", call_args[0][1] if len(call_args[0]) > 1 else "")
assert "Background task complete" in content
assert "Hello from background!" in content
@pytest.mark.asyncio
async def test_exception_sends_error_message(self):
"""When the agent raises an exception, an error message is sent."""
runner = _make_runner()
mock_adapter = AsyncMock()
mock_adapter.send = AsyncMock()
runner.adapters[Platform.TELEGRAM] = mock_adapter
source = SessionSource(
platform=Platform.TELEGRAM,
user_id="12345",
chat_id="67890",
user_name="testuser",
)
with patch("gateway.run._resolve_runtime_agent_kwargs", side_effect=RuntimeError("boom")):
await runner._run_background_task("test prompt", source, "bg_test")
mock_adapter.send.assert_called_once()
call_args = mock_adapter.send.call_args
content = call_args[1].get("content", call_args[0][1] if len(call_args[0]) > 1 else "")
assert "failed" in content.lower()
# ---------------------------------------------------------------------------
# /background in help and known_commands
# ---------------------------------------------------------------------------
class TestBackgroundInHelp:
"""Verify /background appears in help text and known commands."""
@pytest.mark.asyncio
async def test_background_in_help_output(self):
"""The /help output includes /background."""
runner = _make_runner()
event = _make_event(text="/help")
result = await runner._handle_help_command(event)
assert "/background" in result
def test_background_is_known_command(self):
"""The /background command is in GATEWAY_KNOWN_COMMANDS."""
from hermes_cli.commands import GATEWAY_KNOWN_COMMANDS
assert "background" in GATEWAY_KNOWN_COMMANDS
def test_bg_alias_is_known_command(self):
"""The /bg alias is in GATEWAY_KNOWN_COMMANDS."""
from hermes_cli.commands import GATEWAY_KNOWN_COMMANDS
assert "bg" in GATEWAY_KNOWN_COMMANDS
# ---------------------------------------------------------------------------
# CLI /background command definition
# ---------------------------------------------------------------------------
class TestBackgroundInCLICommands:
"""Verify /background is registered in the CLI command system."""
def test_background_in_commands_dict(self):
"""The /background command is in the COMMANDS dict."""
from hermes_cli.commands import COMMANDS
assert "/background" in COMMANDS
def test_bg_alias_in_commands_dict(self):
"""The /bg alias is in the COMMANDS dict."""
from hermes_cli.commands import COMMANDS
assert "/bg" in COMMANDS
def test_background_in_session_category(self):
"""The /background command is in the Session category."""
from hermes_cli.commands import COMMANDS_BY_CATEGORY
assert "/background" in COMMANDS_BY_CATEGORY["Session"]
def test_background_autocompletes(self):
"""The /background command appears in autocomplete results."""
from hermes_cli.commands import SlashCommandCompleter
from prompt_toolkit.document import Document
completer = SlashCommandCompleter()
doc = Document("backgro") # Partial match
completions = list(completer.get_completions(doc, None))
# Text doesn't start with / so no completions
assert len(completions) == 0
doc = Document("/backgro") # With slash prefix
completions = list(completer.get_completions(doc, None))
cmd_displays = [str(c.display) for c in completions]
assert any("/background" in d for d in cmd_displays)

View file

@ -0,0 +1,245 @@
"""Tests for configurable background process notification modes.
The gateway process watcher pushes status updates to users' chats when
background terminal commands run. ``display.background_process_notifications``
controls verbosity: off | result | error | all (default).
Contributed by @PeterFile (PR #593), reimplemented on current main.
"""
import asyncio
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
import pytest
from gateway.config import GatewayConfig, Platform
from gateway.run import GatewayRunner
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
class _FakeRegistry:
"""Return pre-canned sessions, then None once exhausted."""
def __init__(self, sessions):
self._sessions = list(sessions)
def get(self, session_id):
if self._sessions:
return self._sessions.pop(0)
return None
def _build_runner(monkeypatch, tmp_path, mode: str) -> GatewayRunner:
"""Create a GatewayRunner with a fake config for the given mode."""
(tmp_path / "config.yaml").write_text(
f"display:\n background_process_notifications: {mode}\n",
encoding="utf-8",
)
import gateway.run as gateway_run
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
runner = GatewayRunner(GatewayConfig())
adapter = SimpleNamespace(send=AsyncMock())
runner.adapters[Platform.TELEGRAM] = adapter
return runner
def _watcher_dict(session_id="proc_test", thread_id=""):
d = {
"session_id": session_id,
"check_interval": 0,
"platform": "telegram",
"chat_id": "123",
}
if thread_id:
d["thread_id"] = thread_id
return d
# ---------------------------------------------------------------------------
# _load_background_notifications_mode unit tests
# ---------------------------------------------------------------------------
class TestLoadBackgroundNotificationsMode:
def test_defaults_to_all(self, monkeypatch, tmp_path):
import gateway.run as gw
monkeypatch.setattr(gw, "_hermes_home", tmp_path)
monkeypatch.delenv("HERMES_BACKGROUND_NOTIFICATIONS", raising=False)
assert GatewayRunner._load_background_notifications_mode() == "all"
def test_reads_config_yaml(self, monkeypatch, tmp_path):
(tmp_path / "config.yaml").write_text(
"display:\n background_process_notifications: error\n"
)
import gateway.run as gw
monkeypatch.setattr(gw, "_hermes_home", tmp_path)
monkeypatch.delenv("HERMES_BACKGROUND_NOTIFICATIONS", raising=False)
assert GatewayRunner._load_background_notifications_mode() == "error"
def test_env_var_overrides_config(self, monkeypatch, tmp_path):
(tmp_path / "config.yaml").write_text(
"display:\n background_process_notifications: error\n"
)
import gateway.run as gw
monkeypatch.setattr(gw, "_hermes_home", tmp_path)
monkeypatch.setenv("HERMES_BACKGROUND_NOTIFICATIONS", "off")
assert GatewayRunner._load_background_notifications_mode() == "off"
def test_false_value_maps_to_off(self, monkeypatch, tmp_path):
(tmp_path / "config.yaml").write_text(
"display:\n background_process_notifications: false\n"
)
import gateway.run as gw
monkeypatch.setattr(gw, "_hermes_home", tmp_path)
monkeypatch.delenv("HERMES_BACKGROUND_NOTIFICATIONS", raising=False)
assert GatewayRunner._load_background_notifications_mode() == "off"
def test_invalid_value_defaults_to_all(self, monkeypatch, tmp_path):
(tmp_path / "config.yaml").write_text(
"display:\n background_process_notifications: banana\n"
)
import gateway.run as gw
monkeypatch.setattr(gw, "_hermes_home", tmp_path)
monkeypatch.delenv("HERMES_BACKGROUND_NOTIFICATIONS", raising=False)
assert GatewayRunner._load_background_notifications_mode() == "all"
# ---------------------------------------------------------------------------
# _run_process_watcher integration tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
@pytest.mark.parametrize(
("mode", "sessions", "expected_calls", "expected_fragment"),
[
# all mode: running output → sends update
(
"all",
[
SimpleNamespace(output_buffer="building...\n", exited=False, exit_code=None),
None, # process disappears → watcher exits
],
1,
"is still running",
),
# result mode: running output → no update
(
"result",
[
SimpleNamespace(output_buffer="building...\n", exited=False, exit_code=None),
None,
],
0,
None,
),
# off mode: exited process → no notification
(
"off",
[SimpleNamespace(output_buffer="done\n", exited=True, exit_code=0)],
0,
None,
),
# result mode: exited → notifies
(
"result",
[SimpleNamespace(output_buffer="done\n", exited=True, exit_code=0)],
1,
"finished with exit code 0",
),
# error mode: exit 0 → no notification
(
"error",
[SimpleNamespace(output_buffer="done\n", exited=True, exit_code=0)],
0,
None,
),
# error mode: exit 1 → notifies
(
"error",
[SimpleNamespace(output_buffer="traceback\n", exited=True, exit_code=1)],
1,
"finished with exit code 1",
),
# all mode: exited → notifies
(
"all",
[SimpleNamespace(output_buffer="ok\n", exited=True, exit_code=0)],
1,
"finished with exit code 0",
),
],
)
async def test_run_process_watcher_respects_notification_mode(
monkeypatch, tmp_path, mode, sessions, expected_calls, expected_fragment
):
import tools.process_registry as pr_module
monkeypatch.setattr(pr_module, "process_registry", _FakeRegistry(sessions))
# Patch asyncio.sleep to avoid real delays
async def _instant_sleep(*_a, **_kw):
pass
monkeypatch.setattr(asyncio, "sleep", _instant_sleep)
runner = _build_runner(monkeypatch, tmp_path, mode)
adapter = runner.adapters[Platform.TELEGRAM]
await runner._run_process_watcher(_watcher_dict())
assert adapter.send.await_count == expected_calls, (
f"mode={mode}: expected {expected_calls} sends, got {adapter.send.await_count}"
)
if expected_fragment is not None:
sent_message = adapter.send.await_args.args[1]
assert expected_fragment in sent_message
@pytest.mark.asyncio
async def test_thread_id_passed_to_send(monkeypatch, tmp_path):
"""thread_id from watcher dict is forwarded as metadata to adapter.send()."""
import tools.process_registry as pr_module
sessions = [SimpleNamespace(output_buffer="done\n", exited=True, exit_code=0)]
monkeypatch.setattr(pr_module, "process_registry", _FakeRegistry(sessions))
async def _instant_sleep(*_a, **_kw):
pass
monkeypatch.setattr(asyncio, "sleep", _instant_sleep)
runner = _build_runner(monkeypatch, tmp_path, "all")
adapter = runner.adapters[Platform.TELEGRAM]
await runner._run_process_watcher(_watcher_dict(thread_id="42"))
assert adapter.send.await_count == 1
_, kwargs = adapter.send.call_args
assert kwargs["metadata"] == {"thread_id": "42"}
@pytest.mark.asyncio
async def test_no_thread_id_sends_no_metadata(monkeypatch, tmp_path):
"""When thread_id is empty, metadata should be None (general topic)."""
import tools.process_registry as pr_module
sessions = [SimpleNamespace(output_buffer="done\n", exited=True, exit_code=0)]
monkeypatch.setattr(pr_module, "process_registry", _FakeRegistry(sessions))
async def _instant_sleep(*_a, **_kw):
pass
monkeypatch.setattr(asyncio, "sleep", _instant_sleep)
runner = _build_runner(monkeypatch, tmp_path, "all")
adapter = runner.adapters[Platform.TELEGRAM]
await runner._run_process_watcher(_watcher_dict())
assert adapter.send.await_count == 1
_, kwargs = adapter.send.call_args
assert kwargs["metadata"] is None

View file

@ -0,0 +1,135 @@
"""Tests for BasePlatformAdapter topic-aware session handling."""
import asyncio
from types import SimpleNamespace
import pytest
from gateway.config import Platform, PlatformConfig
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, SendResult
from gateway.session import SessionSource, build_session_key
class DummyTelegramAdapter(BasePlatformAdapter):
def __init__(self):
super().__init__(PlatformConfig(enabled=True, token="fake-token"), Platform.TELEGRAM)
self.sent = []
self.typing = []
async def connect(self) -> bool:
return True
async def disconnect(self) -> None:
return None
async def send(self, chat_id, content, reply_to=None, metadata=None) -> SendResult:
self.sent.append(
{
"chat_id": chat_id,
"content": content,
"reply_to": reply_to,
"metadata": metadata,
}
)
return SendResult(success=True, message_id="1")
async def send_typing(self, chat_id: str, metadata=None) -> None:
self.typing.append({"chat_id": chat_id, "metadata": metadata})
return None
async def get_chat_info(self, chat_id: str):
return {"id": chat_id}
def _make_event(chat_id: str, thread_id: str, message_id: str = "1") -> MessageEvent:
return MessageEvent(
text="hello",
source=SessionSource(
platform=Platform.TELEGRAM,
chat_id=chat_id,
chat_type="group",
thread_id=thread_id,
),
message_id=message_id,
)
class TestBasePlatformTopicSessions:
@pytest.mark.asyncio
async def test_handle_message_does_not_interrupt_different_topic(self, monkeypatch):
adapter = DummyTelegramAdapter()
adapter.set_message_handler(lambda event: asyncio.sleep(0, result=None))
active_event = _make_event("-1001", "10")
adapter._active_sessions[build_session_key(active_event.source)] = asyncio.Event()
scheduled = []
def fake_create_task(coro):
scheduled.append(coro)
coro.close()
return SimpleNamespace()
monkeypatch.setattr(asyncio, "create_task", fake_create_task)
await adapter.handle_message(_make_event("-1001", "11"))
assert len(scheduled) == 1
assert adapter._pending_messages == {}
@pytest.mark.asyncio
async def test_handle_message_interrupts_same_topic(self, monkeypatch):
adapter = DummyTelegramAdapter()
adapter.set_message_handler(lambda event: asyncio.sleep(0, result=None))
active_event = _make_event("-1001", "10")
adapter._active_sessions[build_session_key(active_event.source)] = asyncio.Event()
scheduled = []
def fake_create_task(coro):
scheduled.append(coro)
coro.close()
return SimpleNamespace()
monkeypatch.setattr(asyncio, "create_task", fake_create_task)
pending_event = _make_event("-1001", "10", message_id="2")
await adapter.handle_message(pending_event)
assert scheduled == []
assert adapter.get_pending_message(build_session_key(pending_event.source)) == pending_event
@pytest.mark.asyncio
async def test_process_message_background_replies_in_same_topic(self):
adapter = DummyTelegramAdapter()
typing_calls = []
async def handler(_event):
await asyncio.sleep(0)
return "ack"
async def hold_typing(_chat_id, interval=2.0, metadata=None):
typing_calls.append({"chat_id": _chat_id, "metadata": metadata})
await asyncio.Event().wait()
adapter.set_message_handler(handler)
adapter._keep_typing = hold_typing
event = _make_event("-1001", "17585")
await adapter._process_message_background(event, build_session_key(event.source))
assert adapter.sent == [
{
"chat_id": "-1001",
"content": "ack",
"reply_to": "1",
"metadata": {"thread_id": "17585"},
}
]
assert typing_calls == [
{
"chat_id": "-1001",
"metadata": {"thread_id": "17585"},
}
]

View file

@ -0,0 +1,252 @@
"""Tests for gateway/channel_directory.py — channel resolution and display."""
import json
import os
from pathlib import Path
from unittest.mock import patch
from gateway.channel_directory import (
resolve_channel_name,
format_directory_for_display,
load_directory,
_build_from_sessions,
DIRECTORY_PATH,
)
def _write_directory(tmp_path, platforms):
"""Helper to write a fake channel directory."""
data = {"updated_at": "2026-01-01T00:00:00", "platforms": platforms}
cache_file = tmp_path / "channel_directory.json"
cache_file.write_text(json.dumps(data))
return cache_file
class TestLoadDirectory:
def test_missing_file(self, tmp_path):
with patch("gateway.channel_directory.DIRECTORY_PATH", tmp_path / "nope.json"):
result = load_directory()
assert result["updated_at"] is None
assert result["platforms"] == {}
def test_valid_file(self, tmp_path):
cache_file = _write_directory(tmp_path, {
"telegram": [{"id": "123", "name": "John", "type": "dm"}]
})
with patch("gateway.channel_directory.DIRECTORY_PATH", cache_file):
result = load_directory()
assert result["platforms"]["telegram"][0]["name"] == "John"
def test_corrupt_file(self, tmp_path):
cache_file = tmp_path / "channel_directory.json"
cache_file.write_text("{bad json")
with patch("gateway.channel_directory.DIRECTORY_PATH", cache_file):
result = load_directory()
assert result["updated_at"] is None
class TestResolveChannelName:
def _setup(self, tmp_path, platforms):
cache_file = _write_directory(tmp_path, platforms)
return patch("gateway.channel_directory.DIRECTORY_PATH", cache_file)
def test_exact_match(self, tmp_path):
platforms = {
"discord": [
{"id": "111", "name": "bot-home", "guild": "MyServer", "type": "channel"},
{"id": "222", "name": "general", "guild": "MyServer", "type": "channel"},
]
}
with self._setup(tmp_path, platforms):
assert resolve_channel_name("discord", "bot-home") == "111"
assert resolve_channel_name("discord", "#bot-home") == "111"
def test_case_insensitive(self, tmp_path):
platforms = {
"slack": [{"id": "C01", "name": "Engineering", "type": "channel"}]
}
with self._setup(tmp_path, platforms):
assert resolve_channel_name("slack", "engineering") == "C01"
assert resolve_channel_name("slack", "ENGINEERING") == "C01"
def test_guild_qualified_match(self, tmp_path):
platforms = {
"discord": [
{"id": "111", "name": "general", "guild": "ServerA", "type": "channel"},
{"id": "222", "name": "general", "guild": "ServerB", "type": "channel"},
]
}
with self._setup(tmp_path, platforms):
assert resolve_channel_name("discord", "ServerA/general") == "111"
assert resolve_channel_name("discord", "ServerB/general") == "222"
def test_prefix_match_unambiguous(self, tmp_path):
platforms = {
"slack": [
{"id": "C01", "name": "engineering-backend", "type": "channel"},
{"id": "C02", "name": "design-team", "type": "channel"},
]
}
with self._setup(tmp_path, platforms):
# "engineering" prefix matches only one channel
assert resolve_channel_name("slack", "engineering") == "C01"
def test_prefix_match_ambiguous_returns_none(self, tmp_path):
platforms = {
"slack": [
{"id": "C01", "name": "eng-backend", "type": "channel"},
{"id": "C02", "name": "eng-frontend", "type": "channel"},
]
}
with self._setup(tmp_path, platforms):
assert resolve_channel_name("slack", "eng") is None
def test_no_channels_returns_none(self, tmp_path):
with self._setup(tmp_path, {}):
assert resolve_channel_name("telegram", "someone") is None
def test_no_match_returns_none(self, tmp_path):
platforms = {
"telegram": [{"id": "123", "name": "John", "type": "dm"}]
}
with self._setup(tmp_path, platforms):
assert resolve_channel_name("telegram", "nonexistent") is None
def test_topic_name_resolves_to_composite_id(self, tmp_path):
platforms = {
"telegram": [{"id": "-1001:17585", "name": "Coaching Chat / topic 17585", "type": "group"}]
}
with self._setup(tmp_path, platforms):
assert resolve_channel_name("telegram", "Coaching Chat / topic 17585") == "-1001:17585"
class TestBuildFromSessions:
def _write_sessions(self, tmp_path, sessions_data):
"""Write sessions.json at the path _build_from_sessions expects."""
sessions_path = tmp_path / "sessions" / "sessions.json"
sessions_path.parent.mkdir(parents=True)
sessions_path.write_text(json.dumps(sessions_data))
def test_builds_from_sessions_json(self, tmp_path):
self._write_sessions(tmp_path, {
"session_1": {
"origin": {
"platform": "telegram",
"chat_id": "12345",
"chat_name": "Alice",
},
"chat_type": "dm",
},
"session_2": {
"origin": {
"platform": "telegram",
"chat_id": "67890",
"user_name": "Bob",
},
"chat_type": "group",
},
"session_3": {
"origin": {
"platform": "discord",
"chat_id": "99999",
},
},
})
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
entries = _build_from_sessions("telegram")
assert len(entries) == 2
names = {e["name"] for e in entries}
assert "Alice" in names
assert "Bob" in names
def test_missing_sessions_file(self, tmp_path):
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
entries = _build_from_sessions("telegram")
assert entries == []
def test_deduplication_by_chat_id(self, tmp_path):
self._write_sessions(tmp_path, {
"s1": {"origin": {"platform": "telegram", "chat_id": "123", "chat_name": "X"}},
"s2": {"origin": {"platform": "telegram", "chat_id": "123", "chat_name": "X"}},
})
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
entries = _build_from_sessions("telegram")
assert len(entries) == 1
def test_keeps_distinct_topics_with_same_chat_id(self, tmp_path):
self._write_sessions(tmp_path, {
"group_root": {
"origin": {"platform": "telegram", "chat_id": "-1001", "chat_name": "Coaching Chat"},
"chat_type": "group",
},
"topic_a": {
"origin": {
"platform": "telegram",
"chat_id": "-1001",
"chat_name": "Coaching Chat",
"thread_id": "17585",
},
"chat_type": "group",
},
"topic_b": {
"origin": {
"platform": "telegram",
"chat_id": "-1001",
"chat_name": "Coaching Chat",
"thread_id": "17587",
},
"chat_type": "group",
},
})
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
entries = _build_from_sessions("telegram")
ids = {entry["id"] for entry in entries}
names = {entry["name"] for entry in entries}
assert ids == {"-1001", "-1001:17585", "-1001:17587"}
assert "Coaching Chat" in names
assert "Coaching Chat / topic 17585" in names
assert "Coaching Chat / topic 17587" in names
class TestFormatDirectoryForDisplay:
def test_empty_directory(self, tmp_path):
with patch("gateway.channel_directory.DIRECTORY_PATH", tmp_path / "nope.json"):
result = format_directory_for_display()
assert "No messaging platforms" in result
def test_telegram_display(self, tmp_path):
cache_file = _write_directory(tmp_path, {
"telegram": [
{"id": "123", "name": "Alice", "type": "dm"},
{"id": "456", "name": "Dev Group", "type": "group"},
{"id": "-1001:17585", "name": "Coaching Chat / topic 17585", "type": "group"},
]
})
with patch("gateway.channel_directory.DIRECTORY_PATH", cache_file):
result = format_directory_for_display()
assert "Telegram:" in result
assert "telegram:Alice" in result
assert "telegram:Dev Group" in result
assert "telegram:Coaching Chat / topic 17585" in result
def test_discord_grouped_by_guild(self, tmp_path):
cache_file = _write_directory(tmp_path, {
"discord": [
{"id": "1", "name": "general", "guild": "Server1", "type": "channel"},
{"id": "2", "name": "bot-home", "guild": "Server1", "type": "channel"},
{"id": "3", "name": "chat", "guild": "Server2", "type": "channel"},
]
})
with patch("gateway.channel_directory.DIRECTORY_PATH", cache_file):
result = format_directory_for_display()
assert "Discord (Server1):" in result
assert "Discord (Server2):" in result
assert "discord:#general" in result

View file

@ -0,0 +1,194 @@
"""Tests for gateway configuration management."""
from gateway.config import (
GatewayConfig,
HomeChannel,
Platform,
PlatformConfig,
SessionResetPolicy,
load_gateway_config,
)
class TestHomeChannelRoundtrip:
def test_to_dict_from_dict(self):
hc = HomeChannel(platform=Platform.DISCORD, chat_id="999", name="general")
d = hc.to_dict()
restored = HomeChannel.from_dict(d)
assert restored.platform == Platform.DISCORD
assert restored.chat_id == "999"
assert restored.name == "general"
class TestPlatformConfigRoundtrip:
def test_to_dict_from_dict(self):
pc = PlatformConfig(
enabled=True,
token="tok_123",
home_channel=HomeChannel(
platform=Platform.TELEGRAM,
chat_id="555",
name="Home",
),
extra={"foo": "bar"},
)
d = pc.to_dict()
restored = PlatformConfig.from_dict(d)
assert restored.enabled is True
assert restored.token == "tok_123"
assert restored.home_channel.chat_id == "555"
assert restored.extra == {"foo": "bar"}
def test_disabled_no_token(self):
pc = PlatformConfig()
d = pc.to_dict()
restored = PlatformConfig.from_dict(d)
assert restored.enabled is False
assert restored.token is None
class TestGetConnectedPlatforms:
def test_returns_enabled_with_token(self):
config = GatewayConfig(
platforms={
Platform.TELEGRAM: PlatformConfig(enabled=True, token="t"),
Platform.DISCORD: PlatformConfig(enabled=False, token="d"),
Platform.SLACK: PlatformConfig(enabled=True), # no token
},
)
connected = config.get_connected_platforms()
assert Platform.TELEGRAM in connected
assert Platform.DISCORD not in connected
assert Platform.SLACK not in connected
def test_empty_platforms(self):
config = GatewayConfig()
assert config.get_connected_platforms() == []
class TestSessionResetPolicy:
def test_roundtrip(self):
policy = SessionResetPolicy(mode="idle", at_hour=6, idle_minutes=120)
d = policy.to_dict()
restored = SessionResetPolicy.from_dict(d)
assert restored.mode == "idle"
assert restored.at_hour == 6
assert restored.idle_minutes == 120
def test_defaults(self):
policy = SessionResetPolicy()
assert policy.mode == "both"
assert policy.at_hour == 4
assert policy.idle_minutes == 1440
def test_from_dict_treats_null_values_as_defaults(self):
restored = SessionResetPolicy.from_dict(
{"mode": None, "at_hour": None, "idle_minutes": None}
)
assert restored.mode == "both"
assert restored.at_hour == 4
assert restored.idle_minutes == 1440
class TestGatewayConfigRoundtrip:
def test_full_roundtrip(self):
config = GatewayConfig(
platforms={
Platform.TELEGRAM: PlatformConfig(
enabled=True,
token="tok_123",
home_channel=HomeChannel(Platform.TELEGRAM, "123", "Home"),
),
},
reset_triggers=["/new"],
quick_commands={"limits": {"type": "exec", "command": "echo ok"}},
group_sessions_per_user=False,
)
d = config.to_dict()
restored = GatewayConfig.from_dict(d)
assert Platform.TELEGRAM in restored.platforms
assert restored.platforms[Platform.TELEGRAM].token == "tok_123"
assert restored.reset_triggers == ["/new"]
assert restored.quick_commands == {"limits": {"type": "exec", "command": "echo ok"}}
assert restored.group_sessions_per_user is False
def test_roundtrip_preserves_unauthorized_dm_behavior(self):
config = GatewayConfig(
unauthorized_dm_behavior="ignore",
platforms={
Platform.WHATSAPP: PlatformConfig(
enabled=True,
extra={"unauthorized_dm_behavior": "pair"},
),
},
)
restored = GatewayConfig.from_dict(config.to_dict())
assert restored.unauthorized_dm_behavior == "ignore"
assert restored.platforms[Platform.WHATSAPP].extra["unauthorized_dm_behavior"] == "pair"
class TestLoadGatewayConfig:
def test_bridges_quick_commands_from_config_yaml(self, tmp_path, monkeypatch):
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
config_path = hermes_home / "config.yaml"
config_path.write_text(
"quick_commands:\n"
" limits:\n"
" type: exec\n"
" command: echo ok\n",
encoding="utf-8",
)
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
config = load_gateway_config()
assert config.quick_commands == {"limits": {"type": "exec", "command": "echo ok"}}
def test_bridges_group_sessions_per_user_from_config_yaml(self, tmp_path, monkeypatch):
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
config_path = hermes_home / "config.yaml"
config_path.write_text("group_sessions_per_user: false\n", encoding="utf-8")
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
config = load_gateway_config()
assert config.group_sessions_per_user is False
def test_invalid_quick_commands_in_config_yaml_are_ignored(self, tmp_path, monkeypatch):
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
config_path = hermes_home / "config.yaml"
config_path.write_text("quick_commands: not-a-mapping\n", encoding="utf-8")
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
config = load_gateway_config()
assert config.quick_commands == {}
def test_bridges_unauthorized_dm_behavior_from_config_yaml(self, tmp_path, monkeypatch):
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
config_path = hermes_home / "config.yaml"
config_path.write_text(
"unauthorized_dm_behavior: ignore\n"
"whatsapp:\n"
" unauthorized_dm_behavior: pair\n",
encoding="utf-8",
)
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
config = load_gateway_config()
assert config.unauthorized_dm_behavior == "ignore"
assert config.platforms[Platform.WHATSAPP].extra["unauthorized_dm_behavior"] == "pair"

View file

@ -0,0 +1,148 @@
"""Tests for the config.yaml → env var bridge logic in gateway/run.py.
Specifically tests that top-level `cwd:` and `backend:` in config.yaml
are correctly bridged to TERMINAL_CWD / TERMINAL_ENV env vars as
convenience aliases for `terminal.cwd` / `terminal.backend`.
The bridge logic is module-level code in gateway/run.py, so we test
the semantics by reimplementing the relevant config bridge snippet and
asserting the expected env var outcomes.
"""
import os
import json
import pytest
def _simulate_config_bridge(cfg: dict, initial_env: dict | None = None):
"""Simulate the gateway config bridge logic from gateway/run.py.
Returns the resulting env dict (only TERMINAL_* and MESSAGING_CWD keys).
"""
env = dict(initial_env or {})
# --- Replicate lines 54-56: generic top-level bridge (for context) ---
for key, val in cfg.items():
if isinstance(val, (str, int, float, bool)) and key not in env:
env[key] = str(val)
# --- Replicate lines 59-87: terminal config bridge ---
terminal_cfg = cfg.get("terminal", {})
if terminal_cfg and isinstance(terminal_cfg, dict):
terminal_env_map = {
"backend": "TERMINAL_ENV",
"cwd": "TERMINAL_CWD",
"timeout": "TERMINAL_TIMEOUT",
}
for cfg_key, env_var in terminal_env_map.items():
if cfg_key in terminal_cfg:
val = terminal_cfg[cfg_key]
if isinstance(val, list):
env[env_var] = json.dumps(val)
else:
env[env_var] = str(val)
# --- NEW: top-level aliases (the fix being tested) ---
top_level_aliases = {
"cwd": "TERMINAL_CWD",
"backend": "TERMINAL_ENV",
}
for alias_key, alias_env in top_level_aliases.items():
if alias_env not in env:
alias_val = cfg.get(alias_key)
if isinstance(alias_val, str) and alias_val.strip():
env[alias_env] = alias_val.strip()
# --- Replicate lines 144-147: MESSAGING_CWD fallback ---
configured_cwd = env.get("TERMINAL_CWD", "")
if not configured_cwd or configured_cwd in (".", "auto", "cwd"):
messaging_cwd = env.get("MESSAGING_CWD") or "/root" # Path.home() for root
env["TERMINAL_CWD"] = messaging_cwd
return env
class TestTopLevelCwdAlias:
"""Top-level `cwd:` should be treated as `terminal.cwd`."""
def test_top_level_cwd_sets_terminal_cwd(self):
cfg = {"cwd": "/home/hermes/projects"}
result = _simulate_config_bridge(cfg)
assert result["TERMINAL_CWD"] == "/home/hermes/projects"
def test_top_level_backend_sets_terminal_env(self):
cfg = {"backend": "docker"}
result = _simulate_config_bridge(cfg)
assert result["TERMINAL_ENV"] == "docker"
def test_top_level_cwd_and_backend(self):
cfg = {"backend": "local", "cwd": "/home/hermes/projects"}
result = _simulate_config_bridge(cfg)
assert result["TERMINAL_CWD"] == "/home/hermes/projects"
assert result["TERMINAL_ENV"] == "local"
def test_nested_terminal_takes_precedence_over_top_level(self):
"""terminal.cwd should win over top-level cwd."""
cfg = {
"cwd": "/should/not/use",
"terminal": {"cwd": "/home/hermes/real"},
}
result = _simulate_config_bridge(cfg)
assert result["TERMINAL_CWD"] == "/home/hermes/real"
def test_nested_terminal_backend_takes_precedence(self):
cfg = {
"backend": "should-not-use",
"terminal": {"backend": "docker"},
}
result = _simulate_config_bridge(cfg)
assert result["TERMINAL_ENV"] == "docker"
def test_no_cwd_falls_back_to_messaging_cwd(self):
cfg = {}
result = _simulate_config_bridge(cfg, {"MESSAGING_CWD": "/home/hermes/projects"})
assert result["TERMINAL_CWD"] == "/home/hermes/projects"
def test_no_cwd_no_messaging_cwd_falls_back_to_home(self):
cfg = {}
result = _simulate_config_bridge(cfg)
assert result["TERMINAL_CWD"] == "/root" # Path.home() for root user
def test_dot_cwd_triggers_messaging_fallback(self):
"""cwd: '.' should trigger MESSAGING_CWD fallback."""
cfg = {"cwd": "."}
result = _simulate_config_bridge(cfg, {"MESSAGING_CWD": "/home/hermes"})
# "." is stripped but truthy, so it gets set as TERMINAL_CWD
# Then the MESSAGING_CWD fallback does NOT trigger since TERMINAL_CWD
# is set and not in (".", "auto", "cwd").
# Wait — "." IS in the fallback list! So this should fall through.
# Actually the alias sets it to ".", then the messaging fallback
# checks if it's in (".", "auto", "cwd") and overrides.
assert result["TERMINAL_CWD"] == "/home/hermes"
def test_auto_cwd_triggers_messaging_fallback(self):
cfg = {"cwd": "auto"}
result = _simulate_config_bridge(cfg, {"MESSAGING_CWD": "/home/hermes"})
assert result["TERMINAL_CWD"] == "/home/hermes"
def test_empty_cwd_ignored(self):
cfg = {"cwd": ""}
result = _simulate_config_bridge(cfg, {"MESSAGING_CWD": "/home/hermes"})
assert result["TERMINAL_CWD"] == "/home/hermes"
def test_whitespace_only_cwd_ignored(self):
cfg = {"cwd": " "}
result = _simulate_config_bridge(cfg, {"MESSAGING_CWD": "/fallback"})
assert result["TERMINAL_CWD"] == "/fallback"
def test_messaging_cwd_env_var_works(self):
"""MESSAGING_CWD in initial env should be picked up as fallback."""
cfg = {}
result = _simulate_config_bridge(cfg, {"MESSAGING_CWD": "/home/hermes/projects"})
assert result["TERMINAL_CWD"] == "/home/hermes/projects"
def test_top_level_cwd_beats_messaging_cwd(self):
"""Explicit top-level cwd should take precedence over MESSAGING_CWD."""
cfg = {"cwd": "/from/config"}
result = _simulate_config_bridge(cfg, {"MESSAGING_CWD": "/from/env"})
assert result["TERMINAL_CWD"] == "/from/config"

View file

@ -0,0 +1,96 @@
"""Tests for the delivery routing module."""
from gateway.config import Platform, GatewayConfig, PlatformConfig, HomeChannel
from gateway.delivery import DeliveryRouter, DeliveryTarget, parse_deliver_spec
from gateway.session import SessionSource
class TestParseTargetPlatformChat:
def test_explicit_telegram_chat(self):
target = DeliveryTarget.parse("telegram:12345")
assert target.platform == Platform.TELEGRAM
assert target.chat_id == "12345"
assert target.is_explicit is True
def test_platform_only_no_chat_id(self):
target = DeliveryTarget.parse("discord")
assert target.platform == Platform.DISCORD
assert target.chat_id is None
assert target.is_explicit is False
def test_local_target(self):
target = DeliveryTarget.parse("local")
assert target.platform == Platform.LOCAL
assert target.chat_id is None
def test_origin_with_source(self):
origin = SessionSource(platform=Platform.TELEGRAM, chat_id="789", thread_id="42")
target = DeliveryTarget.parse("origin", origin=origin)
assert target.platform == Platform.TELEGRAM
assert target.chat_id == "789"
assert target.thread_id == "42"
assert target.is_origin is True
def test_origin_without_source(self):
target = DeliveryTarget.parse("origin")
assert target.platform == Platform.LOCAL
assert target.is_origin is True
def test_unknown_platform(self):
target = DeliveryTarget.parse("unknown_platform")
assert target.platform == Platform.LOCAL
class TestParseDeliverSpec:
def test_none_returns_default(self):
result = parse_deliver_spec(None)
assert result == "origin"
def test_empty_string_returns_default(self):
result = parse_deliver_spec("")
assert result == "origin"
def test_custom_default(self):
result = parse_deliver_spec(None, default="local")
assert result == "local"
def test_passthrough_string(self):
result = parse_deliver_spec("telegram")
assert result == "telegram"
def test_passthrough_list(self):
result = parse_deliver_spec(["local", "telegram"])
assert result == ["local", "telegram"]
class TestTargetToStringRoundtrip:
def test_origin_roundtrip(self):
origin = SessionSource(platform=Platform.TELEGRAM, chat_id="111", thread_id="42")
target = DeliveryTarget.parse("origin", origin=origin)
assert target.to_string() == "origin"
def test_local_roundtrip(self):
target = DeliveryTarget.parse("local")
assert target.to_string() == "local"
def test_platform_only_roundtrip(self):
target = DeliveryTarget.parse("discord")
assert target.to_string() == "discord"
def test_explicit_chat_roundtrip(self):
target = DeliveryTarget.parse("telegram:999")
s = target.to_string()
assert s == "telegram:999"
reparsed = DeliveryTarget.parse(s)
assert reparsed.platform == Platform.TELEGRAM
assert reparsed.chat_id == "999"
class TestDeliveryRouter:
def test_resolve_targets_does_not_duplicate_local_when_explicit(self):
router = DeliveryRouter(GatewayConfig(always_log_local=True))
targets = router.resolve_targets(["local"])
assert [target.platform for target in targets] == [Platform.LOCAL]

View file

@ -0,0 +1,274 @@
"""Tests for DingTalk platform adapter."""
import asyncio
import json
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch, PropertyMock
import pytest
from gateway.config import Platform, PlatformConfig
# ---------------------------------------------------------------------------
# Requirements check
# ---------------------------------------------------------------------------
class TestDingTalkRequirements:
def test_returns_false_when_sdk_missing(self, monkeypatch):
with patch.dict("sys.modules", {"dingtalk_stream": None}):
monkeypatch.setattr(
"gateway.platforms.dingtalk.DINGTALK_STREAM_AVAILABLE", False
)
from gateway.platforms.dingtalk import check_dingtalk_requirements
assert check_dingtalk_requirements() is False
def test_returns_false_when_env_vars_missing(self, monkeypatch):
monkeypatch.setattr(
"gateway.platforms.dingtalk.DINGTALK_STREAM_AVAILABLE", True
)
monkeypatch.setattr("gateway.platforms.dingtalk.HTTPX_AVAILABLE", True)
monkeypatch.delenv("DINGTALK_CLIENT_ID", raising=False)
monkeypatch.delenv("DINGTALK_CLIENT_SECRET", raising=False)
from gateway.platforms.dingtalk import check_dingtalk_requirements
assert check_dingtalk_requirements() is False
def test_returns_true_when_all_available(self, monkeypatch):
monkeypatch.setattr(
"gateway.platforms.dingtalk.DINGTALK_STREAM_AVAILABLE", True
)
monkeypatch.setattr("gateway.platforms.dingtalk.HTTPX_AVAILABLE", True)
monkeypatch.setenv("DINGTALK_CLIENT_ID", "test-id")
monkeypatch.setenv("DINGTALK_CLIENT_SECRET", "test-secret")
from gateway.platforms.dingtalk import check_dingtalk_requirements
assert check_dingtalk_requirements() is True
# ---------------------------------------------------------------------------
# Adapter construction
# ---------------------------------------------------------------------------
class TestDingTalkAdapterInit:
def test_reads_config_from_extra(self):
from gateway.platforms.dingtalk import DingTalkAdapter
config = PlatformConfig(
enabled=True,
extra={"client_id": "cfg-id", "client_secret": "cfg-secret"},
)
adapter = DingTalkAdapter(config)
assert adapter._client_id == "cfg-id"
assert adapter._client_secret == "cfg-secret"
assert adapter.name == "Dingtalk" # base class uses .title()
def test_falls_back_to_env_vars(self, monkeypatch):
monkeypatch.setenv("DINGTALK_CLIENT_ID", "env-id")
monkeypatch.setenv("DINGTALK_CLIENT_SECRET", "env-secret")
from gateway.platforms.dingtalk import DingTalkAdapter
config = PlatformConfig(enabled=True)
adapter = DingTalkAdapter(config)
assert adapter._client_id == "env-id"
assert adapter._client_secret == "env-secret"
# ---------------------------------------------------------------------------
# Message text extraction
# ---------------------------------------------------------------------------
class TestExtractText:
def test_extracts_dict_text(self):
from gateway.platforms.dingtalk import DingTalkAdapter
msg = MagicMock()
msg.text = {"content": " hello world "}
msg.rich_text = None
assert DingTalkAdapter._extract_text(msg) == "hello world"
def test_extracts_string_text(self):
from gateway.platforms.dingtalk import DingTalkAdapter
msg = MagicMock()
msg.text = "plain text"
msg.rich_text = None
assert DingTalkAdapter._extract_text(msg) == "plain text"
def test_falls_back_to_rich_text(self):
from gateway.platforms.dingtalk import DingTalkAdapter
msg = MagicMock()
msg.text = ""
msg.rich_text = [{"text": "part1"}, {"text": "part2"}, {"image": "url"}]
assert DingTalkAdapter._extract_text(msg) == "part1 part2"
def test_returns_empty_for_no_content(self):
from gateway.platforms.dingtalk import DingTalkAdapter
msg = MagicMock()
msg.text = ""
msg.rich_text = None
assert DingTalkAdapter._extract_text(msg) == ""
# ---------------------------------------------------------------------------
# Deduplication
# ---------------------------------------------------------------------------
class TestDeduplication:
def test_first_message_not_duplicate(self):
from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
assert adapter._is_duplicate("msg-1") is False
def test_second_same_message_is_duplicate(self):
from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
adapter._is_duplicate("msg-1")
assert adapter._is_duplicate("msg-1") is True
def test_different_messages_not_duplicate(self):
from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
adapter._is_duplicate("msg-1")
assert adapter._is_duplicate("msg-2") is False
def test_cache_cleanup_on_overflow(self):
from gateway.platforms.dingtalk import DingTalkAdapter, DEDUP_MAX_SIZE
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
# Fill beyond max
for i in range(DEDUP_MAX_SIZE + 10):
adapter._is_duplicate(f"msg-{i}")
# Cache should have been pruned
assert len(adapter._seen_messages) <= DEDUP_MAX_SIZE + 10
# ---------------------------------------------------------------------------
# Send
# ---------------------------------------------------------------------------
class TestSend:
@pytest.mark.asyncio
async def test_send_posts_to_webhook(self):
from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.text = "OK"
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_response)
adapter._http_client = mock_client
result = await adapter.send(
"chat-123", "Hello!",
metadata={"session_webhook": "https://dingtalk.example/webhook"}
)
assert result.success is True
mock_client.post.assert_called_once()
call_args = mock_client.post.call_args
assert call_args[0][0] == "https://dingtalk.example/webhook"
payload = call_args[1]["json"]
assert payload["msgtype"] == "markdown"
assert payload["markdown"]["title"] == "Hermes"
assert payload["markdown"]["text"] == "Hello!"
@pytest.mark.asyncio
async def test_send_fails_without_webhook(self):
from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
adapter._http_client = AsyncMock()
result = await adapter.send("chat-123", "Hello!")
assert result.success is False
assert "session_webhook" in result.error
@pytest.mark.asyncio
async def test_send_uses_cached_webhook(self):
from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
mock_response = MagicMock()
mock_response.status_code = 200
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_response)
adapter._http_client = mock_client
adapter._session_webhooks["chat-123"] = "https://cached.example/webhook"
result = await adapter.send("chat-123", "Hello!")
assert result.success is True
assert mock_client.post.call_args[0][0] == "https://cached.example/webhook"
@pytest.mark.asyncio
async def test_send_handles_http_error(self):
from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
mock_response = MagicMock()
mock_response.status_code = 400
mock_response.text = "Bad Request"
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_response)
adapter._http_client = mock_client
result = await adapter.send(
"chat-123", "Hello!",
metadata={"session_webhook": "https://example/webhook"}
)
assert result.success is False
assert "400" in result.error
# ---------------------------------------------------------------------------
# Connect / disconnect
# ---------------------------------------------------------------------------
class TestConnect:
@pytest.mark.asyncio
async def test_connect_fails_without_sdk(self, monkeypatch):
monkeypatch.setattr(
"gateway.platforms.dingtalk.DINGTALK_STREAM_AVAILABLE", False
)
from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
result = await adapter.connect()
assert result is False
@pytest.mark.asyncio
async def test_connect_fails_without_credentials(self):
from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
adapter._client_id = ""
adapter._client_secret = ""
result = await adapter.connect()
assert result is False
@pytest.mark.asyncio
async def test_disconnect_cleans_up(self):
from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
adapter._session_webhooks["a"] = "http://x"
adapter._seen_messages["b"] = 1.0
adapter._http_client = AsyncMock()
adapter._stream_task = None
await adapter.disconnect()
assert len(adapter._session_webhooks) == 0
assert len(adapter._seen_messages) == 0
assert adapter._http_client is None
# ---------------------------------------------------------------------------
# Platform enum
# ---------------------------------------------------------------------------
class TestPlatformEnum:
def test_dingtalk_in_platform_enum(self):
assert Platform.DINGTALK.value == "dingtalk"

View file

@ -0,0 +1,117 @@
"""Tests for Discord bot message filtering (DISCORD_ALLOW_BOTS)."""
import asyncio
import os
import unittest
from unittest.mock import AsyncMock, MagicMock, patch
def _make_author(*, bot: bool = False, is_self: bool = False):
"""Create a mock Discord author."""
author = MagicMock()
author.bot = bot
author.id = 99999 if is_self else 12345
author.name = "TestBot" if bot else "TestUser"
author.display_name = author.name
return author
def _make_message(*, author=None, content="hello", mentions=None, is_dm=False):
"""Create a mock Discord message."""
msg = MagicMock()
msg.author = author or _make_author()
msg.content = content
msg.attachments = []
msg.mentions = mentions or []
if is_dm:
import discord
msg.channel = MagicMock(spec=discord.DMChannel)
msg.channel.id = 111
else:
msg.channel = MagicMock()
msg.channel.id = 222
msg.channel.name = "test-channel"
msg.channel.guild = MagicMock()
msg.channel.guild.name = "TestServer"
# Make isinstance checks fail for DMChannel and Thread
type(msg.channel).__name__ = "TextChannel"
return msg
class TestDiscordBotFilter(unittest.TestCase):
"""Test the DISCORD_ALLOW_BOTS filtering logic."""
def _run_filter(self, message, allow_bots="none", client_user=None):
"""Simulate the on_message filter logic and return whether message was accepted."""
# Replicate the exact filter logic from discord.py on_message
if message.author == client_user:
return False # own messages always ignored
if getattr(message.author, "bot", False):
allow = allow_bots.lower().strip()
if allow == "none":
return False
elif allow == "mentions":
if not client_user or client_user not in message.mentions:
return False
# "all" falls through
return True # message accepted
def test_own_messages_always_ignored(self):
"""Bot's own messages are always ignored regardless of allow_bots."""
bot_user = _make_author(is_self=True)
msg = _make_message(author=bot_user)
self.assertFalse(self._run_filter(msg, "all", bot_user))
def test_human_messages_always_accepted(self):
"""Human messages are always accepted regardless of allow_bots."""
human = _make_author(bot=False)
msg = _make_message(author=human)
self.assertTrue(self._run_filter(msg, "none"))
self.assertTrue(self._run_filter(msg, "mentions"))
self.assertTrue(self._run_filter(msg, "all"))
def test_allow_bots_none_rejects_bots(self):
"""With allow_bots=none, all other bot messages are rejected."""
bot = _make_author(bot=True)
msg = _make_message(author=bot)
self.assertFalse(self._run_filter(msg, "none"))
def test_allow_bots_all_accepts_bots(self):
"""With allow_bots=all, all bot messages are accepted."""
bot = _make_author(bot=True)
msg = _make_message(author=bot)
self.assertTrue(self._run_filter(msg, "all"))
def test_allow_bots_mentions_rejects_without_mention(self):
"""With allow_bots=mentions, bot messages without @mention are rejected."""
our_user = _make_author(is_self=True)
bot = _make_author(bot=True)
msg = _make_message(author=bot, mentions=[])
self.assertFalse(self._run_filter(msg, "mentions", our_user))
def test_allow_bots_mentions_accepts_with_mention(self):
"""With allow_bots=mentions, bot messages with @mention are accepted."""
our_user = _make_author(is_self=True)
bot = _make_author(bot=True)
msg = _make_message(author=bot, mentions=[our_user])
self.assertTrue(self._run_filter(msg, "mentions", our_user))
def test_default_is_none(self):
"""Default behavior (no env var) should be 'none'."""
default = os.getenv("DISCORD_ALLOW_BOTS", "none")
self.assertEqual(default, "none")
def test_case_insensitive(self):
"""Allow_bots value should be case-insensitive."""
bot = _make_author(bot=True)
msg = _make_message(author=bot)
self.assertTrue(self._run_filter(msg, "ALL"))
self.assertTrue(self._run_filter(msg, "All"))
self.assertFalse(self._run_filter(msg, "NONE"))
self.assertFalse(self._run_filter(msg, "None"))
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,347 @@
"""Tests for Discord incoming document/file attachment handling.
Covers the document branch in DiscordAdapter._handle_message()
the `else` clause of the attachment content-type loop that was added
to download, cache, and optionally inject text from non-image/audio files.
"""
import os
import sys
from datetime import datetime, timezone
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from gateway.config import PlatformConfig
from gateway.platforms.base import MessageType
# ---------------------------------------------------------------------------
# Discord mock setup (copied from test_discord_free_response.py)
# ---------------------------------------------------------------------------
def _ensure_discord_mock():
"""Install a mock discord module when discord.py isn't available."""
if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
return
discord_mod = MagicMock()
discord_mod.Intents.default.return_value = MagicMock()
discord_mod.Client = MagicMock
discord_mod.File = MagicMock
discord_mod.DMChannel = type("DMChannel", (), {})
discord_mod.Thread = type("Thread", (), {})
discord_mod.ForumChannel = type("ForumChannel", (), {})
discord_mod.ui = SimpleNamespace(View=object, button=lambda *a, **k: (lambda fn: fn), Button=object)
discord_mod.ButtonStyle = SimpleNamespace(success=1, primary=2, danger=3, green=1, blurple=2, red=3)
discord_mod.Color = SimpleNamespace(orange=lambda: 1, green=lambda: 2, blue=lambda: 3, red=lambda: 4)
discord_mod.Interaction = object
discord_mod.Embed = MagicMock
discord_mod.app_commands = SimpleNamespace(
describe=lambda **kwargs: (lambda fn: fn),
choices=lambda **kwargs: (lambda fn: fn),
Choice=lambda **kwargs: SimpleNamespace(**kwargs),
)
ext_mod = MagicMock()
commands_mod = MagicMock()
commands_mod.Bot = MagicMock
ext_mod.commands = commands_mod
sys.modules.setdefault("discord", discord_mod)
sys.modules.setdefault("discord.ext", ext_mod)
sys.modules.setdefault("discord.ext.commands", commands_mod)
_ensure_discord_mock()
import gateway.platforms.discord as discord_platform # noqa: E402
from gateway.platforms.discord import DiscordAdapter # noqa: E402
# ---------------------------------------------------------------------------
# Fake channel / thread types
# ---------------------------------------------------------------------------
class FakeDMChannel:
def __init__(self, channel_id: int = 1):
self.id = channel_id
self.name = "dm"
class FakeThread:
def __init__(self, channel_id: int = 10):
self.id = channel_id
self.name = "thread"
self.parent = None
self.parent_id = None
self.guild = SimpleNamespace(name="TestServer")
self.topic = None
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _redirect_cache(tmp_path, monkeypatch):
"""Point document cache to tmp_path so tests never write to ~/.hermes."""
monkeypatch.setattr(
"gateway.platforms.base.DOCUMENT_CACHE_DIR", tmp_path / "doc_cache"
)
@pytest.fixture
def adapter(monkeypatch):
monkeypatch.setattr(discord_platform.discord, "DMChannel", FakeDMChannel, raising=False)
monkeypatch.setattr(discord_platform.discord, "Thread", FakeThread, raising=False)
config = PlatformConfig(enabled=True, token="fake-token")
a = DiscordAdapter(config)
a._client = SimpleNamespace(user=SimpleNamespace(id=999))
a.handle_message = AsyncMock()
return a
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def make_attachment(
*,
filename: str,
content_type: str,
size: int = 1024,
url: str = "https://cdn.discordapp.com/attachments/fake/file",
) -> SimpleNamespace:
return SimpleNamespace(
filename=filename,
content_type=content_type,
size=size,
url=url,
)
def make_message(attachments: list, content: str = "") -> SimpleNamespace:
return SimpleNamespace(
id=123,
content=content,
attachments=attachments,
mentions=[],
reference=None,
created_at=datetime.now(timezone.utc),
channel=FakeDMChannel(),
author=SimpleNamespace(id=42, display_name="Tester", name="Tester"),
)
def _mock_aiohttp_download(raw_bytes: bytes):
"""Return a patch context manager that makes aiohttp return raw_bytes."""
resp = AsyncMock()
resp.status = 200
resp.read = AsyncMock(return_value=raw_bytes)
resp.__aenter__ = AsyncMock(return_value=resp)
resp.__aexit__ = AsyncMock(return_value=False)
session = AsyncMock()
session.get = MagicMock(return_value=resp)
session.__aenter__ = AsyncMock(return_value=session)
session.__aexit__ = AsyncMock(return_value=False)
return patch("aiohttp.ClientSession", return_value=session)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestIncomingDocumentHandling:
@pytest.mark.asyncio
async def test_pdf_document_cached(self, adapter):
"""A PDF attachment should be downloaded, cached, typed as DOCUMENT."""
pdf_bytes = b"%PDF-1.4 fake content"
with _mock_aiohttp_download(pdf_bytes):
msg = make_message([make_attachment(filename="report.pdf", content_type="application/pdf")])
await adapter._handle_message(msg)
event = adapter.handle_message.call_args[0][0]
assert event.message_type == MessageType.DOCUMENT
assert len(event.media_urls) == 1
assert os.path.exists(event.media_urls[0])
assert event.media_types == ["application/pdf"]
assert "[Content of" not in (event.text or "")
@pytest.mark.asyncio
async def test_txt_content_injected(self, adapter):
""".txt file under 100KB should have its content injected into event.text."""
file_content = b"Hello from a text file"
with _mock_aiohttp_download(file_content):
msg = make_message(
attachments=[make_attachment(filename="notes.txt", content_type="text/plain")],
content="summarize this",
)
await adapter._handle_message(msg)
event = adapter.handle_message.call_args[0][0]
assert "[Content of notes.txt]:" in event.text
assert "Hello from a text file" in event.text
assert "summarize this" in event.text
# injection prepended before caption
assert event.text.index("[Content of") < event.text.index("summarize this")
@pytest.mark.asyncio
async def test_md_content_injected(self, adapter):
""".md file under 100KB should have its content injected."""
file_content = b"# Title\nSome markdown content"
with _mock_aiohttp_download(file_content):
msg = make_message(
attachments=[make_attachment(filename="readme.md", content_type="text/markdown")],
content="",
)
await adapter._handle_message(msg)
event = adapter.handle_message.call_args[0][0]
assert "[Content of readme.md]:" in event.text
assert "# Title" in event.text
@pytest.mark.asyncio
async def test_oversized_document_skipped(self, adapter):
"""A document over 20MB should be skipped — media_urls stays empty."""
msg = make_message([
make_attachment(
filename="huge.pdf",
content_type="application/pdf",
size=25 * 1024 * 1024,
)
])
await adapter._handle_message(msg)
event = adapter.handle_message.call_args[0][0]
assert event.media_urls == []
# handler must still be called
adapter.handle_message.assert_called_once()
@pytest.mark.asyncio
async def test_unsupported_type_skipped(self, adapter):
"""An unsupported file type (.zip) should be skipped silently."""
msg = make_message([
make_attachment(filename="archive.zip", content_type="application/zip")
])
await adapter._handle_message(msg)
event = adapter.handle_message.call_args[0][0]
assert event.media_urls == []
assert event.message_type == MessageType.TEXT
@pytest.mark.asyncio
async def test_download_error_handled(self, adapter):
"""If the HTTP download raises, the handler should not crash."""
resp = AsyncMock()
resp.__aenter__ = AsyncMock(side_effect=RuntimeError("connection reset"))
resp.__aexit__ = AsyncMock(return_value=False)
session = AsyncMock()
session.get = MagicMock(return_value=resp)
session.__aenter__ = AsyncMock(return_value=session)
session.__aexit__ = AsyncMock(return_value=False)
with patch("aiohttp.ClientSession", return_value=session):
msg = make_message([
make_attachment(filename="report.pdf", content_type="application/pdf")
])
await adapter._handle_message(msg)
# Must still deliver an event
adapter.handle_message.assert_called_once()
event = adapter.handle_message.call_args[0][0]
assert event.media_urls == []
@pytest.mark.asyncio
async def test_large_txt_cached_not_injected(self, adapter):
""".txt over 100KB should be cached but NOT injected into event.text."""
large_content = b"x" * (200 * 1024)
with _mock_aiohttp_download(large_content):
msg = make_message(
attachments=[make_attachment(filename="big.txt", content_type="text/plain", size=len(large_content))],
content="",
)
await adapter._handle_message(msg)
event = adapter.handle_message.call_args[0][0]
assert len(event.media_urls) == 1
assert os.path.exists(event.media_urls[0])
assert "[Content of" not in (event.text or "")
@pytest.mark.asyncio
async def test_multiple_text_files_both_injected(self, adapter):
"""Two text file attachments should both be injected into event.text in order."""
content1 = b"First file content"
content2 = b"Second file content"
call_count = 0
responses = [content1, content2]
def make_session(_responses):
idx = 0
class FakeSession:
async def __aenter__(self):
return self
async def __aexit__(self, *_):
pass
def get(self, url, **kwargs):
nonlocal idx
data = _responses[idx % len(_responses)]
idx += 1
resp = AsyncMock()
resp.status = 200
resp.read = AsyncMock(return_value=data)
resp.__aenter__ = AsyncMock(return_value=resp)
resp.__aexit__ = AsyncMock(return_value=False)
return resp
return FakeSession()
with patch("aiohttp.ClientSession", return_value=make_session([content1, content2])):
msg = make_message(
attachments=[
make_attachment(filename="file1.txt", content_type="text/plain"),
make_attachment(filename="file2.txt", content_type="text/plain"),
],
content="",
)
await adapter._handle_message(msg)
event = adapter.handle_message.call_args[0][0]
assert "[Content of file1.txt]:" in event.text
assert "First file content" in event.text
assert "[Content of file2.txt]:" in event.text
assert "Second file content" in event.text
assert event.text.index("file1") < event.text.index("file2")
@pytest.mark.asyncio
async def test_image_attachment_unaffected(self, adapter):
"""Image attachments should still go through the image path, not the document path."""
with patch(
"gateway.platforms.discord.cache_image_from_url",
new_callable=AsyncMock,
return_value="/tmp/cached_image.png",
):
msg = make_message([
make_attachment(filename="photo.png", content_type="image/png")
])
await adapter._handle_message(msg)
event = adapter.handle_message.call_args[0][0]
assert event.message_type == MessageType.PHOTO
assert event.media_urls == ["/tmp/cached_image.png"]
assert event.media_types == ["image/png"]

View file

@ -0,0 +1,360 @@
"""Tests for Discord free-response defaults and mention gating."""
from datetime import datetime, timezone
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
import sys
import pytest
from gateway.config import PlatformConfig
def _ensure_discord_mock():
"""Install a mock discord module when discord.py isn't available."""
if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
return
discord_mod = MagicMock()
discord_mod.Intents.default.return_value = MagicMock()
discord_mod.Client = MagicMock
discord_mod.File = MagicMock
discord_mod.DMChannel = type("DMChannel", (), {})
discord_mod.Thread = type("Thread", (), {})
discord_mod.ForumChannel = type("ForumChannel", (), {})
discord_mod.ui = SimpleNamespace(View=object, button=lambda *a, **k: (lambda fn: fn), Button=object)
discord_mod.ButtonStyle = SimpleNamespace(success=1, primary=2, danger=3, green=1, blurple=2, red=3)
discord_mod.Color = SimpleNamespace(orange=lambda: 1, green=lambda: 2, blue=lambda: 3, red=lambda: 4)
discord_mod.Interaction = object
discord_mod.Embed = MagicMock
discord_mod.app_commands = SimpleNamespace(
describe=lambda **kwargs: (lambda fn: fn),
choices=lambda **kwargs: (lambda fn: fn),
Choice=lambda **kwargs: SimpleNamespace(**kwargs),
)
ext_mod = MagicMock()
commands_mod = MagicMock()
commands_mod.Bot = MagicMock
ext_mod.commands = commands_mod
sys.modules.setdefault("discord", discord_mod)
sys.modules.setdefault("discord.ext", ext_mod)
sys.modules.setdefault("discord.ext.commands", commands_mod)
_ensure_discord_mock()
import gateway.platforms.discord as discord_platform # noqa: E402
from gateway.platforms.discord import DiscordAdapter # noqa: E402
class FakeDMChannel:
def __init__(self, channel_id: int = 1, name: str = "dm"):
self.id = channel_id
self.name = name
class FakeTextChannel:
def __init__(self, channel_id: int = 1, name: str = "general", guild_name: str = "Hermes Server"):
self.id = channel_id
self.name = name
self.guild = SimpleNamespace(name=guild_name)
self.topic = None
class FakeForumChannel:
def __init__(self, channel_id: int = 1, name: str = "support-forum", guild_name: str = "Hermes Server"):
self.id = channel_id
self.name = name
self.guild = SimpleNamespace(name=guild_name)
self.type = 15
self.topic = None
class FakeThread:
def __init__(self, channel_id: int = 1, name: str = "thread", parent=None, guild_name: str = "Hermes Server"):
self.id = channel_id
self.name = name
self.parent = parent
self.parent_id = getattr(parent, "id", None)
self.guild = getattr(parent, "guild", None) or SimpleNamespace(name=guild_name)
self.topic = None
@pytest.fixture
def adapter(monkeypatch):
monkeypatch.setattr(discord_platform.discord, "DMChannel", FakeDMChannel, raising=False)
monkeypatch.setattr(discord_platform.discord, "Thread", FakeThread, raising=False)
monkeypatch.setattr(discord_platform.discord, "ForumChannel", FakeForumChannel, raising=False)
config = PlatformConfig(enabled=True, token="fake-token")
adapter = DiscordAdapter(config)
adapter._client = SimpleNamespace(user=SimpleNamespace(id=999))
adapter.handle_message = AsyncMock()
return adapter
def make_message(*, channel, content: str, mentions=None):
author = SimpleNamespace(id=42, display_name="Jezza", name="Jezza")
return SimpleNamespace(
id=123,
content=content,
mentions=list(mentions or []),
attachments=[],
reference=None,
created_at=datetime.now(timezone.utc),
channel=channel,
author=author,
)
@pytest.mark.asyncio
async def test_discord_defaults_to_require_mention(adapter, monkeypatch):
"""Default behavior: require @mention in server channels."""
monkeypatch.delenv("DISCORD_REQUIRE_MENTION", raising=False)
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
message = make_message(channel=FakeTextChannel(channel_id=123), content="hello from channel")
await adapter._handle_message(message)
# Should be ignored — no mention, require_mention defaults to true
adapter.handle_message.assert_not_awaited()
@pytest.mark.asyncio
async def test_discord_free_response_in_server_channels(adapter, monkeypatch):
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
message = make_message(channel=FakeTextChannel(channel_id=123), content="hello from channel")
await adapter._handle_message(message)
adapter.handle_message.assert_awaited_once()
event = adapter.handle_message.await_args.args[0]
assert event.text == "hello from channel"
assert event.source.chat_id == "123"
assert event.source.chat_type == "group"
@pytest.mark.asyncio
async def test_discord_free_response_in_threads(adapter, monkeypatch):
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
thread = FakeThread(channel_id=456, name="Ghost reader skill")
message = make_message(channel=thread, content="hello from thread")
await adapter._handle_message(message)
adapter.handle_message.assert_awaited_once()
event = adapter.handle_message.await_args.args[0]
assert event.text == "hello from thread"
assert event.source.chat_id == "456"
assert event.source.thread_id == "456"
assert event.source.chat_type == "thread"
@pytest.mark.asyncio
async def test_discord_forum_threads_are_handled_as_threads(adapter, monkeypatch):
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
forum = FakeForumChannel(channel_id=222, name="support-forum")
thread = FakeThread(channel_id=456, name="Can Hermes reply here?", parent=forum)
message = make_message(channel=thread, content="hello from forum post")
await adapter._handle_message(message)
adapter.handle_message.assert_awaited_once()
event = adapter.handle_message.await_args.args[0]
assert event.text == "hello from forum post"
assert event.source.chat_id == "456"
assert event.source.thread_id == "456"
assert event.source.chat_type == "thread"
assert event.source.chat_name == "Hermes Server / support-forum / Can Hermes reply here?"
@pytest.mark.asyncio
async def test_discord_can_still_require_mentions_when_enabled(adapter, monkeypatch):
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
message = make_message(channel=FakeTextChannel(channel_id=789), content="ignored without mention")
await adapter._handle_message(message)
adapter.handle_message.assert_not_awaited()
@pytest.mark.asyncio
async def test_discord_free_response_channel_overrides_mention_requirement(adapter, monkeypatch):
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
monkeypatch.setenv("DISCORD_FREE_RESPONSE_CHANNELS", "789,999")
message = make_message(channel=FakeTextChannel(channel_id=789), content="allowed without mention")
await adapter._handle_message(message)
adapter.handle_message.assert_awaited_once()
event = adapter.handle_message.await_args.args[0]
assert event.text == "allowed without mention"
@pytest.mark.asyncio
async def test_discord_forum_parent_in_free_response_list_allows_forum_thread(adapter, monkeypatch):
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
monkeypatch.setenv("DISCORD_FREE_RESPONSE_CHANNELS", "222")
forum = FakeForumChannel(channel_id=222, name="support-forum")
thread = FakeThread(channel_id=333, name="Forum topic", parent=forum)
message = make_message(channel=thread, content="allowed from forum thread")
await adapter._handle_message(message)
adapter.handle_message.assert_awaited_once()
event = adapter.handle_message.await_args.args[0]
assert event.text == "allowed from forum thread"
assert event.source.chat_id == "333"
@pytest.mark.asyncio
async def test_discord_accepts_and_strips_bot_mentions_when_required(adapter, monkeypatch):
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
bot_user = adapter._client.user
message = make_message(
channel=FakeTextChannel(channel_id=321),
content=f"<@{bot_user.id}> hello with mention",
mentions=[bot_user],
)
await adapter._handle_message(message)
adapter.handle_message.assert_awaited_once()
event = adapter.handle_message.await_args.args[0]
assert event.text == "hello with mention"
@pytest.mark.asyncio
async def test_discord_dms_ignore_mention_requirement(adapter, monkeypatch):
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
message = make_message(channel=FakeDMChannel(channel_id=654), content="dm without mention")
await adapter._handle_message(message)
adapter.handle_message.assert_awaited_once()
event = adapter.handle_message.await_args.args[0]
assert event.text == "dm without mention"
assert event.source.chat_type == "dm"
@pytest.mark.asyncio
async def test_discord_auto_thread_enabled_by_default(adapter, monkeypatch):
"""Auto-threading should be enabled by default (DISCORD_AUTO_THREAD defaults to 'true')."""
monkeypatch.delenv("DISCORD_AUTO_THREAD", raising=False)
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
# Patch _auto_create_thread to return a fake thread
fake_thread = FakeThread(channel_id=999, name="auto-thread")
adapter._auto_create_thread = AsyncMock(return_value=fake_thread)
message = make_message(channel=FakeTextChannel(channel_id=123), content="hello")
await adapter._handle_message(message)
adapter._auto_create_thread.assert_awaited_once()
adapter.handle_message.assert_awaited_once()
event = adapter.handle_message.await_args.args[0]
assert event.source.chat_type == "thread"
assert event.source.thread_id == "999"
@pytest.mark.asyncio
async def test_discord_auto_thread_can_be_disabled(adapter, monkeypatch):
"""Setting auto_thread to false skips thread creation."""
monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
adapter._auto_create_thread = AsyncMock()
message = make_message(channel=FakeTextChannel(channel_id=123), content="hello")
await adapter._handle_message(message)
adapter._auto_create_thread.assert_not_awaited()
adapter.handle_message.assert_awaited_once()
event = adapter.handle_message.await_args.args[0]
assert event.source.chat_type == "group"
@pytest.mark.asyncio
async def test_discord_bot_thread_skips_mention_requirement(adapter, monkeypatch):
"""Messages in a thread the bot has participated in should not require @mention."""
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
# Simulate bot having previously participated in thread 456
adapter._bot_participated_threads.add("456")
thread = FakeThread(channel_id=456, name="existing thread")
message = make_message(channel=thread, content="follow-up without mention")
await adapter._handle_message(message)
adapter.handle_message.assert_awaited_once()
event = adapter.handle_message.await_args.args[0]
assert event.text == "follow-up without mention"
assert event.source.chat_type == "thread"
@pytest.mark.asyncio
async def test_discord_unknown_thread_still_requires_mention(adapter, monkeypatch):
"""Messages in a thread the bot hasn't participated in should still require @mention."""
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
# Bot has NOT participated in thread 789
thread = FakeThread(channel_id=789, name="some thread")
message = make_message(channel=thread, content="hello from unknown thread")
await adapter._handle_message(message)
adapter.handle_message.assert_not_awaited()
@pytest.mark.asyncio
async def test_discord_auto_thread_tracks_participation(adapter, monkeypatch):
"""Auto-created threads should be tracked for future mention-free replies."""
monkeypatch.delenv("DISCORD_AUTO_THREAD", raising=False)
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
fake_thread = FakeThread(channel_id=555, name="auto-thread")
adapter._auto_create_thread = AsyncMock(return_value=fake_thread)
message = make_message(channel=FakeTextChannel(channel_id=123), content="start a thread")
await adapter._handle_message(message)
assert "555" in adapter._bot_participated_threads
@pytest.mark.asyncio
async def test_discord_thread_participation_tracked_on_dispatch(adapter, monkeypatch):
"""When the bot processes a message in a thread, it tracks participation."""
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
thread = FakeThread(channel_id=777, name="manually created thread")
message = make_message(channel=thread, content="hello in thread")
await adapter._handle_message(message)
assert "777" in adapter._bot_participated_threads

View file

@ -0,0 +1,23 @@
"""Import-safety tests for the Discord gateway adapter."""
import builtins
import importlib
import sys
class TestDiscordImportSafety:
def test_module_imports_even_when_discord_dependency_is_missing(self, monkeypatch):
original_import = builtins.__import__
def fake_import(name, globals=None, locals=None, fromlist=(), level=0):
if name == "discord" or name.startswith("discord."):
raise ImportError("discord unavailable for test")
return original_import(name, globals, locals, fromlist, level)
monkeypatch.delitem(sys.modules, "gateway.platforms.discord", raising=False)
monkeypatch.setattr(builtins, "__import__", fake_import)
module = importlib.import_module("gateway.platforms.discord")
assert module.DISCORD_AVAILABLE is False
assert module.discord is None

View file

@ -0,0 +1,9 @@
import inspect
from gateway.platforms.discord import DiscordAdapter
def test_discord_media_methods_accept_metadata_kwarg():
for method_name in ("send_voice", "send_image_file", "send_image"):
signature = inspect.signature(getattr(DiscordAdapter, method_name))
assert "metadata" in signature.parameters, method_name

View file

@ -0,0 +1,44 @@
"""Tests for Discord Opus codec loading — must use ctypes.util.find_library."""
import inspect
class TestOpusFindLibrary:
"""Opus loading must try ctypes.util.find_library first, with platform fallback."""
def test_uses_find_library_first(self):
"""find_library must be the primary lookup strategy."""
from gateway.platforms.discord import DiscordAdapter
source = inspect.getsource(DiscordAdapter.connect)
assert "find_library" in source, \
"Opus loading must use ctypes.util.find_library"
def test_homebrew_fallback_is_conditional(self):
"""Homebrew paths must only be tried when find_library returns None."""
from gateway.platforms.discord import DiscordAdapter
source = inspect.getsource(DiscordAdapter.connect)
# Homebrew fallback must exist
assert "/opt/homebrew" in source or "homebrew" in source, \
"Opus loading should have macOS Homebrew fallback"
# find_library must appear BEFORE any Homebrew path
fl_idx = source.index("find_library")
hb_idx = source.index("/opt/homebrew")
assert fl_idx < hb_idx, \
"find_library must be tried before Homebrew fallback paths"
# Fallback must be guarded by platform check
assert "sys.platform" in source or "darwin" in source, \
"Homebrew fallback must be guarded by macOS platform check"
def test_opus_decode_error_logged(self):
"""Opus decode failure must log the error, not silently return."""
from gateway.platforms.discord import VoiceReceiver
source = inspect.getsource(VoiceReceiver._on_packet)
assert "logger" in source, \
"_on_packet must log Opus decode errors"
# Must not have bare `except Exception:\n return`
lines = source.split("\n")
for i, line in enumerate(lines):
if "except Exception" in line and i + 1 < len(lines):
next_line = lines[i + 1].strip()
assert next_line != "return", \
f"_on_packet has bare 'except Exception: return' at line {i+1}"

View file

@ -0,0 +1,80 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
import sys
import pytest
from gateway.config import PlatformConfig
def _ensure_discord_mock():
if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
return
discord_mod = MagicMock()
discord_mod.Intents.default.return_value = MagicMock()
discord_mod.Client = MagicMock
discord_mod.File = MagicMock
discord_mod.DMChannel = type("DMChannel", (), {})
discord_mod.Thread = type("Thread", (), {})
discord_mod.ForumChannel = type("ForumChannel", (), {})
discord_mod.ui = SimpleNamespace(View=object, button=lambda *a, **k: (lambda fn: fn), Button=object)
discord_mod.ButtonStyle = SimpleNamespace(success=1, primary=2, danger=3, green=1, blurple=2, red=3)
discord_mod.Color = SimpleNamespace(orange=lambda: 1, green=lambda: 2, blue=lambda: 3, red=lambda: 4)
discord_mod.Interaction = object
discord_mod.Embed = MagicMock
discord_mod.app_commands = SimpleNamespace(
describe=lambda **kwargs: (lambda fn: fn),
choices=lambda **kwargs: (lambda fn: fn),
Choice=lambda **kwargs: SimpleNamespace(**kwargs),
)
ext_mod = MagicMock()
commands_mod = MagicMock()
commands_mod.Bot = MagicMock
ext_mod.commands = commands_mod
sys.modules.setdefault("discord", discord_mod)
sys.modules.setdefault("discord.ext", ext_mod)
sys.modules.setdefault("discord.ext.commands", commands_mod)
_ensure_discord_mock()
from gateway.platforms.discord import DiscordAdapter # noqa: E402
@pytest.mark.asyncio
async def test_send_retries_without_reference_when_reply_target_is_system_message():
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
ref_msg = SimpleNamespace(id=99)
sent_msg = SimpleNamespace(id=1234)
send_calls = []
async def fake_send(*, content, reference=None):
send_calls.append({"content": content, "reference": reference})
if len(send_calls) == 1:
raise RuntimeError(
"400 Bad Request (error code: 50035): Invalid Form Body\n"
"In message_reference: Cannot reply to a system message"
)
return sent_msg
channel = SimpleNamespace(
fetch_message=AsyncMock(return_value=ref_msg),
send=AsyncMock(side_effect=fake_send),
)
adapter._client = SimpleNamespace(
get_channel=lambda _chat_id: channel,
fetch_channel=AsyncMock(),
)
result = await adapter.send("555", "hello", reply_to="99")
assert result.success is True
assert result.message_id == "1234"
assert channel.fetch_message.await_count == 1
assert channel.send.await_count == 2
assert send_calls[0]["reference"] is ref_msg
assert send_calls[1]["reference"] is None

View file

@ -0,0 +1,499 @@
"""Tests for native Discord slash command fast-paths (thread creation & auto-thread)."""
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import sys
import pytest
from gateway.config import PlatformConfig
def _ensure_discord_mock():
if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
return
discord_mod = MagicMock()
discord_mod.Intents.default.return_value = MagicMock()
discord_mod.DMChannel = type("DMChannel", (), {})
discord_mod.Thread = type("Thread", (), {})
discord_mod.ForumChannel = type("ForumChannel", (), {})
discord_mod.Interaction = object
discord_mod.app_commands = SimpleNamespace(
describe=lambda **kwargs: (lambda fn: fn),
choices=lambda **kwargs: (lambda fn: fn),
Choice=lambda **kwargs: SimpleNamespace(**kwargs),
)
ext_mod = MagicMock()
commands_mod = MagicMock()
commands_mod.Bot = MagicMock
ext_mod.commands = commands_mod
sys.modules.setdefault("discord", discord_mod)
sys.modules.setdefault("discord.ext", ext_mod)
sys.modules.setdefault("discord.ext.commands", commands_mod)
_ensure_discord_mock()
from gateway.platforms.discord import DiscordAdapter # noqa: E402
class FakeTree:
def __init__(self):
self.commands = {}
def command(self, *, name, description):
def decorator(fn):
self.commands[name] = fn
return fn
return decorator
@pytest.fixture
def adapter():
config = PlatformConfig(enabled=True, token="***")
adapter = DiscordAdapter(config)
adapter._client = SimpleNamespace(
tree=FakeTree(),
get_channel=lambda _id: None,
fetch_channel=AsyncMock(),
user=SimpleNamespace(id=99999, name="HermesBot"),
)
return adapter
# ------------------------------------------------------------------
# /thread slash command registration
# ------------------------------------------------------------------
@pytest.mark.asyncio
async def test_registers_native_thread_slash_command(adapter):
adapter._handle_thread_create_slash = AsyncMock()
adapter._register_slash_commands()
command = adapter._client.tree.commands["thread"]
interaction = SimpleNamespace(
response=SimpleNamespace(defer=AsyncMock()),
)
await command(interaction, name="Planning", message="", auto_archive_duration=1440)
interaction.response.defer.assert_awaited_once_with(ephemeral=True)
adapter._handle_thread_create_slash.assert_awaited_once_with(interaction, "Planning", "", 1440)
# ------------------------------------------------------------------
# _handle_thread_create_slash — success, session dispatch, failure
# ------------------------------------------------------------------
@pytest.mark.asyncio
async def test_handle_thread_create_slash_reports_success(adapter):
created_thread = SimpleNamespace(id=555, name="Planning", send=AsyncMock())
parent_channel = SimpleNamespace(create_thread=AsyncMock(return_value=created_thread), send=AsyncMock())
interaction_channel = SimpleNamespace(parent=parent_channel)
interaction = SimpleNamespace(
channel=interaction_channel,
channel_id=123,
user=SimpleNamespace(display_name="Jezza", id=42),
guild=SimpleNamespace(name="TestGuild"),
followup=SimpleNamespace(send=AsyncMock()),
)
await adapter._handle_thread_create_slash(interaction, "Planning", "Kickoff", 1440)
parent_channel.create_thread.assert_awaited_once_with(
name="Planning",
auto_archive_duration=1440,
reason="Requested by Jezza via /thread",
)
created_thread.send.assert_awaited_once_with("Kickoff")
# Thread link shown to user
interaction.followup.send.assert_awaited()
args, kwargs = interaction.followup.send.await_args
assert "<#555>" in args[0]
assert kwargs["ephemeral"] is True
@pytest.mark.asyncio
async def test_handle_thread_create_slash_dispatches_session_when_message_provided(adapter):
"""When a message is given, _dispatch_thread_session should be called."""
created_thread = SimpleNamespace(id=555, name="Planning", send=AsyncMock())
parent_channel = SimpleNamespace(create_thread=AsyncMock(return_value=created_thread))
interaction = SimpleNamespace(
channel=SimpleNamespace(parent=parent_channel),
channel_id=123,
user=SimpleNamespace(display_name="Jezza", id=42),
guild=SimpleNamespace(name="TestGuild"),
followup=SimpleNamespace(send=AsyncMock()),
)
adapter._dispatch_thread_session = AsyncMock()
await adapter._handle_thread_create_slash(interaction, "Planning", "Hello Hermes", 1440)
adapter._dispatch_thread_session.assert_awaited_once_with(
interaction, "555", "Planning", "Hello Hermes",
)
@pytest.mark.asyncio
async def test_handle_thread_create_slash_no_dispatch_without_message(adapter):
"""Without a message, no session dispatch should occur."""
created_thread = SimpleNamespace(id=555, name="Planning", send=AsyncMock())
parent_channel = SimpleNamespace(create_thread=AsyncMock(return_value=created_thread))
interaction = SimpleNamespace(
channel=SimpleNamespace(parent=parent_channel),
channel_id=123,
user=SimpleNamespace(display_name="Jezza", id=42),
guild=SimpleNamespace(name="TestGuild"),
followup=SimpleNamespace(send=AsyncMock()),
)
adapter._dispatch_thread_session = AsyncMock()
await adapter._handle_thread_create_slash(interaction, "Planning", "", 1440)
adapter._dispatch_thread_session.assert_not_awaited()
@pytest.mark.asyncio
async def test_handle_thread_create_slash_falls_back_to_seed_message(adapter):
created_thread = SimpleNamespace(id=555, name="Planning")
seed_message = SimpleNamespace(id=777, create_thread=AsyncMock(return_value=created_thread))
channel = SimpleNamespace(
create_thread=AsyncMock(side_effect=RuntimeError("direct failed")),
send=AsyncMock(return_value=seed_message),
)
interaction = SimpleNamespace(
channel=channel,
channel_id=123,
user=SimpleNamespace(display_name="Jezza", id=42),
guild=SimpleNamespace(name="TestGuild"),
followup=SimpleNamespace(send=AsyncMock()),
)
await adapter._handle_thread_create_slash(interaction, "Planning", "Kickoff", 1440)
channel.send.assert_awaited_once_with("Kickoff")
seed_message.create_thread.assert_awaited_once_with(
name="Planning",
auto_archive_duration=1440,
reason="Requested by Jezza via /thread",
)
interaction.followup.send.assert_awaited()
@pytest.mark.asyncio
async def test_handle_thread_create_slash_reports_failure(adapter):
channel = SimpleNamespace(
create_thread=AsyncMock(side_effect=RuntimeError("direct failed")),
send=AsyncMock(side_effect=RuntimeError("nope")),
)
interaction = SimpleNamespace(
channel=channel,
channel_id=123,
user=SimpleNamespace(display_name="Jezza", id=42),
followup=SimpleNamespace(send=AsyncMock()),
)
await adapter._handle_thread_create_slash(interaction, "Planning", "", 1440)
interaction.followup.send.assert_awaited_once()
args, kwargs = interaction.followup.send.await_args
assert "Failed to create thread:" in args[0]
assert "nope" in args[0]
assert kwargs["ephemeral"] is True
# ------------------------------------------------------------------
# _dispatch_thread_session — builds correct event and routes it
# ------------------------------------------------------------------
@pytest.mark.asyncio
async def test_dispatch_thread_session_builds_thread_event(adapter):
"""Dispatched event should have chat_type=thread and chat_id=thread_id."""
interaction = SimpleNamespace(
user=SimpleNamespace(display_name="Jezza", id=42),
guild=SimpleNamespace(name="TestGuild"),
)
captured_events = []
async def capture_handle(event):
captured_events.append(event)
adapter.handle_message = capture_handle
await adapter._dispatch_thread_session(interaction, "555", "Planning", "Hello!")
assert len(captured_events) == 1
event = captured_events[0]
assert event.text == "Hello!"
assert event.source.chat_id == "555"
assert event.source.chat_type == "thread"
assert event.source.thread_id == "555"
assert "TestGuild" in event.source.chat_name
# ------------------------------------------------------------------
# _build_slash_event — preserve thread context for native slash commands
# ------------------------------------------------------------------
def test_build_slash_event_preserves_thread_context(adapter):
interaction = SimpleNamespace(
channel=_FakeThreadChannel(channel_id=555, name="Planning"),
channel_id=555,
user=SimpleNamespace(display_name="Jezza", id=42),
)
event = adapter._build_slash_event(interaction, "/status")
assert event.text == "/status"
assert event.source.chat_id == "555"
assert event.source.chat_type == "thread"
assert event.source.thread_id == "555"
assert "TestGuild" in event.source.chat_name
def test_build_slash_event_uses_group_context_for_channels(adapter):
interaction = SimpleNamespace(
channel=_FakeTextChannel(channel_id=123, name="general"),
channel_id=123,
user=SimpleNamespace(display_name="Jezza", id=42),
)
event = adapter._build_slash_event(interaction, "/status")
assert event.source.chat_id == "123"
assert event.source.chat_type == "group"
assert event.source.thread_id is None
assert "TestGuild / #general" == event.source.chat_name
# ------------------------------------------------------------------
# Auto-thread: _auto_create_thread
# ------------------------------------------------------------------
@pytest.mark.asyncio
async def test_auto_create_thread_uses_message_content_as_name(adapter):
thread = SimpleNamespace(id=999, name="Hello world")
message = SimpleNamespace(
content="Hello world, how are you?",
create_thread=AsyncMock(return_value=thread),
)
result = await adapter._auto_create_thread(message)
assert result is thread
message.create_thread.assert_awaited_once()
call_kwargs = message.create_thread.await_args[1]
assert call_kwargs["name"] == "Hello world, how are you?"
assert call_kwargs["auto_archive_duration"] == 1440
@pytest.mark.asyncio
async def test_auto_create_thread_truncates_long_names(adapter):
long_text = "a" * 200
thread = SimpleNamespace(id=999, name="truncated")
message = SimpleNamespace(
content=long_text,
create_thread=AsyncMock(return_value=thread),
)
result = await adapter._auto_create_thread(message)
assert result is thread
call_kwargs = message.create_thread.await_args[1]
assert len(call_kwargs["name"]) <= 80
assert call_kwargs["name"].endswith("...")
@pytest.mark.asyncio
async def test_auto_create_thread_returns_none_on_failure(adapter):
message = SimpleNamespace(
content="Hello",
create_thread=AsyncMock(side_effect=RuntimeError("no perms")),
)
result = await adapter._auto_create_thread(message)
assert result is None
# ------------------------------------------------------------------
# Auto-thread integration in _handle_message
# ------------------------------------------------------------------
import discord as _discord_mod # noqa: E402 — mock or real, used below
class _FakeTextChannel:
"""A channel that is NOT a discord.Thread or discord.DMChannel."""
def __init__(self, channel_id=100, name="general", guild_name="TestGuild"):
self.id = channel_id
self.name = name
self.guild = SimpleNamespace(name=guild_name, id=1)
self.topic = None
class _FakeThreadChannel(_discord_mod.Thread):
"""isinstance(ch, discord.Thread) → True."""
def __init__(self, channel_id=200, name="existing-thread", guild_name="TestGuild", parent_id=100):
# Don't call super().__init__ — mock Thread is just an empty type
self.id = channel_id
self.name = name
self.guild = SimpleNamespace(name=guild_name, id=1)
self.topic = None
self.parent = SimpleNamespace(id=parent_id, name="general", guild=SimpleNamespace(name=guild_name, id=1))
def _fake_message(channel, *, content="Hello", author_id=42, display_name="Jezza"):
return SimpleNamespace(
author=SimpleNamespace(id=author_id, display_name=display_name, bot=False),
content=content,
channel=channel,
attachments=[],
mentions=[],
reference=None,
created_at=None,
id=12345,
)
@pytest.mark.asyncio
async def test_auto_thread_creates_thread_and_redirects(adapter, monkeypatch):
"""When DISCORD_AUTO_THREAD=true, a new thread is created and the event routes there."""
monkeypatch.setenv("DISCORD_AUTO_THREAD", "true")
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
thread = SimpleNamespace(id=999, name="Hello")
adapter._auto_create_thread = AsyncMock(return_value=thread)
captured_events = []
async def capture_handle(event):
captured_events.append(event)
adapter.handle_message = capture_handle
msg = _fake_message(_FakeTextChannel(), content="Hello world")
await adapter._handle_message(msg)
adapter._auto_create_thread.assert_awaited_once_with(msg)
assert len(captured_events) == 1
event = captured_events[0]
assert event.source.chat_id == "999" # redirected to thread
assert event.source.chat_type == "thread"
assert event.source.thread_id == "999"
@pytest.mark.asyncio
async def test_auto_thread_enabled_by_default_slash_commands(adapter, monkeypatch):
"""Without DISCORD_AUTO_THREAD env var, auto-threading is enabled (default: true)."""
monkeypatch.delenv("DISCORD_AUTO_THREAD", raising=False)
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
fake_thread = _FakeThreadChannel(channel_id=999, name="auto-thread")
adapter._auto_create_thread = AsyncMock(return_value=fake_thread)
captured_events = []
async def capture_handle(event):
captured_events.append(event)
adapter.handle_message = capture_handle
msg = _fake_message(_FakeTextChannel())
await adapter._handle_message(msg)
adapter._auto_create_thread.assert_awaited_once()
assert len(captured_events) == 1
assert captured_events[0].source.chat_id == "999" # redirected to thread
assert captured_events[0].source.chat_type == "thread"
@pytest.mark.asyncio
async def test_auto_thread_can_be_disabled(adapter, monkeypatch):
"""Setting DISCORD_AUTO_THREAD=false keeps messages in the channel."""
monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
adapter._auto_create_thread = AsyncMock()
captured_events = []
async def capture_handle(event):
captured_events.append(event)
adapter.handle_message = capture_handle
msg = _fake_message(_FakeTextChannel())
await adapter._handle_message(msg)
adapter._auto_create_thread.assert_not_awaited()
assert len(captured_events) == 1
assert captured_events[0].source.chat_id == "100" # stays in channel
@pytest.mark.asyncio
async def test_auto_thread_skips_threads_and_dms(adapter, monkeypatch):
"""Auto-thread should not create threads inside existing threads."""
monkeypatch.setenv("DISCORD_AUTO_THREAD", "true")
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
adapter._auto_create_thread = AsyncMock()
captured_events = []
async def capture_handle(event):
captured_events.append(event)
adapter.handle_message = capture_handle
msg = _fake_message(_FakeThreadChannel())
await adapter._handle_message(msg)
adapter._auto_create_thread.assert_not_awaited() # should NOT auto-thread
# ------------------------------------------------------------------
# Config bridge
# ------------------------------------------------------------------
def test_discord_auto_thread_config_bridge(monkeypatch, tmp_path):
"""discord.auto_thread in config.yaml should be bridged to DISCORD_AUTO_THREAD env var."""
import yaml
from pathlib import Path
# Write a config.yaml the loader will find
hermes_dir = tmp_path / ".hermes"
hermes_dir.mkdir()
config_path = hermes_dir / "config.yaml"
config_path.write_text(yaml.dump({
"discord": {"auto_thread": True},
}))
monkeypatch.delenv("DISCORD_AUTO_THREAD", raising=False)
monkeypatch.setenv("HERMES_HOME", str(hermes_dir))
monkeypatch.setattr(Path, "home", lambda: tmp_path)
from gateway.config import load_gateway_config
load_gateway_config()
import os
assert os.getenv("DISCORD_AUTO_THREAD") == "true"

View file

@ -0,0 +1,99 @@
"""Tests for Discord system message filtering (thread renames, pins, etc.)."""
import pytest
import unittest
from unittest.mock import MagicMock
discord = pytest.importorskip("discord")
def _make_author(*, bot: bool = False, is_self: bool = False):
"""Create a mock Discord author."""
author = MagicMock()
author.bot = bot
author.id = 99999 if is_self else 12345
author.name = "TestBot" if bot else "TestUser"
author.display_name = author.name
return author
def _make_message(*, author=None, content="hello", msg_type=None):
"""Create a mock Discord message with a specific type."""
msg = MagicMock()
msg.author = author or _make_author()
msg.content = content
msg.attachments = []
msg.mentions = []
msg.type = msg_type if msg_type is not None else discord.MessageType.default
msg.channel = MagicMock()
msg.channel.id = 222
msg.channel.name = "test-channel"
msg.channel.guild = MagicMock()
msg.channel.guild.name = "TestServer"
return msg
class TestDiscordSystemMessageFilter(unittest.TestCase):
"""Test that Discord system messages (thread renames, pins, etc.) are ignored."""
def _run_filter(self, message, client_user=None):
"""Simulate the on_message filter logic and return whether message was accepted.
Replicates the guard added to discord.py:
if message.type not in (discord.MessageType.default, discord.MessageType.reply):
return # ignored
"""
# Own messages always ignored
if message.author == client_user:
return False
# System message filter (the fix being tested)
if message.type not in (discord.MessageType.default, discord.MessageType.reply):
return False
return True # message accepted
def test_default_messages_accepted(self):
"""Regular user messages (type=default) should be accepted."""
msg = _make_message(msg_type=discord.MessageType.default)
self.assertTrue(self._run_filter(msg))
def test_reply_messages_accepted(self):
"""Reply messages (type=reply) should be accepted — users reply to bot messages."""
msg = _make_message(msg_type=discord.MessageType.reply)
self.assertTrue(self._run_filter(msg))
def test_thread_rename_ignored(self):
"""Thread rename system messages should be ignored."""
msg = _make_message(msg_type=discord.MessageType.channel_name_change)
self.assertFalse(self._run_filter(msg))
def test_pins_add_ignored(self):
"""Pin notifications should be ignored."""
msg = _make_message(msg_type=discord.MessageType.pins_add)
self.assertFalse(self._run_filter(msg))
def test_new_member_ignored(self):
"""New member join messages should be ignored."""
msg = _make_message(msg_type=discord.MessageType.new_member)
self.assertFalse(self._run_filter(msg))
def test_premium_guild_subscription_ignored(self):
"""Boost messages should be ignored."""
msg = _make_message(msg_type=discord.MessageType.premium_guild_subscription)
self.assertFalse(self._run_filter(msg))
def test_recipient_add_ignored(self):
"""Group DM recipient add messages should be ignored."""
msg = _make_message(msg_type=discord.MessageType.recipient_add)
self.assertFalse(self._run_filter(msg))
def test_own_default_messages_still_ignored(self):
"""Bot's own messages should still be ignored even if type is default."""
bot_user = _make_author(is_self=True)
msg = _make_message(author=bot_user, msg_type=discord.MessageType.default)
self.assertFalse(self._run_filter(msg, client_user=bot_user))
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,83 @@
"""Tests for Discord thread participation persistence.
Verifies that _bot_participated_threads survives adapter restarts by
being persisted to ~/.hermes/discord_threads.json.
"""
import json
import os
from unittest.mock import patch
import pytest
class TestDiscordThreadPersistence:
"""Thread IDs are saved to disk and reloaded on init."""
def _make_adapter(self, tmp_path):
"""Build a minimal DiscordAdapter with HERMES_HOME pointed at tmp_path."""
from gateway.config import PlatformConfig
from gateway.platforms.discord import DiscordAdapter
config = PlatformConfig(enabled=True, token="test-token")
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
return DiscordAdapter(config=config)
def test_starts_empty_when_no_state_file(self, tmp_path):
adapter = self._make_adapter(tmp_path)
assert adapter._bot_participated_threads == set()
def test_track_thread_persists_to_disk(self, tmp_path):
adapter = self._make_adapter(tmp_path)
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
adapter._track_thread("111")
adapter._track_thread("222")
state_file = tmp_path / "discord_threads.json"
assert state_file.exists()
saved = json.loads(state_file.read_text())
assert set(saved) == {"111", "222"}
def test_threads_survive_restart(self, tmp_path):
"""Threads tracked by one adapter instance are visible to the next."""
adapter1 = self._make_adapter(tmp_path)
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
adapter1._track_thread("aaa")
adapter1._track_thread("bbb")
adapter2 = self._make_adapter(tmp_path)
assert "aaa" in adapter2._bot_participated_threads
assert "bbb" in adapter2._bot_participated_threads
def test_duplicate_track_does_not_double_save(self, tmp_path):
adapter = self._make_adapter(tmp_path)
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
adapter._track_thread("111")
adapter._track_thread("111") # no-op
saved = json.loads((tmp_path / "discord_threads.json").read_text())
assert saved.count("111") == 1
def test_caps_at_max_tracked_threads(self, tmp_path):
adapter = self._make_adapter(tmp_path)
adapter._MAX_TRACKED_THREADS = 5
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
for i in range(10):
adapter._track_thread(str(i))
assert len(adapter._bot_participated_threads) == 5
def test_corrupted_state_file_falls_back_to_empty(self, tmp_path):
state_file = tmp_path / "discord_threads.json"
state_file.write_text("not valid json{{{")
adapter = self._make_adapter(tmp_path)
assert adapter._bot_participated_threads == set()
def test_missing_hermes_home_does_not_crash(self, tmp_path):
"""Load/save tolerate missing directories."""
fake_home = tmp_path / "nonexistent" / "deep"
with patch.dict(os.environ, {"HERMES_HOME": str(fake_home)}):
from gateway.platforms.discord import DiscordAdapter
# _load should return empty set, not crash
threads = DiscordAdapter._load_participated_threads()
assert threads == set()

View file

@ -0,0 +1,157 @@
"""
Tests for document cache utilities in gateway/platforms/base.py.
Covers: get_document_cache_dir, cache_document_from_bytes,
cleanup_document_cache, SUPPORTED_DOCUMENT_TYPES.
"""
import os
import time
from pathlib import Path
import pytest
from gateway.platforms.base import (
SUPPORTED_DOCUMENT_TYPES,
cache_document_from_bytes,
cleanup_document_cache,
get_document_cache_dir,
)
# ---------------------------------------------------------------------------
# Fixture: redirect DOCUMENT_CACHE_DIR to a temp directory for every test
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _redirect_cache(tmp_path, monkeypatch):
"""Point the module-level DOCUMENT_CACHE_DIR to a fresh tmp_path."""
monkeypatch.setattr(
"gateway.platforms.base.DOCUMENT_CACHE_DIR", tmp_path / "doc_cache"
)
# ---------------------------------------------------------------------------
# TestGetDocumentCacheDir
# ---------------------------------------------------------------------------
class TestGetDocumentCacheDir:
def test_creates_directory(self, tmp_path):
cache_dir = get_document_cache_dir()
assert cache_dir.exists()
assert cache_dir.is_dir()
def test_returns_existing_directory(self):
first = get_document_cache_dir()
second = get_document_cache_dir()
assert first == second
assert first.exists()
# ---------------------------------------------------------------------------
# TestCacheDocumentFromBytes
# ---------------------------------------------------------------------------
class TestCacheDocumentFromBytes:
def test_basic_caching(self):
data = b"hello world"
path = cache_document_from_bytes(data, "test.txt")
assert os.path.exists(path)
assert Path(path).read_bytes() == data
def test_filename_preserved_in_path(self):
path = cache_document_from_bytes(b"data", "report.pdf")
assert "report.pdf" in os.path.basename(path)
def test_empty_filename_uses_fallback(self):
path = cache_document_from_bytes(b"data", "")
assert "document" in os.path.basename(path)
def test_unique_filenames(self):
p1 = cache_document_from_bytes(b"a", "same.txt")
p2 = cache_document_from_bytes(b"b", "same.txt")
assert p1 != p2
def test_path_traversal_blocked(self):
"""Malicious directory components are stripped — only the leaf name survives."""
path = cache_document_from_bytes(b"data", "../../etc/passwd")
basename = os.path.basename(path)
assert "passwd" in basename
# Must NOT contain directory separators
assert ".." not in basename
# File must reside inside the cache directory
cache_dir = get_document_cache_dir()
assert Path(path).resolve().is_relative_to(cache_dir.resolve())
def test_null_bytes_stripped(self):
path = cache_document_from_bytes(b"data", "file\x00.pdf")
basename = os.path.basename(path)
assert "\x00" not in basename
assert "file.pdf" in basename
def test_dot_dot_filename_handled(self):
"""A filename that is literally '..' falls back to 'document'."""
path = cache_document_from_bytes(b"data", "..")
basename = os.path.basename(path)
assert "document" in basename
def test_none_filename_uses_fallback(self):
path = cache_document_from_bytes(b"data", None)
assert "document" in os.path.basename(path)
# ---------------------------------------------------------------------------
# TestCleanupDocumentCache
# ---------------------------------------------------------------------------
class TestCleanupDocumentCache:
def test_removes_old_files(self, tmp_path):
cache_dir = get_document_cache_dir()
old_file = cache_dir / "old.txt"
old_file.write_text("old")
# Set modification time to 48 hours ago
old_mtime = time.time() - 48 * 3600
os.utime(old_file, (old_mtime, old_mtime))
removed = cleanup_document_cache(max_age_hours=24)
assert removed == 1
assert not old_file.exists()
def test_keeps_recent_files(self):
cache_dir = get_document_cache_dir()
recent = cache_dir / "recent.txt"
recent.write_text("fresh")
removed = cleanup_document_cache(max_age_hours=24)
assert removed == 0
assert recent.exists()
def test_returns_removed_count(self):
cache_dir = get_document_cache_dir()
old_time = time.time() - 48 * 3600
for i in range(3):
f = cache_dir / f"old_{i}.txt"
f.write_text("x")
os.utime(f, (old_time, old_time))
assert cleanup_document_cache(max_age_hours=24) == 3
def test_empty_cache_dir(self):
assert cleanup_document_cache(max_age_hours=24) == 0
# ---------------------------------------------------------------------------
# TestSupportedDocumentTypes
# ---------------------------------------------------------------------------
class TestSupportedDocumentTypes:
def test_all_extensions_have_mime_types(self):
for ext, mime in SUPPORTED_DOCUMENT_TYPES.items():
assert ext.startswith("."), f"{ext} missing leading dot"
assert "/" in mime, f"{mime} is not a valid MIME type"
@pytest.mark.parametrize(
"ext",
[".pdf", ".md", ".txt", ".docx", ".xlsx", ".pptx"],
)
def test_expected_extensions_present(self, ext):
assert ext in SUPPORTED_DOCUMENT_TYPES

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,317 @@
"""
Tests for extract_local_files() auto-detection of bare local file paths
in model response text for native media delivery.
Covers: path matching, code-block exclusion, URL rejection, tilde expansion,
deduplication, text cleanup, and extension routing.
Based on PR #1636 by sudoingX (salvaged + hardened).
"""
import os
from unittest.mock import patch
import pytest
from gateway.platforms.base import BasePlatformAdapter
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _extract(content: str, existing_files: set[str] | None = None):
"""
Run extract_local_files with os.path.isfile mocked to return True
for any path in *existing_files* (expanded form). If *existing_files*
is None every path passes.
"""
existing = existing_files
def fake_isfile(p):
if existing is None:
return True
return p in existing
def fake_expanduser(p):
if p.startswith("~/"):
return "/home/user" + p[1:]
return p
with patch("os.path.isfile", side_effect=fake_isfile), \
patch("os.path.expanduser", side_effect=fake_expanduser):
return BasePlatformAdapter.extract_local_files(content)
# ---------------------------------------------------------------------------
# Basic detection
# ---------------------------------------------------------------------------
class TestBasicDetection:
def test_absolute_path_image(self):
paths, cleaned = _extract("Here is the screenshot /root/screenshots/game.png enjoy")
assert paths == ["/root/screenshots/game.png"]
assert "/root/screenshots/game.png" not in cleaned
assert "Here is the screenshot" in cleaned
def test_tilde_path_image(self):
paths, cleaned = _extract("Check out ~/photos/cat.jpg for the cat")
assert paths == ["/home/user/photos/cat.jpg"]
assert "~/photos/cat.jpg" not in cleaned
def test_video_extensions(self):
for ext in (".mp4", ".mov", ".avi", ".mkv", ".webm"):
text = f"Video at /tmp/clip{ext} here"
paths, _ = _extract(text)
assert len(paths) == 1, f"Failed for {ext}"
assert paths[0] == f"/tmp/clip{ext}"
def test_image_extensions(self):
for ext in (".png", ".jpg", ".jpeg", ".gif", ".webp"):
text = f"Image at /tmp/pic{ext} here"
paths, _ = _extract(text)
assert len(paths) == 1, f"Failed for {ext}"
assert paths[0] == f"/tmp/pic{ext}"
def test_case_insensitive_extension(self):
paths, _ = _extract("See /tmp/PHOTO.PNG and /tmp/vid.MP4 now")
assert len(paths) == 2
def test_multiple_paths(self):
text = "First /tmp/a.png then /tmp/b.jpg and /tmp/c.mp4 done"
paths, cleaned = _extract(text)
assert len(paths) == 3
assert "/tmp/a.png" in paths
assert "/tmp/b.jpg" in paths
assert "/tmp/c.mp4" in paths
for p in paths:
assert p not in cleaned
def test_path_at_line_start(self):
paths, _ = _extract("/var/data/image.png")
assert paths == ["/var/data/image.png"]
def test_path_at_end_of_line(self):
paths, _ = _extract("saved to /var/data/image.png")
assert paths == ["/var/data/image.png"]
def test_path_with_dots_in_directory(self):
paths, _ = _extract("See /opt/my.app/assets/logo.png here")
assert paths == ["/opt/my.app/assets/logo.png"]
def test_path_with_hyphens(self):
paths, _ = _extract("File at /tmp/my-screenshot-2024.png done")
assert paths == ["/tmp/my-screenshot-2024.png"]
# ---------------------------------------------------------------------------
# Non-existent files are skipped
# ---------------------------------------------------------------------------
class TestIsfileGuard:
def test_nonexistent_path_skipped(self):
"""Paths that don't exist on disk are not extracted."""
paths, cleaned = _extract(
"See /tmp/nope.png here",
existing_files=set(), # nothing exists
)
assert paths == []
assert "/tmp/nope.png" in cleaned # not stripped
def test_only_existing_paths_extracted(self):
"""Mix of existing and non-existing — only existing are returned."""
paths, cleaned = _extract(
"A /tmp/real.png and /tmp/fake.jpg end",
existing_files={"/tmp/real.png"},
)
assert paths == ["/tmp/real.png"]
assert "/tmp/real.png" not in cleaned
assert "/tmp/fake.jpg" in cleaned
# ---------------------------------------------------------------------------
# URL false-positive prevention
# ---------------------------------------------------------------------------
class TestURLRejection:
def test_https_url_not_matched(self):
"""Paths embedded in HTTP URLs must not be extracted."""
paths, cleaned = _extract("Visit https://example.com/images/photo.png for details")
# The regex lookbehind should prevent matching the URL's path segment
# Even if it did match, isfile would be False for /images/photo.png
# (we mock isfile to True-for-all here, so the lookbehind is the guard)
assert paths == []
assert "https://example.com/images/photo.png" in cleaned
def test_http_url_not_matched(self):
paths, _ = _extract("See http://cdn.example.com/assets/banner.jpg here")
assert paths == []
def test_file_url_not_matched(self):
paths, _ = _extract("Open file:///home/user/doc.png in browser")
# file:// has :// before /home so lookbehind blocks it
assert paths == []
# ---------------------------------------------------------------------------
# Code block exclusion
# ---------------------------------------------------------------------------
class TestCodeBlockExclusion:
def test_fenced_code_block_skipped(self):
text = "Here's how:\n```python\nimg = open('/tmp/image.png')\n```\nDone."
paths, cleaned = _extract(text)
assert paths == []
assert "/tmp/image.png" in cleaned # not stripped
def test_inline_code_skipped(self):
text = "Use the path `/tmp/image.png` in your config"
paths, cleaned = _extract(text)
assert paths == []
assert "`/tmp/image.png`" in cleaned
def test_path_outside_code_block_still_matched(self):
text = (
"```\ncode: /tmp/inside.png\n```\n"
"But this one is real: /tmp/outside.png"
)
paths, _ = _extract(text, existing_files={"/tmp/outside.png"})
assert paths == ["/tmp/outside.png"]
def test_mixed_inline_code_and_bare_path(self):
text = "Config uses `/etc/app/bg.png` but output is /tmp/result.jpg"
paths, cleaned = _extract(text, existing_files={"/tmp/result.jpg"})
assert paths == ["/tmp/result.jpg"]
assert "`/etc/app/bg.png`" in cleaned
assert "/tmp/result.jpg" not in cleaned
def test_multiline_fenced_block(self):
text = (
"```bash\n"
"cp /source/a.png /dest/b.png\n"
"mv /source/c.mp4 /dest/d.mp4\n"
"```\n"
"Files are ready."
)
paths, _ = _extract(text)
assert paths == []
# ---------------------------------------------------------------------------
# Deduplication
# ---------------------------------------------------------------------------
class TestDeduplication:
def test_duplicate_paths_deduplicated(self):
text = "See /tmp/img.png and also /tmp/img.png again"
paths, _ = _extract(text)
assert paths == ["/tmp/img.png"]
def test_tilde_and_expanded_same_file(self):
"""~/photos/a.png and /home/user/photos/a.png are the same file."""
text = "See ~/photos/a.png and /home/user/photos/a.png here"
paths, _ = _extract(text, existing_files={"/home/user/photos/a.png"})
assert len(paths) == 1
assert paths[0] == "/home/user/photos/a.png"
# ---------------------------------------------------------------------------
# Text cleanup
# ---------------------------------------------------------------------------
class TestTextCleanup:
def test_path_removed_from_text(self):
paths, cleaned = _extract("Before /tmp/x.png after")
assert "Before" in cleaned
assert "after" in cleaned
assert "/tmp/x.png" not in cleaned
def test_excessive_blank_lines_collapsed(self):
text = "Before\n\n\n/tmp/x.png\n\n\nAfter"
_, cleaned = _extract(text)
assert "\n\n\n" not in cleaned
def test_no_paths_text_unchanged(self):
text = "This is a normal response with no file paths."
paths, cleaned = _extract(text)
assert paths == []
assert cleaned == text
def test_tilde_form_cleaned_from_text(self):
"""The raw ~/... form should be removed, not the expanded /home/user/... form."""
text = "Output saved to ~/result.png for review"
paths, cleaned = _extract(text)
assert paths == ["/home/user/result.png"]
assert "~/result.png" not in cleaned
def test_only_path_in_text(self):
"""If the response is just a path, cleaned text is empty."""
paths, cleaned = _extract("/tmp/screenshot.png")
assert paths == ["/tmp/screenshot.png"]
assert cleaned == ""
# ---------------------------------------------------------------------------
# Edge cases
# ---------------------------------------------------------------------------
class TestEdgeCases:
def test_empty_string(self):
paths, cleaned = _extract("")
assert paths == []
assert cleaned == ""
def test_no_media_extensions(self):
"""Non-media extensions should not be matched."""
paths, _ = _extract("See /tmp/data.csv and /tmp/script.py and /tmp/notes.txt")
assert paths == []
def test_path_with_spaces_not_matched(self):
"""Paths with spaces are intentionally not matched (avoids false positives)."""
paths, _ = _extract("File at /tmp/my file.png here")
assert paths == []
def test_windows_path_not_matched(self):
"""Windows-style paths should not match."""
paths, _ = _extract("See C:\\Users\\test\\image.png")
assert paths == []
def test_relative_path_not_matched(self):
"""Relative paths like ./image.png should not match."""
paths, _ = _extract("File at ./screenshots/image.png here")
assert paths == []
def test_bare_filename_not_matched(self):
"""Just 'image.png' without a path should not match."""
paths, _ = _extract("Open image.png to see")
assert paths == []
def test_path_followed_by_punctuation(self):
"""Path followed by comma, period, paren should still match."""
for suffix in [",", ".", ")", ":", ";"]:
text = f"See /tmp/img.png{suffix} details"
paths, _ = _extract(text)
assert len(paths) == 1, f"Failed with suffix '{suffix}'"
def test_path_in_parentheses(self):
paths, _ = _extract("(see /tmp/img.png)")
assert paths == ["/tmp/img.png"]
def test_path_in_quotes(self):
paths, _ = _extract('The file is "/tmp/img.png" right here')
assert paths == ["/tmp/img.png"]
def test_deep_nested_path(self):
paths, _ = _extract("At /a/b/c/d/e/f/g/h/image.png end")
assert paths == ["/a/b/c/d/e/f/g/h/image.png"]
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View file

@ -0,0 +1,167 @@
"""Tests for memory flush stale-overwrite prevention (#2670).
Verifies that:
1. Cron sessions are skipped (no flush for headless cron runs)
2. Current memory state is injected into the flush prompt so the
flush agent can see what's already saved and avoid overwrites
3. The flush still works normally when memory files don't exist
"""
import pytest
from pathlib import Path
from unittest.mock import MagicMock, patch, call
def _make_runner():
from gateway.run import GatewayRunner
runner = object.__new__(GatewayRunner)
runner._honcho_managers = {}
runner._honcho_configs = {}
runner._running_agents = {}
runner._pending_messages = {}
runner._pending_approvals = {}
runner.adapters = {}
runner.hooks = MagicMock()
runner.session_store = MagicMock()
return runner
_TRANSCRIPT_4_MSGS = [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "hi there"},
{"role": "user", "content": "remember my name is Alice"},
{"role": "assistant", "content": "Got it, Alice!"},
]
class TestCronSessionBypass:
"""Cron sessions should never trigger a memory flush."""
def test_cron_session_skipped(self):
runner = _make_runner()
runner._flush_memories_for_session("cron_job123_20260323_120000")
# session_store.load_transcript should never be called
runner.session_store.load_transcript.assert_not_called()
def test_cron_session_with_honcho_key_skipped(self):
runner = _make_runner()
runner._flush_memories_for_session("cron_daily_20260323", "some-honcho-key")
runner.session_store.load_transcript.assert_not_called()
def test_non_cron_session_proceeds(self):
"""Non-cron sessions should still attempt the flush."""
runner = _make_runner()
runner.session_store.load_transcript.return_value = []
runner._flush_memories_for_session("session_abc123")
runner.session_store.load_transcript.assert_called_once_with("session_abc123")
class TestMemoryInjection:
"""The flush prompt should include current memory state from disk."""
def test_memory_content_injected_into_flush_prompt(self, tmp_path):
"""When memory files exist, their content appears in the flush prompt."""
runner = _make_runner()
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
tmp_agent = MagicMock()
memory_dir = tmp_path / "memories"
memory_dir.mkdir()
(memory_dir / "MEMORY.md").write_text("Agent knows Python\n§\nUser prefers dark mode")
(memory_dir / "USER.md").write_text("Name: Alice\n§\nTimezone: PST")
with (
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
patch("run_agent.AIAgent", return_value=tmp_agent),
# Intercept `from tools.memory_tool import MEMORY_DIR` inside the function
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=memory_dir)}),
):
runner._flush_memories_for_session("session_123")
tmp_agent.run_conversation.assert_called_once()
call_kwargs = tmp_agent.run_conversation.call_args.kwargs
flush_prompt = call_kwargs.get("user_message", "")
# Verify both memory sections appear in the prompt
assert "Agent knows Python" in flush_prompt
assert "User prefers dark mode" in flush_prompt
assert "Name: Alice" in flush_prompt
assert "Timezone: PST" in flush_prompt
# Verify the stale-overwrite warning is present
assert "Do NOT overwrite or remove entries" in flush_prompt
assert "current live state of memory" in flush_prompt
def test_flush_works_without_memory_files(self, tmp_path):
"""When no memory files exist, flush still runs without the guard."""
runner = _make_runner()
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
tmp_agent = MagicMock()
empty_dir = tmp_path / "no_memories"
empty_dir.mkdir()
with (
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
patch("run_agent.AIAgent", return_value=tmp_agent),
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=empty_dir)}),
):
runner._flush_memories_for_session("session_456")
# Should still run, just without the memory guard section
tmp_agent.run_conversation.assert_called_once()
flush_prompt = tmp_agent.run_conversation.call_args.kwargs.get("user_message", "")
assert "Do NOT overwrite or remove entries" not in flush_prompt
assert "Review the conversation above" in flush_prompt
def test_empty_memory_files_no_injection(self, tmp_path):
"""Empty memory files should not trigger the guard section."""
runner = _make_runner()
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
tmp_agent = MagicMock()
memory_dir = tmp_path / "memories"
memory_dir.mkdir()
(memory_dir / "MEMORY.md").write_text("")
(memory_dir / "USER.md").write_text(" \n ") # whitespace only
with (
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
patch("run_agent.AIAgent", return_value=tmp_agent),
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=memory_dir)}),
):
runner._flush_memories_for_session("session_789")
tmp_agent.run_conversation.assert_called_once()
flush_prompt = tmp_agent.run_conversation.call_args.kwargs.get("user_message", "")
# No memory content → no guard section
assert "current live state of memory" not in flush_prompt
class TestFlushPromptStructure:
"""Verify the flush prompt retains its core instructions."""
def test_core_instructions_present(self):
"""The flush prompt should still contain the original guidance."""
runner = _make_runner()
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
tmp_agent = MagicMock()
with (
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
patch("run_agent.AIAgent", return_value=tmp_agent),
# Make the import fail gracefully so we test without memory files
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=Path("/nonexistent"))}),
):
runner._flush_memories_for_session("session_struct")
flush_prompt = tmp_agent.run_conversation.call_args.kwargs.get("user_message", "")
assert "automatically reset" in flush_prompt
assert "Save any important facts" in flush_prompt
assert "consider saving it as a skill" in flush_prompt
assert "Do NOT respond to the user" in flush_prompt

View file

@ -0,0 +1,106 @@
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from gateway.config import GatewayConfig, Platform, PlatformConfig
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, SendResult
from gateway.run import GatewayRunner
from gateway.session import SessionSource, build_session_key
class StubAdapter(BasePlatformAdapter):
def __init__(self):
super().__init__(PlatformConfig(enabled=True, token="***"), Platform.TELEGRAM)
async def connect(self):
return True
async def disconnect(self):
return None
async def send(self, chat_id, content, reply_to=None, metadata=None):
return SendResult(success=True, message_id="1")
async def send_typing(self, chat_id, metadata=None):
return None
async def get_chat_info(self, chat_id):
return {"id": chat_id}
def _source(chat_id="123456", chat_type="dm"):
return SessionSource(
platform=Platform.TELEGRAM,
chat_id=chat_id,
chat_type=chat_type,
)
@pytest.mark.asyncio
async def test_cancel_background_tasks_cancels_inflight_message_processing():
adapter = StubAdapter()
release = asyncio.Event()
async def block_forever(_event):
await release.wait()
return None
adapter.set_message_handler(block_forever)
event = MessageEvent(text="work", source=_source(), message_id="1")
await adapter.handle_message(event)
await asyncio.sleep(0)
session_key = build_session_key(event.source)
assert session_key in adapter._active_sessions
assert adapter._background_tasks
await adapter.cancel_background_tasks()
assert adapter._background_tasks == set()
assert adapter._active_sessions == {}
assert adapter._pending_messages == {}
@pytest.mark.asyncio
async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks():
runner = object.__new__(GatewayRunner)
runner.config = GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")})
runner._running = True
runner._shutdown_event = asyncio.Event()
runner._exit_reason = None
runner._pending_messages = {"session": "pending text"}
runner._pending_approvals = {"session": {"command": "rm -rf /tmp/x"}}
runner._shutdown_all_gateway_honcho = lambda: None
adapter = StubAdapter()
release = asyncio.Event()
async def block_forever(_event):
await release.wait()
return None
adapter.set_message_handler(block_forever)
event = MessageEvent(text="work", source=_source(), message_id="1")
await adapter.handle_message(event)
await asyncio.sleep(0)
disconnect_mock = AsyncMock()
adapter.disconnect = disconnect_mock
session_key = build_session_key(event.source)
running_agent = MagicMock()
runner._running_agents = {session_key: running_agent}
runner.adapters = {Platform.TELEGRAM: adapter}
with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"):
await runner.stop()
running_agent.interrupt.assert_called_once_with("Gateway shutting down")
disconnect_mock.assert_awaited_once()
assert runner.adapters == {}
assert runner._running_agents == {}
assert runner._pending_messages == {}
assert runner._pending_approvals == {}
assert runner._shutdown_event.is_set() is True

View file

@ -0,0 +1,622 @@
"""Tests for the Home Assistant gateway adapter.
Tests real logic: state change formatting, event filtering pipeline,
cooldown behavior, config integration, and adapter initialization.
"""
import time
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from gateway.config import (
GatewayConfig,
Platform,
PlatformConfig,
)
from gateway.platforms.homeassistant import (
HomeAssistantAdapter,
check_ha_requirements,
)
# ---------------------------------------------------------------------------
# check_ha_requirements
# ---------------------------------------------------------------------------
class TestCheckRequirements:
def test_returns_false_without_token(self, monkeypatch):
monkeypatch.delenv("HASS_TOKEN", raising=False)
assert check_ha_requirements() is False
def test_returns_true_with_token(self, monkeypatch):
monkeypatch.setenv("HASS_TOKEN", "test-token")
assert check_ha_requirements() is True
@patch("gateway.platforms.homeassistant.AIOHTTP_AVAILABLE", False)
def test_returns_false_without_aiohttp(self, monkeypatch):
monkeypatch.setenv("HASS_TOKEN", "test-token")
assert check_ha_requirements() is False
# ---------------------------------------------------------------------------
# _format_state_change - pure function, all domain branches
# ---------------------------------------------------------------------------
class TestFormatStateChange:
@staticmethod
def fmt(entity_id, old_state, new_state):
return HomeAssistantAdapter._format_state_change(entity_id, old_state, new_state)
def test_climate_includes_temperatures(self):
msg = self.fmt(
"climate.thermostat",
{"state": "off"},
{"state": "heat", "attributes": {
"friendly_name": "Main Thermostat",
"current_temperature": 21.5,
"temperature": 23,
}},
)
assert "Main Thermostat" in msg
assert "'off'" in msg and "'heat'" in msg
assert "21.5" in msg and "23" in msg
def test_sensor_includes_unit(self):
msg = self.fmt(
"sensor.temperature",
{"state": "22.5"},
{"state": "25.1", "attributes": {
"friendly_name": "Living Room Temp",
"unit_of_measurement": "C",
}},
)
assert "22.5C" in msg and "25.1C" in msg
assert "Living Room Temp" in msg
def test_sensor_without_unit(self):
msg = self.fmt(
"sensor.count",
{"state": "5"},
{"state": "10", "attributes": {"friendly_name": "Counter"}},
)
assert "5" in msg and "10" in msg
def test_binary_sensor_on(self):
msg = self.fmt(
"binary_sensor.motion",
{"state": "off"},
{"state": "on", "attributes": {"friendly_name": "Hallway Motion"}},
)
assert "triggered" in msg
assert "Hallway Motion" in msg
def test_binary_sensor_off(self):
msg = self.fmt(
"binary_sensor.door",
{"state": "on"},
{"state": "off", "attributes": {"friendly_name": "Front Door"}},
)
assert "cleared" in msg
def test_light_turned_on(self):
msg = self.fmt(
"light.bedroom",
{"state": "off"},
{"state": "on", "attributes": {"friendly_name": "Bedroom Light"}},
)
assert "turned on" in msg
def test_switch_turned_off(self):
msg = self.fmt(
"switch.heater",
{"state": "on"},
{"state": "off", "attributes": {"friendly_name": "Heater"}},
)
assert "turned off" in msg
def test_fan_domain_uses_light_switch_branch(self):
msg = self.fmt(
"fan.ceiling",
{"state": "off"},
{"state": "on", "attributes": {"friendly_name": "Ceiling Fan"}},
)
assert "turned on" in msg
def test_alarm_panel(self):
msg = self.fmt(
"alarm_control_panel.home",
{"state": "disarmed"},
{"state": "armed_away", "attributes": {"friendly_name": "Home Alarm"}},
)
assert "Home Alarm" in msg
assert "armed_away" in msg and "disarmed" in msg
def test_generic_domain_includes_entity_id(self):
msg = self.fmt(
"automation.morning",
{"state": "off"},
{"state": "on", "attributes": {"friendly_name": "Morning Routine"}},
)
assert "automation.morning" in msg
assert "Morning Routine" in msg
def test_same_state_returns_none(self):
assert self.fmt(
"sensor.temp",
{"state": "22"},
{"state": "22", "attributes": {"friendly_name": "Temp"}},
) is None
def test_empty_new_state_returns_none(self):
assert self.fmt("light.x", {"state": "on"}, {}) is None
def test_no_old_state_uses_unknown(self):
msg = self.fmt(
"light.new",
None,
{"state": "on", "attributes": {"friendly_name": "New Light"}},
)
assert msg is not None
assert "New Light" in msg
def test_uses_entity_id_when_no_friendly_name(self):
msg = self.fmt(
"sensor.unnamed",
{"state": "1"},
{"state": "2", "attributes": {}},
)
assert "sensor.unnamed" in msg
# ---------------------------------------------------------------------------
# Adapter initialization from config
# ---------------------------------------------------------------------------
class TestAdapterInit:
def test_url_and_token_from_config_extra(self, monkeypatch):
monkeypatch.delenv("HASS_URL", raising=False)
monkeypatch.delenv("HASS_TOKEN", raising=False)
config = PlatformConfig(
enabled=True,
token="config-token",
extra={"url": "http://192.168.1.50:8123"},
)
adapter = HomeAssistantAdapter(config)
assert adapter._hass_token == "config-token"
assert adapter._hass_url == "http://192.168.1.50:8123"
def test_url_fallback_to_env(self, monkeypatch):
monkeypatch.setenv("HASS_URL", "http://env-host:8123")
monkeypatch.setenv("HASS_TOKEN", "env-tok")
config = PlatformConfig(enabled=True, token="env-tok")
adapter = HomeAssistantAdapter(config)
assert adapter._hass_url == "http://env-host:8123"
def test_trailing_slash_stripped(self):
config = PlatformConfig(
enabled=True, token="t",
extra={"url": "http://ha.local:8123/"},
)
adapter = HomeAssistantAdapter(config)
assert adapter._hass_url == "http://ha.local:8123"
def test_watch_filters_parsed(self):
config = PlatformConfig(
enabled=True, token="***",
extra={
"watch_domains": ["climate", "binary_sensor"],
"watch_entities": ["sensor.special"],
"ignore_entities": ["sensor.uptime", "sensor.cpu"],
"cooldown_seconds": 120,
},
)
adapter = HomeAssistantAdapter(config)
assert adapter._watch_domains == {"climate", "binary_sensor"}
assert adapter._watch_entities == {"sensor.special"}
assert adapter._ignore_entities == {"sensor.uptime", "sensor.cpu"}
assert adapter._watch_all is False
assert adapter._cooldown_seconds == 120
def test_watch_all_parsed(self):
config = PlatformConfig(
enabled=True, token="***",
extra={"watch_all": True},
)
adapter = HomeAssistantAdapter(config)
assert adapter._watch_all is True
def test_defaults_when_no_extra(self, monkeypatch):
monkeypatch.setenv("HASS_TOKEN", "tok")
config = PlatformConfig(enabled=True, token="***")
adapter = HomeAssistantAdapter(config)
assert adapter._watch_domains == set()
assert adapter._watch_entities == set()
assert adapter._ignore_entities == set()
assert adapter._watch_all is False
assert adapter._cooldown_seconds == 30
# ---------------------------------------------------------------------------
# Event filtering pipeline (_handle_ha_event)
#
# We mock handle_message (not our code, it's the base class pipeline) to
# capture the MessageEvent that _handle_ha_event produces.
# ---------------------------------------------------------------------------
def _make_adapter(**extra) -> HomeAssistantAdapter:
config = PlatformConfig(enabled=True, token="tok", extra=extra)
adapter = HomeAssistantAdapter(config)
adapter.handle_message = AsyncMock()
return adapter
def _make_event(entity_id, old_state, new_state, old_attrs=None, new_attrs=None):
return {
"data": {
"entity_id": entity_id,
"old_state": {"state": old_state, "attributes": old_attrs or {}},
"new_state": {"state": new_state, "attributes": new_attrs or {"friendly_name": entity_id}},
}
}
class TestEventFilteringPipeline:
@pytest.mark.asyncio
async def test_ignored_entity_not_forwarded(self):
adapter = _make_adapter(watch_all=True, ignore_entities=["sensor.uptime"])
await adapter._handle_ha_event(_make_event("sensor.uptime", "100", "101"))
adapter.handle_message.assert_not_called()
@pytest.mark.asyncio
async def test_unwatched_domain_not_forwarded(self):
adapter = _make_adapter(watch_domains=["climate"])
await adapter._handle_ha_event(_make_event("light.bedroom", "off", "on"))
adapter.handle_message.assert_not_called()
@pytest.mark.asyncio
async def test_watched_domain_forwarded(self):
adapter = _make_adapter(watch_domains=["climate"], cooldown_seconds=0)
await adapter._handle_ha_event(
_make_event("climate.thermostat", "off", "heat",
new_attrs={"friendly_name": "Thermostat", "current_temperature": 20, "temperature": 22})
)
adapter.handle_message.assert_called_once()
# Verify the actual MessageEvent text content
msg_event = adapter.handle_message.call_args[0][0]
assert "Thermostat" in msg_event.text
assert "heat" in msg_event.text
assert msg_event.source.platform == Platform.HOMEASSISTANT
assert msg_event.source.chat_id == "ha_events"
@pytest.mark.asyncio
async def test_watched_entity_forwarded(self):
adapter = _make_adapter(watch_entities=["sensor.important"], cooldown_seconds=0)
await adapter._handle_ha_event(
_make_event("sensor.important", "10", "20",
new_attrs={"friendly_name": "Important Sensor", "unit_of_measurement": "W"})
)
adapter.handle_message.assert_called_once()
msg_event = adapter.handle_message.call_args[0][0]
assert "10W" in msg_event.text and "20W" in msg_event.text
@pytest.mark.asyncio
async def test_no_filters_blocks_everything(self):
"""Without watch_domains, watch_entities, or watch_all, events are dropped."""
adapter = _make_adapter(cooldown_seconds=0)
await adapter._handle_ha_event(_make_event("cover.blinds", "closed", "open"))
adapter.handle_message.assert_not_called()
@pytest.mark.asyncio
async def test_watch_all_passes_everything(self):
"""With watch_all=True and no specific filters, all events pass through."""
adapter = _make_adapter(watch_all=True, cooldown_seconds=0)
await adapter._handle_ha_event(_make_event("cover.blinds", "closed", "open"))
adapter.handle_message.assert_called_once()
@pytest.mark.asyncio
async def test_same_state_not_forwarded(self):
adapter = _make_adapter(watch_all=True, cooldown_seconds=0)
await adapter._handle_ha_event(_make_event("light.x", "on", "on"))
adapter.handle_message.assert_not_called()
@pytest.mark.asyncio
async def test_empty_entity_id_skipped(self):
adapter = _make_adapter(watch_all=True)
await adapter._handle_ha_event({"data": {"entity_id": ""}})
adapter.handle_message.assert_not_called()
@pytest.mark.asyncio
async def test_message_event_has_correct_source(self):
adapter = _make_adapter(watch_all=True, cooldown_seconds=0)
await adapter._handle_ha_event(
_make_event("light.test", "off", "on",
new_attrs={"friendly_name": "Test Light"})
)
msg_event = adapter.handle_message.call_args[0][0]
assert msg_event.source.user_name == "Home Assistant"
assert msg_event.source.chat_type == "channel"
assert msg_event.message_id.startswith("ha_light.test_")
# ---------------------------------------------------------------------------
# Cooldown behavior
# ---------------------------------------------------------------------------
class TestCooldown:
@pytest.mark.asyncio
async def test_cooldown_blocks_rapid_events(self):
adapter = _make_adapter(watch_all=True, cooldown_seconds=60)
event = _make_event("sensor.temp", "20", "21",
new_attrs={"friendly_name": "Temp"})
await adapter._handle_ha_event(event)
assert adapter.handle_message.call_count == 1
# Second event immediately after should be blocked
event2 = _make_event("sensor.temp", "21", "22",
new_attrs={"friendly_name": "Temp"})
await adapter._handle_ha_event(event2)
assert adapter.handle_message.call_count == 1 # Still 1
@pytest.mark.asyncio
async def test_cooldown_expires(self):
adapter = _make_adapter(watch_all=True, cooldown_seconds=1)
event = _make_event("sensor.temp", "20", "21",
new_attrs={"friendly_name": "Temp"})
await adapter._handle_ha_event(event)
assert adapter.handle_message.call_count == 1
# Simulate time passing beyond cooldown
adapter._last_event_time["sensor.temp"] = time.time() - 2
event2 = _make_event("sensor.temp", "21", "22",
new_attrs={"friendly_name": "Temp"})
await adapter._handle_ha_event(event2)
assert adapter.handle_message.call_count == 2
@pytest.mark.asyncio
async def test_different_entities_independent_cooldowns(self):
adapter = _make_adapter(watch_all=True, cooldown_seconds=60)
await adapter._handle_ha_event(
_make_event("sensor.a", "1", "2", new_attrs={"friendly_name": "A"})
)
await adapter._handle_ha_event(
_make_event("sensor.b", "3", "4", new_attrs={"friendly_name": "B"})
)
# Both should pass - different entities
assert adapter.handle_message.call_count == 2
# Same entity again - should be blocked
await adapter._handle_ha_event(
_make_event("sensor.a", "2", "3", new_attrs={"friendly_name": "A"})
)
assert adapter.handle_message.call_count == 2 # Still 2
@pytest.mark.asyncio
async def test_zero_cooldown_passes_all(self):
adapter = _make_adapter(watch_all=True, cooldown_seconds=0)
for i in range(5):
await adapter._handle_ha_event(
_make_event("sensor.temp", str(i), str(i + 1),
new_attrs={"friendly_name": "Temp"})
)
assert adapter.handle_message.call_count == 5
# ---------------------------------------------------------------------------
# Config integration (env overrides, round-trip)
# ---------------------------------------------------------------------------
class TestConfigIntegration:
def test_env_override_creates_ha_platform(self, monkeypatch):
monkeypatch.setenv("HASS_TOKEN", "env-token")
monkeypatch.setenv("HASS_URL", "http://10.0.0.5:8123")
# Clear other platform tokens
for v in ["TELEGRAM_BOT_TOKEN", "DISCORD_BOT_TOKEN", "SLACK_BOT_TOKEN"]:
monkeypatch.delenv(v, raising=False)
from gateway.config import load_gateway_config
config = load_gateway_config()
assert Platform.HOMEASSISTANT in config.platforms
ha = config.platforms[Platform.HOMEASSISTANT]
assert ha.enabled is True
assert ha.token == "env-token"
assert ha.extra["url"] == "http://10.0.0.5:8123"
def test_no_env_no_platform(self, monkeypatch):
for v in ["HASS_TOKEN", "HASS_URL", "TELEGRAM_BOT_TOKEN",
"DISCORD_BOT_TOKEN", "SLACK_BOT_TOKEN"]:
monkeypatch.delenv(v, raising=False)
from gateway.config import load_gateway_config
config = load_gateway_config()
assert Platform.HOMEASSISTANT not in config.platforms
def test_config_roundtrip_preserves_extra(self):
config = GatewayConfig(
platforms={
Platform.HOMEASSISTANT: PlatformConfig(
enabled=True,
token="tok",
extra={
"url": "http://ha:8123",
"watch_domains": ["climate"],
"cooldown_seconds": 45,
},
),
},
)
d = config.to_dict()
restored = GatewayConfig.from_dict(d)
ha = restored.platforms[Platform.HOMEASSISTANT]
assert ha.enabled is True
assert ha.token == "tok"
assert ha.extra["watch_domains"] == ["climate"]
assert ha.extra["cooldown_seconds"] == 45
def test_connected_platforms_includes_ha(self):
config = GatewayConfig(
platforms={
Platform.HOMEASSISTANT: PlatformConfig(enabled=True, token="tok"),
Platform.TELEGRAM: PlatformConfig(enabled=False, token="t"),
},
)
connected = config.get_connected_platforms()
assert Platform.HOMEASSISTANT in connected
assert Platform.TELEGRAM not in connected
# ---------------------------------------------------------------------------
# send() via REST API
# ---------------------------------------------------------------------------
class TestSendViaRestApi:
"""send() uses REST API (not WebSocket) to avoid race conditions."""
@staticmethod
def _mock_aiohttp_session(response_status=200, response_text="OK"):
"""Build a mock aiohttp session + response for async-with patterns.
aiohttp.ClientSession() is a sync constructor whose return value
is used as ``async with session:``. ``session.post(...)`` returns a
context-manager (not a coroutine), so both layers use MagicMock for
the call and AsyncMock only for ``__aenter__`` / ``__aexit__``.
"""
mock_response = MagicMock()
mock_response.status = response_status
mock_response.text = AsyncMock(return_value=response_text)
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
mock_response.__aexit__ = AsyncMock(return_value=False)
mock_session = MagicMock()
mock_session.post = MagicMock(return_value=mock_response)
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock(return_value=False)
return mock_session
@pytest.mark.asyncio
async def test_send_success(self):
adapter = _make_adapter()
mock_session = self._mock_aiohttp_session(200)
with patch("gateway.platforms.homeassistant.aiohttp") as mock_aiohttp:
mock_aiohttp.ClientSession = MagicMock(return_value=mock_session)
mock_aiohttp.ClientTimeout = lambda total: total
result = await adapter.send("ha_events", "Test notification")
assert result.success is True
# Verify the REST API was called with correct payload
call_args = mock_session.post.call_args
assert "/api/services/persistent_notification/create" in call_args[0][0]
assert call_args[1]["json"]["title"] == "Hermes Agent"
assert call_args[1]["json"]["message"] == "Test notification"
assert "Bearer tok" in call_args[1]["headers"]["Authorization"]
@pytest.mark.asyncio
async def test_send_http_error(self):
adapter = _make_adapter()
mock_session = self._mock_aiohttp_session(401, "Unauthorized")
with patch("gateway.platforms.homeassistant.aiohttp") as mock_aiohttp:
mock_aiohttp.ClientSession = MagicMock(return_value=mock_session)
mock_aiohttp.ClientTimeout = lambda total: total
result = await adapter.send("ha_events", "Test")
assert result.success is False
assert "401" in result.error
@pytest.mark.asyncio
async def test_send_truncates_long_message(self):
adapter = _make_adapter()
mock_session = self._mock_aiohttp_session(200)
long_message = "x" * 10000
with patch("gateway.platforms.homeassistant.aiohttp") as mock_aiohttp:
mock_aiohttp.ClientSession = MagicMock(return_value=mock_session)
mock_aiohttp.ClientTimeout = lambda total: total
await adapter.send("ha_events", long_message)
sent_message = mock_session.post.call_args[1]["json"]["message"]
assert len(sent_message) == 4096
@pytest.mark.asyncio
async def test_send_does_not_use_websocket(self):
"""send() must use REST API, not the WS connection (race condition fix)."""
adapter = _make_adapter()
adapter._ws = AsyncMock() # Simulate an active WS
mock_session = self._mock_aiohttp_session(200)
with patch("gateway.platforms.homeassistant.aiohttp") as mock_aiohttp:
mock_aiohttp.ClientSession = MagicMock(return_value=mock_session)
mock_aiohttp.ClientTimeout = lambda total: total
await adapter.send("ha_events", "Test")
# WS should NOT have been used for sending
adapter._ws.send_json.assert_not_called()
adapter._ws.receive_json.assert_not_called()
# ---------------------------------------------------------------------------
# Toolset integration
# ---------------------------------------------------------------------------
class TestToolsetIntegration:
def test_homeassistant_toolset_resolves(self):
from toolsets import resolve_toolset
tools = resolve_toolset("homeassistant")
assert set(tools) == {"ha_list_entities", "ha_get_state", "ha_call_service", "ha_list_services"}
def test_gateway_toolset_includes_ha_tools(self):
from toolsets import resolve_toolset
gateway_tools = resolve_toolset("hermes-gateway")
for tool in ("ha_list_entities", "ha_get_state", "ha_call_service", "ha_list_services"):
assert tool in gateway_tools
def test_hermes_core_tools_includes_ha(self):
from toolsets import _HERMES_CORE_TOOLS
for tool in ("ha_list_entities", "ha_get_state", "ha_call_service", "ha_list_services"):
assert tool in _HERMES_CORE_TOOLS
# ---------------------------------------------------------------------------
# WebSocket URL construction
# ---------------------------------------------------------------------------
class TestWsUrlConstruction:
def test_http_to_ws(self):
config = PlatformConfig(enabled=True, token="t", extra={"url": "http://ha:8123"})
adapter = HomeAssistantAdapter(config)
ws_url = adapter._hass_url.replace("http://", "ws://").replace("https://", "wss://")
assert ws_url == "ws://ha:8123"
def test_https_to_wss(self):
config = PlatformConfig(enabled=True, token="t", extra={"url": "https://ha.example.com"})
adapter = HomeAssistantAdapter(config)
ws_url = adapter._hass_url.replace("http://", "ws://").replace("https://", "wss://")
assert ws_url == "wss://ha.example.com"

View file

@ -0,0 +1,131 @@
"""Tests for gateway-owned Honcho lifecycle helpers."""
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from gateway.config import Platform
from gateway.platforms.base import MessageEvent
from gateway.session import SessionSource
def _make_runner():
from gateway.run import GatewayRunner
runner = object.__new__(GatewayRunner)
runner._honcho_managers = {}
runner._honcho_configs = {}
runner._running_agents = {}
runner._pending_messages = {}
runner._pending_approvals = {}
runner.adapters = {}
runner.hooks = MagicMock()
runner.hooks.emit = AsyncMock()
return runner
def _make_event(text="/reset"):
return MessageEvent(
text=text,
source=SessionSource(
platform=Platform.TELEGRAM,
chat_id="chat-1",
user_id="user-1",
user_name="alice",
),
)
class TestGatewayHonchoLifecycle:
def test_gateway_reuses_honcho_manager_for_session_key(self):
runner = _make_runner()
hcfg = SimpleNamespace(
enabled=True,
api_key="honcho-key",
ai_peer="hermes",
peer_name="alice",
context_tokens=123,
peer_memory_mode=lambda peer: "hybrid",
)
manager = MagicMock()
with (
patch("honcho_integration.client.HonchoClientConfig.from_global_config", return_value=hcfg),
patch("honcho_integration.client.get_honcho_client", return_value=MagicMock()),
patch("honcho_integration.session.HonchoSessionManager", return_value=manager) as mock_mgr_cls,
):
first_mgr, first_cfg = runner._get_or_create_gateway_honcho("session-key")
second_mgr, second_cfg = runner._get_or_create_gateway_honcho("session-key")
assert first_mgr is manager
assert second_mgr is manager
assert first_cfg is hcfg
assert second_cfg is hcfg
mock_mgr_cls.assert_called_once()
def test_gateway_skips_honcho_manager_when_disabled(self):
runner = _make_runner()
hcfg = SimpleNamespace(
enabled=False,
api_key="honcho-key",
ai_peer="hermes",
peer_name="alice",
)
with (
patch("honcho_integration.client.HonchoClientConfig.from_global_config", return_value=hcfg),
patch("honcho_integration.client.get_honcho_client") as mock_client,
patch("honcho_integration.session.HonchoSessionManager") as mock_mgr_cls,
):
manager, cfg = runner._get_or_create_gateway_honcho("session-key")
assert manager is None
assert cfg is hcfg
mock_client.assert_not_called()
mock_mgr_cls.assert_not_called()
@pytest.mark.asyncio
async def test_reset_shuts_down_gateway_honcho_manager(self):
runner = _make_runner()
event = _make_event()
runner._shutdown_gateway_honcho = MagicMock()
runner._async_flush_memories = AsyncMock()
runner.session_store = MagicMock()
runner.session_store._generate_session_key.return_value = "gateway-key"
runner.session_store._entries = {
"gateway-key": SimpleNamespace(session_id="old-session"),
}
runner.session_store.reset_session.return_value = SimpleNamespace(session_id="new-session")
result = await runner._handle_reset_command(event)
runner._shutdown_gateway_honcho.assert_called_once_with("gateway-key")
runner._async_flush_memories.assert_called_once_with("old-session", "gateway-key")
assert "Session reset" in result
def test_flush_memories_reuses_gateway_session_key_and_skips_honcho_sync(self):
runner = _make_runner()
runner.session_store = MagicMock()
runner.session_store.load_transcript.return_value = [
{"role": "user", "content": "a"},
{"role": "assistant", "content": "b"},
{"role": "user", "content": "c"},
{"role": "assistant", "content": "d"},
]
tmp_agent = MagicMock()
with (
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "test-key"}),
patch("gateway.run._resolve_gateway_model", return_value="model-name"),
patch("run_agent.AIAgent", return_value=tmp_agent) as mock_agent_cls,
):
runner._flush_memories_for_session("old-session", "gateway-key")
mock_agent_cls.assert_called_once()
_, kwargs = mock_agent_cls.call_args
assert kwargs["session_id"] == "old-session"
assert kwargs["honcho_session_key"] == "gateway-key"
tmp_agent.run_conversation.assert_called_once()
_, run_kwargs = tmp_agent.run_conversation.call_args
assert run_kwargs["sync_honcho"] is False

View file

@ -0,0 +1,217 @@
"""Tests for gateway/hooks.py — event hook system."""
import asyncio
from pathlib import Path
from unittest.mock import patch
import pytest
from gateway.hooks import HookRegistry
def _create_hook(hooks_dir, hook_name, events, handler_code):
"""Helper to create a hook directory with HOOK.yaml and handler.py."""
hook_dir = hooks_dir / hook_name
hook_dir.mkdir(parents=True)
(hook_dir / "HOOK.yaml").write_text(
f"name: {hook_name}\n"
f"description: Test hook\n"
f"events: {events}\n"
)
(hook_dir / "handler.py").write_text(handler_code)
return hook_dir
class TestHookRegistryInit:
def test_empty_registry(self):
reg = HookRegistry()
assert reg.loaded_hooks == []
assert reg._handlers == {}
class TestDiscoverAndLoad:
def test_loads_valid_hook(self, tmp_path):
_create_hook(tmp_path, "my-hook", '["agent:start"]',
"def handle(event_type, context):\n pass\n")
reg = HookRegistry()
with patch("gateway.hooks.HOOKS_DIR", tmp_path):
reg.discover_and_load()
assert len(reg.loaded_hooks) == 1
assert reg.loaded_hooks[0]["name"] == "my-hook"
assert "agent:start" in reg.loaded_hooks[0]["events"]
def test_skips_missing_hook_yaml(self, tmp_path):
hook_dir = tmp_path / "bad-hook"
hook_dir.mkdir()
(hook_dir / "handler.py").write_text("def handle(e, c): pass\n")
reg = HookRegistry()
with patch("gateway.hooks.HOOKS_DIR", tmp_path):
reg.discover_and_load()
assert len(reg.loaded_hooks) == 0
def test_skips_missing_handler_py(self, tmp_path):
hook_dir = tmp_path / "bad-hook"
hook_dir.mkdir()
(hook_dir / "HOOK.yaml").write_text("name: bad\nevents: ['agent:start']\n")
reg = HookRegistry()
with patch("gateway.hooks.HOOKS_DIR", tmp_path):
reg.discover_and_load()
assert len(reg.loaded_hooks) == 0
def test_skips_no_events(self, tmp_path):
hook_dir = tmp_path / "empty-hook"
hook_dir.mkdir()
(hook_dir / "HOOK.yaml").write_text("name: empty\nevents: []\n")
(hook_dir / "handler.py").write_text("def handle(e, c): pass\n")
reg = HookRegistry()
with patch("gateway.hooks.HOOKS_DIR", tmp_path):
reg.discover_and_load()
assert len(reg.loaded_hooks) == 0
def test_skips_no_handle_function(self, tmp_path):
hook_dir = tmp_path / "no-handle"
hook_dir.mkdir()
(hook_dir / "HOOK.yaml").write_text("name: no-handle\nevents: ['agent:start']\n")
(hook_dir / "handler.py").write_text("def something_else(): pass\n")
reg = HookRegistry()
with patch("gateway.hooks.HOOKS_DIR", tmp_path):
reg.discover_and_load()
assert len(reg.loaded_hooks) == 0
def test_nonexistent_hooks_dir(self, tmp_path):
reg = HookRegistry()
with patch("gateway.hooks.HOOKS_DIR", tmp_path / "nonexistent"):
reg.discover_and_load()
assert len(reg.loaded_hooks) == 0
def test_multiple_hooks(self, tmp_path):
_create_hook(tmp_path, "hook-a", '["agent:start"]',
"def handle(e, c): pass\n")
_create_hook(tmp_path, "hook-b", '["session:start", "session:reset"]',
"def handle(e, c): pass\n")
reg = HookRegistry()
with patch("gateway.hooks.HOOKS_DIR", tmp_path):
reg.discover_and_load()
assert len(reg.loaded_hooks) == 2
class TestEmit:
@pytest.mark.asyncio
async def test_emit_calls_sync_handler(self, tmp_path):
results = []
_create_hook(tmp_path, "sync-hook", '["agent:start"]',
"results = []\n"
"def handle(event_type, context):\n"
" results.append(event_type)\n")
reg = HookRegistry()
with patch("gateway.hooks.HOOKS_DIR", tmp_path):
reg.discover_and_load()
# Inject our results list into the handler's module globals
handler_fn = reg._handlers["agent:start"][0]
handler_fn.__globals__["results"] = results
await reg.emit("agent:start", {"test": True})
assert "agent:start" in results
@pytest.mark.asyncio
async def test_emit_calls_async_handler(self, tmp_path):
results = []
hook_dir = tmp_path / "async-hook"
hook_dir.mkdir()
(hook_dir / "HOOK.yaml").write_text(
"name: async-hook\nevents: ['agent:end']\n"
)
(hook_dir / "handler.py").write_text(
"import asyncio\n"
"results = []\n"
"async def handle(event_type, context):\n"
" results.append(event_type)\n"
)
reg = HookRegistry()
with patch("gateway.hooks.HOOKS_DIR", tmp_path):
reg.discover_and_load()
handler_fn = reg._handlers["agent:end"][0]
handler_fn.__globals__["results"] = results
await reg.emit("agent:end", {})
assert "agent:end" in results
@pytest.mark.asyncio
async def test_wildcard_matching(self, tmp_path):
results = []
_create_hook(tmp_path, "wildcard-hook", '["command:*"]',
"results = []\n"
"def handle(event_type, context):\n"
" results.append(event_type)\n")
reg = HookRegistry()
with patch("gateway.hooks.HOOKS_DIR", tmp_path):
reg.discover_and_load()
handler_fn = reg._handlers["command:*"][0]
handler_fn.__globals__["results"] = results
await reg.emit("command:reset", {})
assert "command:reset" in results
@pytest.mark.asyncio
async def test_no_handlers_for_event(self, tmp_path):
reg = HookRegistry()
# Should not raise and should have no handlers registered
result = await reg.emit("unknown:event", {})
assert result is None
assert not reg._handlers.get("unknown:event")
@pytest.mark.asyncio
async def test_handler_error_does_not_propagate(self, tmp_path):
_create_hook(tmp_path, "bad-hook", '["agent:start"]',
"def handle(event_type, context):\n"
" raise ValueError('boom')\n")
reg = HookRegistry()
with patch("gateway.hooks.HOOKS_DIR", tmp_path):
reg.discover_and_load()
assert len(reg._handlers.get("agent:start", [])) == 1
# Should not raise even though handler throws
result = await reg.emit("agent:start", {})
assert result is None
@pytest.mark.asyncio
async def test_emit_default_context(self, tmp_path):
captured = []
_create_hook(tmp_path, "ctx-hook", '["agent:start"]',
"captured = []\n"
"def handle(event_type, context):\n"
" captured.append(context)\n")
reg = HookRegistry()
with patch("gateway.hooks.HOOKS_DIR", tmp_path):
reg.discover_and_load()
handler_fn = reg._handlers["agent:start"][0]
handler_fn.__globals__["captured"] = captured
await reg.emit("agent:start") # no context arg
assert captured[0] == {}

View file

@ -0,0 +1,150 @@
"""Tests verifying interrupt key consistency between adapter and gateway.
Regression test for a bug where monitor_for_interrupt() in _run_agent used
source.chat_id to query the adapter, but the adapter stores interrupts under
the full session key (build_session_key output). This mismatch meant
interrupts were never detected, causing subagents to ignore new messages.
"""
import asyncio
import pytest
from gateway.config import Platform, PlatformConfig
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType, SendResult
from gateway.session import SessionSource, build_session_key
class StubAdapter(BasePlatformAdapter):
"""Minimal adapter for interrupt tests."""
def __init__(self):
super().__init__(PlatformConfig(enabled=True, token="test"), Platform.TELEGRAM)
async def connect(self):
return True
async def disconnect(self):
pass
async def send(self, chat_id, content, reply_to=None, metadata=None):
return SendResult(success=True, message_id="1")
async def send_typing(self, chat_id, metadata=None):
pass
async def get_chat_info(self, chat_id):
return {"id": chat_id}
def _source(chat_id="123456", chat_type="dm", thread_id=None):
return SessionSource(
platform=Platform.TELEGRAM,
chat_id=chat_id,
chat_type=chat_type,
thread_id=thread_id,
)
class TestInterruptKeyConsistency:
"""Ensure adapter interrupt methods are queried with session_key, not chat_id."""
def test_session_key_differs_from_chat_id_for_dm(self):
"""Session key for a DM is namespaced and includes the DM chat_id."""
source = _source("123456", "dm")
session_key = build_session_key(source)
assert session_key != source.chat_id
assert session_key == "agent:main:telegram:dm:123456"
def test_session_key_differs_from_chat_id_for_group(self):
"""Session key for a group chat includes prefix, unlike raw chat_id."""
source = _source("-1001234", "group")
session_key = build_session_key(source)
assert session_key != source.chat_id
assert "agent:main:" in session_key
assert source.chat_id in session_key
@pytest.mark.asyncio
async def test_has_pending_interrupt_requires_session_key(self):
"""has_pending_interrupt returns True only when queried with session_key."""
adapter = StubAdapter()
source = _source("123456", "dm")
session_key = build_session_key(source)
# Simulate adapter storing interrupt under session_key
interrupt_event = asyncio.Event()
adapter._active_sessions[session_key] = interrupt_event
interrupt_event.set()
# Using session_key → found
assert adapter.has_pending_interrupt(session_key) is True
# Using chat_id → NOT found (this was the bug)
assert adapter.has_pending_interrupt(source.chat_id) is False
@pytest.mark.asyncio
async def test_get_pending_message_requires_session_key(self):
"""get_pending_message returns the event only with session_key."""
adapter = StubAdapter()
source = _source("123456", "dm")
session_key = build_session_key(source)
event = MessageEvent(text="hello", source=source, message_id="42")
adapter._pending_messages[session_key] = event
# Using chat_id → None (the bug)
assert adapter.get_pending_message(source.chat_id) is None
# Using session_key → found
result = adapter.get_pending_message(session_key)
assert result is event
@pytest.mark.asyncio
async def test_handle_message_stores_under_session_key(self):
"""handle_message stores pending messages under session_key, not chat_id."""
adapter = StubAdapter()
adapter.set_message_handler(lambda event: asyncio.sleep(0, result=None))
source = _source("-1001234", "group")
session_key = build_session_key(source)
# Mark session as active
adapter._active_sessions[session_key] = asyncio.Event()
# Send a second message while session is active
event = MessageEvent(text="interrupt!", source=source, message_id="2")
await adapter.handle_message(event)
# Stored under session_key
assert session_key in adapter._pending_messages
# NOT stored under chat_id
assert source.chat_id not in adapter._pending_messages
# Interrupt event was set
assert adapter._active_sessions[session_key].is_set()
@pytest.mark.asyncio
async def test_photo_followup_is_queued_without_interrupt(self):
"""Photo follow-ups should queue behind the active run instead of interrupting it."""
adapter = StubAdapter()
adapter.set_message_handler(lambda event: asyncio.sleep(0, result=None))
source = _source("-1001234", "group")
session_key = build_session_key(source)
interrupt_event = asyncio.Event()
adapter._active_sessions[session_key] = interrupt_event
event = MessageEvent(
text="caption",
source=source,
message_type=MessageType.PHOTO,
message_id="2",
media_urls=["/tmp/photo-a.jpg"],
media_types=["image/jpeg"],
)
await adapter.handle_message(event)
queued = adapter._pending_messages[session_key]
assert queued is event
assert queued.media_urls == ["/tmp/photo-a.jpg"]
assert interrupt_event.is_set() is False

View file

@ -0,0 +1,448 @@
"""Tests for Matrix platform adapter."""
import json
import re
import pytest
from unittest.mock import MagicMock, patch, AsyncMock
from gateway.config import Platform, PlatformConfig
# ---------------------------------------------------------------------------
# Platform & Config
# ---------------------------------------------------------------------------
class TestMatrixPlatformEnum:
def test_matrix_enum_exists(self):
assert Platform.MATRIX.value == "matrix"
def test_matrix_in_platform_list(self):
platforms = [p.value for p in Platform]
assert "matrix" in platforms
class TestMatrixConfigLoading:
def test_apply_env_overrides_with_access_token(self, monkeypatch):
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_abc123")
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
assert Platform.MATRIX in config.platforms
mc = config.platforms[Platform.MATRIX]
assert mc.enabled is True
assert mc.token == "syt_abc123"
assert mc.extra.get("homeserver") == "https://matrix.example.org"
def test_apply_env_overrides_with_password(self, monkeypatch):
monkeypatch.delenv("MATRIX_ACCESS_TOKEN", raising=False)
monkeypatch.setenv("MATRIX_PASSWORD", "secret123")
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
monkeypatch.setenv("MATRIX_USER_ID", "@bot:example.org")
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
assert Platform.MATRIX in config.platforms
mc = config.platforms[Platform.MATRIX]
assert mc.enabled is True
assert mc.extra.get("password") == "secret123"
assert mc.extra.get("user_id") == "@bot:example.org"
def test_matrix_not_loaded_without_creds(self, monkeypatch):
monkeypatch.delenv("MATRIX_ACCESS_TOKEN", raising=False)
monkeypatch.delenv("MATRIX_PASSWORD", raising=False)
monkeypatch.delenv("MATRIX_HOMESERVER", raising=False)
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
assert Platform.MATRIX not in config.platforms
def test_matrix_encryption_flag(self, monkeypatch):
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_abc123")
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
monkeypatch.setenv("MATRIX_ENCRYPTION", "true")
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
mc = config.platforms[Platform.MATRIX]
assert mc.extra.get("encryption") is True
def test_matrix_encryption_default_off(self, monkeypatch):
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_abc123")
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
monkeypatch.delenv("MATRIX_ENCRYPTION", raising=False)
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
mc = config.platforms[Platform.MATRIX]
assert mc.extra.get("encryption") is False
def test_matrix_home_room(self, monkeypatch):
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_abc123")
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
monkeypatch.setenv("MATRIX_HOME_ROOM", "!room123:example.org")
monkeypatch.setenv("MATRIX_HOME_ROOM_NAME", "Bot Room")
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
home = config.get_home_channel(Platform.MATRIX)
assert home is not None
assert home.chat_id == "!room123:example.org"
assert home.name == "Bot Room"
def test_matrix_user_id_stored_in_extra(self, monkeypatch):
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_abc123")
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
monkeypatch.setenv("MATRIX_USER_ID", "@hermes:example.org")
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
mc = config.platforms[Platform.MATRIX]
assert mc.extra.get("user_id") == "@hermes:example.org"
# ---------------------------------------------------------------------------
# Adapter helpers
# ---------------------------------------------------------------------------
def _make_adapter():
"""Create a MatrixAdapter with mocked config."""
from gateway.platforms.matrix import MatrixAdapter
config = PlatformConfig(
enabled=True,
token="syt_test_token",
extra={
"homeserver": "https://matrix.example.org",
"user_id": "@bot:example.org",
},
)
adapter = MatrixAdapter(config)
return adapter
# ---------------------------------------------------------------------------
# mxc:// URL conversion
# ---------------------------------------------------------------------------
class TestMatrixMxcToHttp:
def setup_method(self):
self.adapter = _make_adapter()
def test_basic_mxc_conversion(self):
"""mxc://server/media_id should become an authenticated HTTP URL."""
mxc = "mxc://matrix.org/abc123"
result = self.adapter._mxc_to_http(mxc)
assert result == "https://matrix.example.org/_matrix/client/v1/media/download/matrix.org/abc123"
def test_mxc_with_different_server(self):
"""mxc:// from a different server should still use our homeserver."""
mxc = "mxc://other.server/media456"
result = self.adapter._mxc_to_http(mxc)
assert result.startswith("https://matrix.example.org/")
assert "other.server/media456" in result
def test_non_mxc_url_passthrough(self):
"""Non-mxc URLs should be returned unchanged."""
url = "https://example.com/image.png"
assert self.adapter._mxc_to_http(url) == url
def test_mxc_uses_client_v1_endpoint(self):
"""Should use /_matrix/client/v1/media/download/ not the deprecated path."""
mxc = "mxc://example.com/test123"
result = self.adapter._mxc_to_http(mxc)
assert "/_matrix/client/v1/media/download/" in result
assert "/_matrix/media/v3/download/" not in result
# ---------------------------------------------------------------------------
# DM detection
# ---------------------------------------------------------------------------
class TestMatrixDmDetection:
def setup_method(self):
self.adapter = _make_adapter()
def test_room_in_m_direct_is_dm(self):
"""A room listed in m.direct should be detected as DM."""
self.adapter._joined_rooms = {"!dm_room:ex.org", "!group_room:ex.org"}
self.adapter._dm_rooms = {
"!dm_room:ex.org": True,
"!group_room:ex.org": False,
}
assert self.adapter._dm_rooms.get("!dm_room:ex.org") is True
assert self.adapter._dm_rooms.get("!group_room:ex.org") is False
def test_unknown_room_not_in_cache(self):
"""Unknown rooms should not be in the DM cache."""
self.adapter._dm_rooms = {}
assert self.adapter._dm_rooms.get("!unknown:ex.org") is None
@pytest.mark.asyncio
async def test_refresh_dm_cache_with_m_direct(self):
"""_refresh_dm_cache should populate _dm_rooms from m.direct data."""
self.adapter._joined_rooms = {"!room_a:ex.org", "!room_b:ex.org", "!room_c:ex.org"}
mock_client = MagicMock()
mock_resp = MagicMock()
mock_resp.content = {
"@alice:ex.org": ["!room_a:ex.org"],
"@bob:ex.org": ["!room_b:ex.org"],
}
mock_client.get_account_data = AsyncMock(return_value=mock_resp)
self.adapter._client = mock_client
await self.adapter._refresh_dm_cache()
assert self.adapter._dm_rooms["!room_a:ex.org"] is True
assert self.adapter._dm_rooms["!room_b:ex.org"] is True
assert self.adapter._dm_rooms["!room_c:ex.org"] is False
# ---------------------------------------------------------------------------
# Reply fallback stripping
# ---------------------------------------------------------------------------
class TestMatrixReplyFallbackStripping:
"""Test that Matrix reply fallback lines ('> ' prefix) are stripped."""
def setup_method(self):
self.adapter = _make_adapter()
self.adapter._user_id = "@bot:example.org"
self.adapter._startup_ts = 0.0
self.adapter._dm_rooms = {}
self.adapter._message_handler = AsyncMock()
def _strip_fallback(self, body: str, has_reply: bool = True) -> str:
"""Simulate the reply fallback stripping logic from _on_room_message."""
reply_to = "some_event_id" if has_reply else None
if reply_to and body.startswith("> "):
lines = body.split("\n")
stripped = []
past_fallback = False
for line in lines:
if not past_fallback:
if line.startswith("> ") or line == ">":
continue
if line == "":
past_fallback = True
continue
past_fallback = True
stripped.append(line)
body = "\n".join(stripped) if stripped else body
return body
def test_simple_reply_fallback(self):
body = "> <@alice:ex.org> Original message\n\nActual reply"
result = self._strip_fallback(body)
assert result == "Actual reply"
def test_multiline_reply_fallback(self):
body = "> <@alice:ex.org> Line 1\n> Line 2\n\nMy response"
result = self._strip_fallback(body)
assert result == "My response"
def test_no_reply_fallback_preserved(self):
body = "Just a normal message"
result = self._strip_fallback(body, has_reply=False)
assert result == "Just a normal message"
def test_quote_without_reply_preserved(self):
"""'> ' lines without a reply_to context should be preserved."""
body = "> This is a blockquote"
result = self._strip_fallback(body, has_reply=False)
assert result == "> This is a blockquote"
def test_empty_fallback_separator(self):
"""The blank line between fallback and actual content should be stripped."""
body = "> <@alice:ex.org> hi\n>\n\nResponse"
result = self._strip_fallback(body)
assert result == "Response"
def test_multiline_response_after_fallback(self):
body = "> <@alice:ex.org> Original\n\nLine 1\nLine 2\nLine 3"
result = self._strip_fallback(body)
assert result == "Line 1\nLine 2\nLine 3"
# ---------------------------------------------------------------------------
# Thread detection
# ---------------------------------------------------------------------------
class TestMatrixThreadDetection:
def test_thread_id_from_m_relates_to(self):
"""m.relates_to with rel_type=m.thread should extract the event_id."""
relates_to = {
"rel_type": "m.thread",
"event_id": "$thread_root_event",
"is_falling_back": True,
"m.in_reply_to": {"event_id": "$some_event"},
}
# Simulate the extraction logic from _on_room_message
thread_id = None
if relates_to.get("rel_type") == "m.thread":
thread_id = relates_to.get("event_id")
assert thread_id == "$thread_root_event"
def test_no_thread_for_reply(self):
"""m.in_reply_to without m.thread should not set thread_id."""
relates_to = {
"m.in_reply_to": {"event_id": "$reply_event"},
}
thread_id = None
if relates_to.get("rel_type") == "m.thread":
thread_id = relates_to.get("event_id")
assert thread_id is None
def test_no_thread_for_edit(self):
"""m.replace relation should not set thread_id."""
relates_to = {
"rel_type": "m.replace",
"event_id": "$edited_event",
}
thread_id = None
if relates_to.get("rel_type") == "m.thread":
thread_id = relates_to.get("event_id")
assert thread_id is None
def test_empty_relates_to(self):
"""Empty m.relates_to should not set thread_id."""
relates_to = {}
thread_id = None
if relates_to.get("rel_type") == "m.thread":
thread_id = relates_to.get("event_id")
assert thread_id is None
# ---------------------------------------------------------------------------
# Format message
# ---------------------------------------------------------------------------
class TestMatrixFormatMessage:
def setup_method(self):
self.adapter = _make_adapter()
def test_image_markdown_stripped(self):
"""![alt](url) should be converted to just the URL."""
result = self.adapter.format_message("![cat](https://img.example.com/cat.png)")
assert result == "https://img.example.com/cat.png"
def test_regular_markdown_preserved(self):
"""Standard markdown should be preserved (Matrix supports it)."""
content = "**bold** and *italic* and `code`"
assert self.adapter.format_message(content) == content
def test_plain_text_unchanged(self):
content = "Hello, world!"
assert self.adapter.format_message(content) == content
def test_multiple_images_stripped(self):
content = "![a](http://a.com/1.png) and ![b](http://b.com/2.png)"
result = self.adapter.format_message(content)
assert "![" not in result
assert "http://a.com/1.png" in result
assert "http://b.com/2.png" in result
# ---------------------------------------------------------------------------
# Markdown to HTML conversion
# ---------------------------------------------------------------------------
class TestMatrixMarkdownToHtml:
def setup_method(self):
self.adapter = _make_adapter()
def test_bold_conversion(self):
"""**bold** should produce <strong> tags."""
result = self.adapter._markdown_to_html("**bold**")
assert "<strong>" in result or "<b>" in result
assert "bold" in result
def test_italic_conversion(self):
"""*italic* should produce <em> tags."""
result = self.adapter._markdown_to_html("*italic*")
assert "<em>" in result or "<i>" in result
def test_inline_code(self):
"""`code` should produce <code> tags."""
result = self.adapter._markdown_to_html("`code`")
assert "<code>" in result
def test_plain_text_returns_html(self):
"""Plain text should still be returned (possibly with <br> or <p>)."""
result = self.adapter._markdown_to_html("Hello world")
assert "Hello world" in result
# ---------------------------------------------------------------------------
# Helper: display name extraction
# ---------------------------------------------------------------------------
class TestMatrixDisplayName:
def setup_method(self):
self.adapter = _make_adapter()
def test_get_display_name_from_room_users(self):
"""Should get display name from room's users dict."""
mock_room = MagicMock()
mock_user = MagicMock()
mock_user.display_name = "Alice"
mock_room.users = {"@alice:ex.org": mock_user}
name = self.adapter._get_display_name(mock_room, "@alice:ex.org")
assert name == "Alice"
def test_get_display_name_fallback_to_localpart(self):
"""Should extract localpart from @user:server format."""
mock_room = MagicMock()
mock_room.users = {}
name = self.adapter._get_display_name(mock_room, "@bob:example.org")
assert name == "bob"
def test_get_display_name_no_room(self):
"""Should handle None room gracefully."""
name = self.adapter._get_display_name(None, "@charlie:ex.org")
assert name == "charlie"
# ---------------------------------------------------------------------------
# Requirements check
# ---------------------------------------------------------------------------
class TestMatrixRequirements:
def test_check_requirements_with_token(self, monkeypatch):
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_test")
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
from gateway.platforms.matrix import check_matrix_requirements
try:
import nio # noqa: F401
assert check_matrix_requirements() is True
except ImportError:
assert check_matrix_requirements() is False
def test_check_requirements_without_creds(self, monkeypatch):
monkeypatch.delenv("MATRIX_ACCESS_TOKEN", raising=False)
monkeypatch.delenv("MATRIX_PASSWORD", raising=False)
monkeypatch.delenv("MATRIX_HOMESERVER", raising=False)
from gateway.platforms.matrix import check_matrix_requirements
assert check_matrix_requirements() is False
def test_check_requirements_without_homeserver(self, monkeypatch):
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_test")
monkeypatch.delenv("MATRIX_HOMESERVER", raising=False)
from gateway.platforms.matrix import check_matrix_requirements
assert check_matrix_requirements() is False

View file

@ -0,0 +1,673 @@
"""Tests for Mattermost platform adapter."""
import json
import time
import pytest
from unittest.mock import MagicMock, patch, AsyncMock
from gateway.config import Platform, PlatformConfig
# ---------------------------------------------------------------------------
# Platform & Config
# ---------------------------------------------------------------------------
class TestMattermostPlatformEnum:
def test_mattermost_enum_exists(self):
assert Platform.MATTERMOST.value == "mattermost"
def test_mattermost_in_platform_list(self):
platforms = [p.value for p in Platform]
assert "mattermost" in platforms
class TestMattermostConfigLoading:
def test_apply_env_overrides_mattermost(self, monkeypatch):
monkeypatch.setenv("MATTERMOST_TOKEN", "mm-tok-abc123")
monkeypatch.setenv("MATTERMOST_URL", "https://mm.example.com")
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
assert Platform.MATTERMOST in config.platforms
mc = config.platforms[Platform.MATTERMOST]
assert mc.enabled is True
assert mc.token == "mm-tok-abc123"
assert mc.extra.get("url") == "https://mm.example.com"
def test_mattermost_not_loaded_without_token(self, monkeypatch):
monkeypatch.delenv("MATTERMOST_TOKEN", raising=False)
monkeypatch.delenv("MATTERMOST_URL", raising=False)
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
assert Platform.MATTERMOST not in config.platforms
def test_connected_platforms_includes_mattermost(self, monkeypatch):
monkeypatch.setenv("MATTERMOST_TOKEN", "mm-tok-abc123")
monkeypatch.setenv("MATTERMOST_URL", "https://mm.example.com")
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
connected = config.get_connected_platforms()
assert Platform.MATTERMOST in connected
def test_mattermost_home_channel(self, monkeypatch):
monkeypatch.setenv("MATTERMOST_TOKEN", "mm-tok-abc123")
monkeypatch.setenv("MATTERMOST_URL", "https://mm.example.com")
monkeypatch.setenv("MATTERMOST_HOME_CHANNEL", "ch_abc123")
monkeypatch.setenv("MATTERMOST_HOME_CHANNEL_NAME", "General")
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
home = config.get_home_channel(Platform.MATTERMOST)
assert home is not None
assert home.chat_id == "ch_abc123"
assert home.name == "General"
def test_mattermost_url_warning_without_url(self, monkeypatch):
"""MATTERMOST_TOKEN set but MATTERMOST_URL missing should still load."""
monkeypatch.setenv("MATTERMOST_TOKEN", "mm-tok-abc123")
monkeypatch.delenv("MATTERMOST_URL", raising=False)
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
assert Platform.MATTERMOST in config.platforms
assert config.platforms[Platform.MATTERMOST].extra.get("url") == ""
# ---------------------------------------------------------------------------
# Adapter format / truncate
# ---------------------------------------------------------------------------
def _make_adapter():
"""Create a MattermostAdapter with mocked config."""
from gateway.platforms.mattermost import MattermostAdapter
config = PlatformConfig(
enabled=True,
token="test-token",
extra={"url": "https://mm.example.com"},
)
adapter = MattermostAdapter(config)
return adapter
class TestMattermostFormatMessage:
def setup_method(self):
self.adapter = _make_adapter()
def test_image_markdown_to_url(self):
"""![alt](url) should be converted to just the URL."""
result = self.adapter.format_message("![cat](https://img.example.com/cat.png)")
assert result == "https://img.example.com/cat.png"
def test_image_markdown_strips_alt_text(self):
result = self.adapter.format_message("Here: ![my image](https://x.com/a.jpg) done")
assert "![" not in result
assert "https://x.com/a.jpg" in result
def test_regular_markdown_preserved(self):
"""Regular markdown (bold, italic, code) should be kept as-is."""
content = "**bold** and *italic* and `code`"
assert self.adapter.format_message(content) == content
def test_regular_links_preserved(self):
"""Non-image links should be preserved."""
content = "[click](https://example.com)"
assert self.adapter.format_message(content) == content
def test_plain_text_unchanged(self):
content = "Hello, world!"
assert self.adapter.format_message(content) == content
def test_multiple_images(self):
content = "![a](http://a.com/1.png) text ![b](http://b.com/2.png)"
result = self.adapter.format_message(content)
assert "![" not in result
assert "http://a.com/1.png" in result
assert "http://b.com/2.png" in result
class TestMattermostTruncateMessage:
def setup_method(self):
self.adapter = _make_adapter()
def test_short_message_single_chunk(self):
msg = "Hello, world!"
chunks = self.adapter.truncate_message(msg, 4000)
assert len(chunks) == 1
assert chunks[0] == msg
def test_long_message_splits(self):
msg = "a " * 2500 # 5000 chars
chunks = self.adapter.truncate_message(msg, 4000)
assert len(chunks) >= 2
for chunk in chunks:
assert len(chunk) <= 4000
def test_custom_max_length(self):
msg = "Hello " * 20
chunks = self.adapter.truncate_message(msg, max_length=50)
assert all(len(c) <= 50 for c in chunks)
def test_exactly_at_limit(self):
msg = "x" * 4000
chunks = self.adapter.truncate_message(msg, 4000)
assert len(chunks) == 1
# ---------------------------------------------------------------------------
# Send
# ---------------------------------------------------------------------------
class TestMattermostSend:
def setup_method(self):
self.adapter = _make_adapter()
self.adapter._session = MagicMock()
@pytest.mark.asyncio
async def test_send_calls_api_post(self):
"""send() should POST to /api/v4/posts with channel_id and message."""
mock_resp = AsyncMock()
mock_resp.status = 200
mock_resp.json = AsyncMock(return_value={"id": "post123"})
mock_resp.text = AsyncMock(return_value="")
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
mock_resp.__aexit__ = AsyncMock(return_value=False)
self.adapter._session.post = MagicMock(return_value=mock_resp)
result = await self.adapter.send("channel_1", "Hello!")
assert result.success is True
assert result.message_id == "post123"
# Verify post was called with correct URL
call_args = self.adapter._session.post.call_args
assert "/api/v4/posts" in call_args[0][0]
# Verify payload
payload = call_args[1]["json"]
assert payload["channel_id"] == "channel_1"
assert payload["message"] == "Hello!"
@pytest.mark.asyncio
async def test_send_empty_content_succeeds(self):
"""Empty content should return success without calling the API."""
result = await self.adapter.send("channel_1", "")
assert result.success is True
@pytest.mark.asyncio
async def test_send_with_thread_reply(self):
"""When reply_mode is 'thread', reply_to should become root_id."""
self.adapter._reply_mode = "thread"
mock_resp = AsyncMock()
mock_resp.status = 200
mock_resp.json = AsyncMock(return_value={"id": "post456"})
mock_resp.text = AsyncMock(return_value="")
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
mock_resp.__aexit__ = AsyncMock(return_value=False)
self.adapter._session.post = MagicMock(return_value=mock_resp)
result = await self.adapter.send("channel_1", "Reply!", reply_to="root_post")
assert result.success is True
payload = self.adapter._session.post.call_args[1]["json"]
assert payload["root_id"] == "root_post"
@pytest.mark.asyncio
async def test_send_without_thread_no_root_id(self):
"""When reply_mode is 'off', reply_to should NOT set root_id."""
self.adapter._reply_mode = "off"
mock_resp = AsyncMock()
mock_resp.status = 200
mock_resp.json = AsyncMock(return_value={"id": "post789"})
mock_resp.text = AsyncMock(return_value="")
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
mock_resp.__aexit__ = AsyncMock(return_value=False)
self.adapter._session.post = MagicMock(return_value=mock_resp)
result = await self.adapter.send("channel_1", "Reply!", reply_to="root_post")
assert result.success is True
payload = self.adapter._session.post.call_args[1]["json"]
assert "root_id" not in payload
@pytest.mark.asyncio
async def test_send_api_failure(self):
"""When API returns error, send should return failure."""
mock_resp = AsyncMock()
mock_resp.status = 500
mock_resp.json = AsyncMock(return_value={})
mock_resp.text = AsyncMock(return_value="Internal Server Error")
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
mock_resp.__aexit__ = AsyncMock(return_value=False)
self.adapter._session.post = MagicMock(return_value=mock_resp)
result = await self.adapter.send("channel_1", "Hello!")
assert result.success is False
# ---------------------------------------------------------------------------
# WebSocket event parsing
# ---------------------------------------------------------------------------
class TestMattermostWebSocketParsing:
def setup_method(self):
self.adapter = _make_adapter()
self.adapter._bot_user_id = "bot_user_id"
# Mock handle_message to capture the MessageEvent without processing
self.adapter.handle_message = AsyncMock()
@pytest.mark.asyncio
async def test_parse_posted_event(self):
"""'posted' events should extract message from double-encoded post JSON."""
post_data = {
"id": "post_abc",
"user_id": "user_123",
"channel_id": "chan_456",
"message": "@bot_user_id Hello from Matrix!",
}
event = {
"event": "posted",
"data": {
"post": json.dumps(post_data), # double-encoded JSON string
"channel_type": "O",
"sender_name": "@alice",
},
}
await self.adapter._handle_ws_event(event)
assert self.adapter.handle_message.called
msg_event = self.adapter.handle_message.call_args[0][0]
assert msg_event.text == "@bot_user_id Hello from Matrix!"
assert msg_event.message_id == "post_abc"
@pytest.mark.asyncio
async def test_ignore_own_messages(self):
"""Messages from the bot's own user_id should be ignored."""
post_data = {
"id": "post_self",
"user_id": "bot_user_id", # same as bot
"channel_id": "chan_456",
"message": "Bot echo",
}
event = {
"event": "posted",
"data": {
"post": json.dumps(post_data),
"channel_type": "O",
},
}
await self.adapter._handle_ws_event(event)
assert not self.adapter.handle_message.called
@pytest.mark.asyncio
async def test_ignore_non_posted_events(self):
"""Non-'posted' events should be ignored."""
event = {
"event": "typing",
"data": {"user_id": "user_123"},
}
await self.adapter._handle_ws_event(event)
assert not self.adapter.handle_message.called
@pytest.mark.asyncio
async def test_ignore_system_posts(self):
"""Posts with a 'type' field (system messages) should be ignored."""
post_data = {
"id": "sys_post",
"user_id": "user_123",
"channel_id": "chan_456",
"message": "user joined",
"type": "system_join_channel",
}
event = {
"event": "posted",
"data": {
"post": json.dumps(post_data),
"channel_type": "O",
},
}
await self.adapter._handle_ws_event(event)
assert not self.adapter.handle_message.called
@pytest.mark.asyncio
async def test_channel_type_mapping(self):
"""channel_type 'D' should map to 'dm'."""
post_data = {
"id": "post_dm",
"user_id": "user_123",
"channel_id": "chan_dm",
"message": "DM message",
}
event = {
"event": "posted",
"data": {
"post": json.dumps(post_data),
"channel_type": "D",
"sender_name": "@bob",
},
}
await self.adapter._handle_ws_event(event)
assert self.adapter.handle_message.called
msg_event = self.adapter.handle_message.call_args[0][0]
assert msg_event.source.chat_type == "dm"
@pytest.mark.asyncio
async def test_thread_id_from_root_id(self):
"""Post with root_id should have thread_id set."""
post_data = {
"id": "post_reply",
"user_id": "user_123",
"channel_id": "chan_456",
"message": "@bot_user_id Thread reply",
"root_id": "root_post_123",
}
event = {
"event": "posted",
"data": {
"post": json.dumps(post_data),
"channel_type": "O",
"sender_name": "@alice",
},
}
await self.adapter._handle_ws_event(event)
assert self.adapter.handle_message.called
msg_event = self.adapter.handle_message.call_args[0][0]
assert msg_event.source.thread_id == "root_post_123"
@pytest.mark.asyncio
async def test_invalid_post_json_ignored(self):
"""Invalid JSON in data.post should be silently ignored."""
event = {
"event": "posted",
"data": {
"post": "not-valid-json{{{",
"channel_type": "O",
},
}
await self.adapter._handle_ws_event(event)
assert not self.adapter.handle_message.called
# ---------------------------------------------------------------------------
# File upload (send_image)
# ---------------------------------------------------------------------------
class TestMattermostFileUpload:
def setup_method(self):
self.adapter = _make_adapter()
self.adapter._session = MagicMock()
@pytest.mark.asyncio
async def test_send_image_downloads_and_uploads(self):
"""send_image should download the URL, upload via /api/v4/files, then post."""
# Mock the download (GET)
mock_dl_resp = AsyncMock()
mock_dl_resp.status = 200
mock_dl_resp.read = AsyncMock(return_value=b"\x89PNG\x00fake-image-data")
mock_dl_resp.content_type = "image/png"
mock_dl_resp.__aenter__ = AsyncMock(return_value=mock_dl_resp)
mock_dl_resp.__aexit__ = AsyncMock(return_value=False)
# Mock the upload (POST to /files)
mock_upload_resp = AsyncMock()
mock_upload_resp.status = 200
mock_upload_resp.json = AsyncMock(return_value={
"file_infos": [{"id": "file_abc123"}]
})
mock_upload_resp.text = AsyncMock(return_value="")
mock_upload_resp.__aenter__ = AsyncMock(return_value=mock_upload_resp)
mock_upload_resp.__aexit__ = AsyncMock(return_value=False)
# Mock the post (POST to /posts)
mock_post_resp = AsyncMock()
mock_post_resp.status = 200
mock_post_resp.json = AsyncMock(return_value={"id": "post_with_file"})
mock_post_resp.text = AsyncMock(return_value="")
mock_post_resp.__aenter__ = AsyncMock(return_value=mock_post_resp)
mock_post_resp.__aexit__ = AsyncMock(return_value=False)
# Route calls: first GET (download), then POST (upload), then POST (create post)
self.adapter._session.get = MagicMock(return_value=mock_dl_resp)
post_call_count = 0
original_post_returns = [mock_upload_resp, mock_post_resp]
def post_side_effect(*args, **kwargs):
nonlocal post_call_count
resp = original_post_returns[min(post_call_count, len(original_post_returns) - 1)]
post_call_count += 1
return resp
self.adapter._session.post = MagicMock(side_effect=post_side_effect)
result = await self.adapter.send_image(
"channel_1", "https://img.example.com/cat.png", caption="A cat"
)
assert result.success is True
assert result.message_id == "post_with_file"
# ---------------------------------------------------------------------------
# Dedup cache
# ---------------------------------------------------------------------------
class TestMattermostDedup:
def setup_method(self):
self.adapter = _make_adapter()
self.adapter._bot_user_id = "bot_user_id"
# Mock handle_message to capture calls without processing
self.adapter.handle_message = AsyncMock()
@pytest.mark.asyncio
async def test_duplicate_post_ignored(self):
"""The same post_id within the TTL window should be ignored."""
post_data = {
"id": "post_dup",
"user_id": "user_123",
"channel_id": "chan_456",
"message": "@bot_user_id Hello!",
}
event = {
"event": "posted",
"data": {
"post": json.dumps(post_data),
"channel_type": "O",
"sender_name": "@alice",
},
}
# First time: should process
await self.adapter._handle_ws_event(event)
assert self.adapter.handle_message.call_count == 1
# Second time (same post_id): should be deduped
await self.adapter._handle_ws_event(event)
assert self.adapter.handle_message.call_count == 1 # still 1
@pytest.mark.asyncio
async def test_different_post_ids_both_processed(self):
"""Different post IDs should both be processed."""
for i, pid in enumerate(["post_a", "post_b"]):
post_data = {
"id": pid,
"user_id": "user_123",
"channel_id": "chan_456",
"message": f"@bot_user_id Message {i}",
}
event = {
"event": "posted",
"data": {
"post": json.dumps(post_data),
"channel_type": "O",
"sender_name": "@alice",
},
}
await self.adapter._handle_ws_event(event)
assert self.adapter.handle_message.call_count == 2
def test_prune_seen_clears_expired(self):
"""_prune_seen should remove entries older than _SEEN_TTL."""
now = time.time()
# Fill with enough expired entries to trigger pruning
for i in range(self.adapter._SEEN_MAX + 10):
self.adapter._seen_posts[f"old_{i}"] = now - 600 # 10 min ago
# Add a fresh one
self.adapter._seen_posts["fresh"] = now
self.adapter._prune_seen()
# Old entries should be pruned, fresh one kept
assert "fresh" in self.adapter._seen_posts
assert len(self.adapter._seen_posts) < self.adapter._SEEN_MAX
def test_seen_cache_tracks_post_ids(self):
"""Posts are tracked in _seen_posts dict."""
self.adapter._seen_posts["test_post"] = time.time()
assert "test_post" in self.adapter._seen_posts
# ---------------------------------------------------------------------------
# Requirements check
# ---------------------------------------------------------------------------
class TestMattermostRequirements:
def test_check_requirements_with_token_and_url(self, monkeypatch):
monkeypatch.setenv("MATTERMOST_TOKEN", "test-token")
monkeypatch.setenv("MATTERMOST_URL", "https://mm.example.com")
from gateway.platforms.mattermost import check_mattermost_requirements
assert check_mattermost_requirements() is True
def test_check_requirements_without_token(self, monkeypatch):
monkeypatch.delenv("MATTERMOST_TOKEN", raising=False)
monkeypatch.delenv("MATTERMOST_URL", raising=False)
from gateway.platforms.mattermost import check_mattermost_requirements
assert check_mattermost_requirements() is False
def test_check_requirements_without_url(self, monkeypatch):
monkeypatch.setenv("MATTERMOST_TOKEN", "test-token")
monkeypatch.delenv("MATTERMOST_URL", raising=False)
from gateway.platforms.mattermost import check_mattermost_requirements
assert check_mattermost_requirements() is False
# ---------------------------------------------------------------------------
# Media type propagation (MIME types, not bare strings)
# ---------------------------------------------------------------------------
class TestMattermostMediaTypes:
"""Verify that media_types contains actual MIME types (e.g. 'image/png')
rather than bare category strings ('image'), so downstream
``mtype.startswith("image/")`` checks in run.py work correctly."""
def setup_method(self):
self.adapter = _make_adapter()
self.adapter._bot_user_id = "bot_user_id"
self.adapter.handle_message = AsyncMock()
def _make_event(self, file_ids):
post_data = {
"id": "post_media",
"user_id": "user_123",
"channel_id": "chan_456",
"message": "@bot_user_id file attached",
"file_ids": file_ids,
}
return {
"event": "posted",
"data": {
"post": json.dumps(post_data),
"channel_type": "O",
"sender_name": "@alice",
},
}
@pytest.mark.asyncio
async def test_image_media_type_is_full_mime(self):
"""An image attachment should produce 'image/png', not 'image'."""
file_info = {"name": "photo.png", "mime_type": "image/png"}
self.adapter._api_get = AsyncMock(return_value=file_info)
mock_resp = AsyncMock()
mock_resp.status = 200
mock_resp.read = AsyncMock(return_value=b"\x89PNG fake")
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
mock_resp.__aexit__ = AsyncMock(return_value=False)
self.adapter._session = MagicMock()
self.adapter._session.get = MagicMock(return_value=mock_resp)
with patch("gateway.platforms.base.cache_image_from_bytes", return_value="/tmp/photo.png"):
await self.adapter._handle_ws_event(self._make_event(["file1"]))
msg = self.adapter.handle_message.call_args[0][0]
assert msg.media_types == ["image/png"]
assert msg.media_types[0].startswith("image/")
@pytest.mark.asyncio
async def test_audio_media_type_is_full_mime(self):
"""An audio attachment should produce 'audio/ogg', not 'audio'."""
file_info = {"name": "voice.ogg", "mime_type": "audio/ogg"}
self.adapter._api_get = AsyncMock(return_value=file_info)
mock_resp = AsyncMock()
mock_resp.status = 200
mock_resp.read = AsyncMock(return_value=b"OGG fake")
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
mock_resp.__aexit__ = AsyncMock(return_value=False)
self.adapter._session = MagicMock()
self.adapter._session.get = MagicMock(return_value=mock_resp)
with patch("gateway.platforms.base.cache_audio_from_bytes", return_value="/tmp/voice.ogg"), \
patch("gateway.platforms.base.cache_image_from_bytes"), \
patch("gateway.platforms.base.cache_document_from_bytes"):
await self.adapter._handle_ws_event(self._make_event(["file2"]))
msg = self.adapter.handle_message.call_args[0][0]
assert msg.media_types == ["audio/ogg"]
assert msg.media_types[0].startswith("audio/")
@pytest.mark.asyncio
async def test_document_media_type_is_full_mime(self):
"""A document attachment should produce 'application/pdf', not 'document'."""
file_info = {"name": "report.pdf", "mime_type": "application/pdf"}
self.adapter._api_get = AsyncMock(return_value=file_info)
mock_resp = AsyncMock()
mock_resp.status = 200
mock_resp.read = AsyncMock(return_value=b"PDF fake")
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
mock_resp.__aexit__ = AsyncMock(return_value=False)
self.adapter._session = MagicMock()
self.adapter._session.get = MagicMock(return_value=mock_resp)
with patch("gateway.platforms.base.cache_document_from_bytes", return_value="/tmp/report.pdf"), \
patch("gateway.platforms.base.cache_image_from_bytes"):
await self.adapter._handle_ws_event(self._make_event(["file3"]))
msg = self.adapter.handle_message.call_args[0][0]
assert msg.media_types == ["application/pdf"]
assert not msg.media_types[0].startswith("image/")
assert not msg.media_types[0].startswith("audio/")

View file

@ -0,0 +1,184 @@
"""
Tests for MEDIA tag extraction from tool results.
Verifies that MEDIA tags (e.g., from TTS tool) are only extracted from
messages in the CURRENT turn, not from the full conversation history.
This prevents voice messages from accumulating and being sent multiple
times per reply. (Regression test for #160)
"""
import pytest
import re
def extract_media_tags_fixed(result_messages, history_len):
"""
Extract MEDIA tags from tool results, but ONLY from new messages
(those added after history_len). This is the fixed behavior.
Args:
result_messages: Full list of messages including history + new
history_len: Length of history before this turn
Returns:
Tuple of (media_tags list, has_voice_directive bool)
"""
media_tags = []
has_voice_directive = False
# Only process new messages from this turn
new_messages = result_messages[history_len:] if len(result_messages) > history_len else []
for msg in new_messages:
if msg.get("role") == "tool" or msg.get("role") == "function":
content = msg.get("content", "")
if "MEDIA:" in content:
for match in re.finditer(r'MEDIA:(\S+)', content):
path = match.group(1).strip().rstrip('",}')
if path:
media_tags.append(f"MEDIA:{path}")
if "[[audio_as_voice]]" in content:
has_voice_directive = True
return media_tags, has_voice_directive
def extract_media_tags_broken(result_messages):
"""
The BROKEN behavior: extract MEDIA tags from ALL messages including history.
This causes TTS voice messages to accumulate and be re-sent on every reply.
"""
media_tags = []
has_voice_directive = False
for msg in result_messages:
if msg.get("role") == "tool" or msg.get("role") == "function":
content = msg.get("content", "")
if "MEDIA:" in content:
for match in re.finditer(r'MEDIA:(\S+)', content):
path = match.group(1).strip().rstrip('",}')
if path:
media_tags.append(f"MEDIA:{path}")
if "[[audio_as_voice]]" in content:
has_voice_directive = True
return media_tags, has_voice_directive
class TestMediaExtraction:
"""Tests for MEDIA tag extraction from tool results."""
def test_media_tags_not_extracted_from_history(self):
"""MEDIA tags from previous turns should NOT be extracted again."""
# Simulate conversation history with a TTS call from a previous turn
history = [
{"role": "user", "content": "Say hello as audio"},
{"role": "assistant", "content": None, "tool_calls": [{"id": "1", "function": {"name": "text_to_speech"}}]},
{"role": "tool", "tool_call_id": "1", "content": '{"success": true, "media_tag": "[[audio_as_voice]]\\nMEDIA:/path/to/audio1.ogg"}'},
{"role": "assistant", "content": "I've said hello for you!"},
]
# New turn: user asks a simple question
new_messages = [
{"role": "user", "content": "What time is it?"},
{"role": "assistant", "content": "It's 3:30 AM."},
]
all_messages = history + new_messages
history_len = len(history)
# Fixed behavior: should extract NO media tags (none in new messages)
tags, voice_directive = extract_media_tags_fixed(all_messages, history_len)
assert tags == [], "Fixed extraction should not find tags in history"
assert voice_directive is False
# Broken behavior: would incorrectly extract the old media tag
broken_tags, broken_voice = extract_media_tags_broken(all_messages)
assert len(broken_tags) == 1, "Broken extraction finds tags in history"
assert "audio1.ogg" in broken_tags[0]
def test_media_tags_extracted_from_current_turn(self):
"""MEDIA tags from the current turn SHOULD be extracted."""
# History without TTS
history = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
]
# New turn with TTS call
new_messages = [
{"role": "user", "content": "Say goodbye as audio"},
{"role": "assistant", "content": None, "tool_calls": [{"id": "2", "function": {"name": "text_to_speech"}}]},
{"role": "tool", "tool_call_id": "2", "content": '{"success": true, "media_tag": "[[audio_as_voice]]\\nMEDIA:/path/to/audio2.ogg"}'},
{"role": "assistant", "content": "I've said goodbye!"},
]
all_messages = history + new_messages
history_len = len(history)
# Fixed behavior: should extract the new media tag
tags, voice_directive = extract_media_tags_fixed(all_messages, history_len)
assert len(tags) == 1, "Should extract media tag from current turn"
assert "audio2.ogg" in tags[0]
assert voice_directive is True
def test_multiple_tts_calls_in_history_not_accumulated(self):
"""Multiple TTS calls in history should NOT accumulate in new responses."""
# History with multiple TTS calls
history = [
{"role": "user", "content": "Say hello"},
{"role": "tool", "tool_call_id": "1", "content": 'MEDIA:/audio/hello.ogg'},
{"role": "assistant", "content": "Done!"},
{"role": "user", "content": "Say goodbye"},
{"role": "tool", "tool_call_id": "2", "content": 'MEDIA:/audio/goodbye.ogg'},
{"role": "assistant", "content": "Done!"},
{"role": "user", "content": "Say thanks"},
{"role": "tool", "tool_call_id": "3", "content": 'MEDIA:/audio/thanks.ogg'},
{"role": "assistant", "content": "Done!"},
]
# New turn: no TTS
new_messages = [
{"role": "user", "content": "What time is it?"},
{"role": "assistant", "content": "3 PM"},
]
all_messages = history + new_messages
history_len = len(history)
# Fixed: no tags
tags, _ = extract_media_tags_fixed(all_messages, history_len)
assert tags == [], "Should not accumulate tags from history"
# Broken: would have 3 tags (all the old ones)
broken_tags, _ = extract_media_tags_broken(all_messages)
assert len(broken_tags) == 3, "Broken version accumulates all history tags"
def test_deduplication_within_current_turn(self):
"""Multiple MEDIA tags in current turn should be deduplicated."""
history = []
# Current turn with multiple tool calls producing same media
new_messages = [
{"role": "user", "content": "Multiple TTS"},
{"role": "tool", "tool_call_id": "1", "content": 'MEDIA:/audio/same.ogg'},
{"role": "tool", "tool_call_id": "2", "content": 'MEDIA:/audio/same.ogg'}, # duplicate
{"role": "tool", "tool_call_id": "3", "content": 'MEDIA:/audio/different.ogg'},
{"role": "assistant", "content": "Done!"},
]
all_messages = history + new_messages
tags, _ = extract_media_tags_fixed(all_messages, 0)
# Even though same.ogg appears twice, deduplication happens after extraction
# The extraction itself should get both, then caller deduplicates
assert len(tags) == 3 # Raw extraction gets all
# Deduplication as done in the actual code:
seen = set()
unique = [t for t in tags if t not in seen and not seen.add(t)]
assert len(unique) == 2 # After dedup: same.ogg and different.ogg
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View file

@ -0,0 +1,229 @@
"""Tests for gateway/mirror.py — session mirroring."""
import json
from pathlib import Path
from unittest.mock import patch, MagicMock
import gateway.mirror as mirror_mod
from gateway.mirror import (
mirror_to_session,
_find_session_id,
_append_to_jsonl,
)
def _setup_sessions(tmp_path, sessions_data):
"""Helper to write a fake sessions.json and patch module-level paths."""
sessions_dir = tmp_path / "sessions"
sessions_dir.mkdir(parents=True, exist_ok=True)
index_file = sessions_dir / "sessions.json"
index_file.write_text(json.dumps(sessions_data))
return sessions_dir, index_file
class TestFindSessionId:
def test_finds_matching_session(self, tmp_path):
sessions_dir, index_file = _setup_sessions(tmp_path, {
"agent:main:telegram:dm": {
"session_id": "sess_abc",
"origin": {"platform": "telegram", "chat_id": "12345"},
"updated_at": "2026-01-01T00:00:00",
}
})
with patch.object(mirror_mod, "_SESSIONS_DIR", sessions_dir), \
patch.object(mirror_mod, "_SESSIONS_INDEX", index_file):
result = _find_session_id("telegram", "12345")
assert result == "sess_abc"
def test_returns_most_recent(self, tmp_path):
sessions_dir, index_file = _setup_sessions(tmp_path, {
"old": {
"session_id": "sess_old",
"origin": {"platform": "telegram", "chat_id": "12345"},
"updated_at": "2026-01-01T00:00:00",
},
"new": {
"session_id": "sess_new",
"origin": {"platform": "telegram", "chat_id": "12345"},
"updated_at": "2026-02-01T00:00:00",
},
})
with patch.object(mirror_mod, "_SESSIONS_DIR", sessions_dir), \
patch.object(mirror_mod, "_SESSIONS_INDEX", index_file):
result = _find_session_id("telegram", "12345")
assert result == "sess_new"
def test_thread_id_disambiguates_same_chat(self, tmp_path):
sessions_dir, index_file = _setup_sessions(tmp_path, {
"topic_a": {
"session_id": "sess_topic_a",
"origin": {"platform": "telegram", "chat_id": "-1001", "thread_id": "10"},
"updated_at": "2026-01-01T00:00:00",
},
"topic_b": {
"session_id": "sess_topic_b",
"origin": {"platform": "telegram", "chat_id": "-1001", "thread_id": "11"},
"updated_at": "2026-02-01T00:00:00",
},
})
with patch.object(mirror_mod, "_SESSIONS_DIR", sessions_dir), \
patch.object(mirror_mod, "_SESSIONS_INDEX", index_file):
result = _find_session_id("telegram", "-1001", thread_id="10")
assert result == "sess_topic_a"
def test_no_match_returns_none(self, tmp_path):
sessions_dir, index_file = _setup_sessions(tmp_path, {
"sess": {
"session_id": "sess_1",
"origin": {"platform": "discord", "chat_id": "999"},
"updated_at": "2026-01-01T00:00:00",
}
})
with patch.object(mirror_mod, "_SESSIONS_INDEX", index_file):
result = _find_session_id("telegram", "12345")
assert result is None
def test_missing_sessions_file(self, tmp_path):
with patch.object(mirror_mod, "_SESSIONS_INDEX", tmp_path / "nope.json"):
result = _find_session_id("telegram", "12345")
assert result is None
def test_platform_case_insensitive(self, tmp_path):
sessions_dir, index_file = _setup_sessions(tmp_path, {
"s1": {
"session_id": "sess_1",
"origin": {"platform": "Telegram", "chat_id": "123"},
"updated_at": "2026-01-01T00:00:00",
}
})
with patch.object(mirror_mod, "_SESSIONS_INDEX", index_file):
result = _find_session_id("telegram", "123")
assert result == "sess_1"
class TestAppendToJsonl:
def test_appends_message(self, tmp_path):
sessions_dir = tmp_path / "sessions"
sessions_dir.mkdir()
with patch.object(mirror_mod, "_SESSIONS_DIR", sessions_dir):
_append_to_jsonl("sess_1", {"role": "assistant", "content": "Hello"})
transcript = sessions_dir / "sess_1.jsonl"
lines = transcript.read_text().strip().splitlines()
assert len(lines) == 1
msg = json.loads(lines[0])
assert msg["role"] == "assistant"
assert msg["content"] == "Hello"
def test_appends_multiple_messages(self, tmp_path):
sessions_dir = tmp_path / "sessions"
sessions_dir.mkdir()
with patch.object(mirror_mod, "_SESSIONS_DIR", sessions_dir):
_append_to_jsonl("sess_1", {"role": "assistant", "content": "msg1"})
_append_to_jsonl("sess_1", {"role": "assistant", "content": "msg2"})
transcript = sessions_dir / "sess_1.jsonl"
lines = transcript.read_text().strip().splitlines()
assert len(lines) == 2
class TestMirrorToSession:
def test_successful_mirror(self, tmp_path):
sessions_dir, index_file = _setup_sessions(tmp_path, {
"s1": {
"session_id": "sess_abc",
"origin": {"platform": "telegram", "chat_id": "12345"},
"updated_at": "2026-01-01T00:00:00",
}
})
with patch.object(mirror_mod, "_SESSIONS_DIR", sessions_dir), \
patch.object(mirror_mod, "_SESSIONS_INDEX", index_file), \
patch("gateway.mirror._append_to_sqlite"):
result = mirror_to_session("telegram", "12345", "Hello!", source_label="cli")
assert result is True
# Check JSONL was written
transcript = sessions_dir / "sess_abc.jsonl"
assert transcript.exists()
msg = json.loads(transcript.read_text().strip())
assert msg["content"] == "Hello!"
assert msg["role"] == "assistant"
assert msg["mirror"] is True
assert msg["mirror_source"] == "cli"
def test_successful_mirror_uses_thread_id(self, tmp_path):
sessions_dir, index_file = _setup_sessions(tmp_path, {
"topic_a": {
"session_id": "sess_topic_a",
"origin": {"platform": "telegram", "chat_id": "-1001", "thread_id": "10"},
"updated_at": "2026-01-01T00:00:00",
},
"topic_b": {
"session_id": "sess_topic_b",
"origin": {"platform": "telegram", "chat_id": "-1001", "thread_id": "11"},
"updated_at": "2026-02-01T00:00:00",
},
})
with patch.object(mirror_mod, "_SESSIONS_DIR", sessions_dir), \
patch.object(mirror_mod, "_SESSIONS_INDEX", index_file), \
patch("gateway.mirror._append_to_sqlite"):
result = mirror_to_session("telegram", "-1001", "Hello topic!", source_label="cron", thread_id="10")
assert result is True
assert (sessions_dir / "sess_topic_a.jsonl").exists()
assert not (sessions_dir / "sess_topic_b.jsonl").exists()
def test_no_matching_session(self, tmp_path):
sessions_dir, index_file = _setup_sessions(tmp_path, {})
with patch.object(mirror_mod, "_SESSIONS_DIR", sessions_dir), \
patch.object(mirror_mod, "_SESSIONS_INDEX", index_file):
result = mirror_to_session("telegram", "99999", "Hello!")
assert result is False
def test_error_returns_false(self, tmp_path):
with patch("gateway.mirror._find_session_id", side_effect=Exception("boom")):
result = mirror_to_session("telegram", "123", "msg")
assert result is False
class TestAppendToSqlite:
def test_connection_is_closed_after_use(self, tmp_path):
"""Verify _append_to_sqlite closes the SessionDB connection."""
from gateway.mirror import _append_to_sqlite
mock_db = MagicMock()
with patch("hermes_state.SessionDB", return_value=mock_db):
_append_to_sqlite("sess_1", {"role": "assistant", "content": "hello"})
mock_db.append_message.assert_called_once()
mock_db.close.assert_called_once()
def test_connection_closed_even_on_error(self, tmp_path):
"""Verify connection is closed even when append_message raises."""
from gateway.mirror import _append_to_sqlite
mock_db = MagicMock()
mock_db.append_message.side_effect = Exception("db error")
with patch("hermes_state.SessionDB", return_value=mock_db):
_append_to_sqlite("sess_1", {"role": "assistant", "content": "hello"})
mock_db.close.assert_called_once()

View file

@ -0,0 +1,356 @@
"""Tests for gateway/pairing.py — DM pairing security system."""
import json
import os
import time
from pathlib import Path
from unittest.mock import patch
from gateway.pairing import (
PairingStore,
ALPHABET,
CODE_LENGTH,
CODE_TTL_SECONDS,
RATE_LIMIT_SECONDS,
MAX_PENDING_PER_PLATFORM,
MAX_FAILED_ATTEMPTS,
LOCKOUT_SECONDS,
_secure_write,
)
def _make_store(tmp_path):
"""Create a PairingStore with PAIRING_DIR pointed to tmp_path."""
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
return PairingStore()
# ---------------------------------------------------------------------------
# _secure_write
# ---------------------------------------------------------------------------
class TestSecureWrite:
def test_creates_parent_dirs(self, tmp_path):
target = tmp_path / "sub" / "dir" / "file.json"
_secure_write(target, '{"hello": "world"}')
assert target.exists()
assert json.loads(target.read_text()) == {"hello": "world"}
def test_sets_file_permissions(self, tmp_path):
target = tmp_path / "secret.json"
_secure_write(target, "data")
mode = oct(target.stat().st_mode & 0o777)
assert mode == "0o600"
# ---------------------------------------------------------------------------
# Code generation
# ---------------------------------------------------------------------------
class TestCodeGeneration:
def test_code_format(self, tmp_path):
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
store = PairingStore()
code = store.generate_code("telegram", "user1", "Alice")
assert isinstance(code, str) and len(code) == CODE_LENGTH
assert len(code) == CODE_LENGTH
assert all(c in ALPHABET for c in code)
def test_code_uniqueness(self, tmp_path):
"""Multiple codes for different users should be distinct."""
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
store = PairingStore()
codes = set()
for i in range(3):
code = store.generate_code("telegram", f"user{i}")
assert isinstance(code, str) and len(code) == CODE_LENGTH
codes.add(code)
assert len(codes) == 3
def test_stores_pending_entry(self, tmp_path):
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
store = PairingStore()
code = store.generate_code("telegram", "user1", "Alice")
pending = store.list_pending("telegram")
assert len(pending) == 1
assert pending[0]["code"] == code
assert pending[0]["user_id"] == "user1"
assert pending[0]["user_name"] == "Alice"
# ---------------------------------------------------------------------------
# Rate limiting
# ---------------------------------------------------------------------------
class TestRateLimiting:
def test_same_user_rate_limited(self, tmp_path):
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
store = PairingStore()
code1 = store.generate_code("telegram", "user1")
code2 = store.generate_code("telegram", "user1")
assert isinstance(code1, str) and len(code1) == CODE_LENGTH
assert code2 is None # rate limited
def test_different_users_not_rate_limited(self, tmp_path):
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
store = PairingStore()
code1 = store.generate_code("telegram", "user1")
code2 = store.generate_code("telegram", "user2")
assert isinstance(code1, str) and len(code1) == CODE_LENGTH
assert isinstance(code2, str) and len(code2) == CODE_LENGTH
def test_rate_limit_expires(self, tmp_path):
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
store = PairingStore()
code1 = store.generate_code("telegram", "user1")
assert isinstance(code1, str) and len(code1) == CODE_LENGTH
# Simulate rate limit expiry
limits = store._load_json(store._rate_limit_path())
limits["telegram:user1"] = time.time() - RATE_LIMIT_SECONDS - 1
store._save_json(store._rate_limit_path(), limits)
code2 = store.generate_code("telegram", "user1")
assert isinstance(code2, str) and len(code2) == CODE_LENGTH
assert code2 != code1
# ---------------------------------------------------------------------------
# Max pending limit
# ---------------------------------------------------------------------------
class TestMaxPending:
def test_max_pending_per_platform(self, tmp_path):
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
store = PairingStore()
codes = []
for i in range(MAX_PENDING_PER_PLATFORM + 1):
code = store.generate_code("telegram", f"user{i}")
codes.append(code)
# First MAX_PENDING_PER_PLATFORM should succeed
assert all(isinstance(c, str) and len(c) == CODE_LENGTH for c in codes[:MAX_PENDING_PER_PLATFORM])
# Next one should be blocked
assert codes[MAX_PENDING_PER_PLATFORM] is None
def test_different_platforms_independent(self, tmp_path):
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
store = PairingStore()
for i in range(MAX_PENDING_PER_PLATFORM):
store.generate_code("telegram", f"user{i}")
# Different platform should still work
code = store.generate_code("discord", "user0")
assert isinstance(code, str) and len(code) == CODE_LENGTH
# ---------------------------------------------------------------------------
# Approval flow
# ---------------------------------------------------------------------------
class TestApprovalFlow:
def test_approve_valid_code(self, tmp_path):
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
store = PairingStore()
code = store.generate_code("telegram", "user1", "Alice")
result = store.approve_code("telegram", code)
assert isinstance(result, dict)
assert "user_id" in result
assert "user_name" in result
assert result["user_id"] == "user1"
assert result["user_name"] == "Alice"
def test_approved_user_is_approved(self, tmp_path):
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
store = PairingStore()
code = store.generate_code("telegram", "user1", "Alice")
store.approve_code("telegram", code)
assert store.is_approved("telegram", "user1") is True
def test_unapproved_user_not_approved(self, tmp_path):
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
store = PairingStore()
assert store.is_approved("telegram", "nonexistent") is False
def test_approve_removes_from_pending(self, tmp_path):
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
store = PairingStore()
code = store.generate_code("telegram", "user1")
store.approve_code("telegram", code)
pending = store.list_pending("telegram")
assert len(pending) == 0
def test_approve_case_insensitive(self, tmp_path):
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
store = PairingStore()
code = store.generate_code("telegram", "user1", "Alice")
result = store.approve_code("telegram", code.lower())
assert isinstance(result, dict)
assert result["user_id"] == "user1"
assert result["user_name"] == "Alice"
def test_approve_strips_whitespace(self, tmp_path):
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
store = PairingStore()
code = store.generate_code("telegram", "user1", "Alice")
result = store.approve_code("telegram", f" {code} ")
assert isinstance(result, dict)
assert result["user_id"] == "user1"
assert result["user_name"] == "Alice"
def test_invalid_code_returns_none(self, tmp_path):
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
store = PairingStore()
result = store.approve_code("telegram", "INVALIDCODE")
assert result is None
# ---------------------------------------------------------------------------
# Lockout after failed attempts
# ---------------------------------------------------------------------------
class TestLockout:
def test_lockout_after_max_failures(self, tmp_path):
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
store = PairingStore()
# Generate a valid code so platform has data
store.generate_code("telegram", "user1")
# Exhaust failed attempts
for _ in range(MAX_FAILED_ATTEMPTS):
store.approve_code("telegram", "WRONGCODE")
# Platform should now be locked out — can't generate new codes
assert store._is_locked_out("telegram") is True
def test_lockout_blocks_code_generation(self, tmp_path):
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
store = PairingStore()
for _ in range(MAX_FAILED_ATTEMPTS):
store.approve_code("telegram", "WRONG")
code = store.generate_code("telegram", "newuser")
assert code is None
def test_lockout_expires(self, tmp_path):
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
store = PairingStore()
for _ in range(MAX_FAILED_ATTEMPTS):
store.approve_code("telegram", "WRONG")
# Simulate lockout expiry
limits = store._load_json(store._rate_limit_path())
lockout_key = "_lockout:telegram"
limits[lockout_key] = time.time() - 1 # expired
store._save_json(store._rate_limit_path(), limits)
assert store._is_locked_out("telegram") is False
# ---------------------------------------------------------------------------
# Code expiry
# ---------------------------------------------------------------------------
class TestCodeExpiry:
def test_expired_codes_cleaned_up(self, tmp_path):
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
store = PairingStore()
code = store.generate_code("telegram", "user1")
# Manually expire the code
pending = store._load_json(store._pending_path("telegram"))
pending[code]["created_at"] = time.time() - CODE_TTL_SECONDS - 1
store._save_json(store._pending_path("telegram"), pending)
# Cleanup happens on next operation
remaining = store.list_pending("telegram")
assert len(remaining) == 0
def test_expired_code_cannot_be_approved(self, tmp_path):
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
store = PairingStore()
code = store.generate_code("telegram", "user1")
# Expire it
pending = store._load_json(store._pending_path("telegram"))
pending[code]["created_at"] = time.time() - CODE_TTL_SECONDS - 1
store._save_json(store._pending_path("telegram"), pending)
result = store.approve_code("telegram", code)
assert result is None
# ---------------------------------------------------------------------------
# Revoke
# ---------------------------------------------------------------------------
class TestRevoke:
def test_revoke_approved_user(self, tmp_path):
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
store = PairingStore()
code = store.generate_code("telegram", "user1", "Alice")
store.approve_code("telegram", code)
assert store.is_approved("telegram", "user1") is True
revoked = store.revoke("telegram", "user1")
assert revoked is True
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
assert store.is_approved("telegram", "user1") is False
def test_revoke_nonexistent_returns_false(self, tmp_path):
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
store = PairingStore()
assert store.revoke("telegram", "nobody") is False
# ---------------------------------------------------------------------------
# List & clear
# ---------------------------------------------------------------------------
class TestListAndClear:
def test_list_approved(self, tmp_path):
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
store = PairingStore()
code = store.generate_code("telegram", "user1", "Alice")
store.approve_code("telegram", code)
approved = store.list_approved("telegram")
assert len(approved) == 1
assert approved[0]["user_id"] == "user1"
assert approved[0]["platform"] == "telegram"
def test_list_approved_all_platforms(self, tmp_path):
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
store = PairingStore()
c1 = store.generate_code("telegram", "user1")
store.approve_code("telegram", c1)
c2 = store.generate_code("discord", "user2")
store.approve_code("discord", c2)
approved = store.list_approved()
assert len(approved) == 2
def test_clear_pending(self, tmp_path):
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
store = PairingStore()
store.generate_code("telegram", "user1")
store.generate_code("telegram", "user2")
count = store.clear_pending("telegram")
remaining = store.list_pending("telegram")
assert count == 2
assert len(remaining) == 0
def test_clear_pending_all_platforms(self, tmp_path):
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
store = PairingStore()
store.generate_code("telegram", "user1")
store.generate_code("discord", "user2")
count = store.clear_pending()
assert count == 2

View file

@ -0,0 +1,156 @@
"""Tests for PII redaction in gateway session context prompts."""
from gateway.session import (
SessionContext,
SessionSource,
build_session_context_prompt,
_hash_id,
_hash_sender_id,
_hash_chat_id,
_looks_like_phone,
)
from gateway.config import Platform, HomeChannel
# ---------------------------------------------------------------------------
# Low-level helpers
# ---------------------------------------------------------------------------
class TestHashHelpers:
def test_hash_id_deterministic(self):
assert _hash_id("12345") == _hash_id("12345")
def test_hash_id_12_hex_chars(self):
h = _hash_id("user-abc")
assert len(h) == 12
assert all(c in "0123456789abcdef" for c in h)
def test_hash_sender_id_prefix(self):
assert _hash_sender_id("12345").startswith("user_")
assert len(_hash_sender_id("12345")) == 17 # "user_" + 12
def test_hash_chat_id_preserves_prefix(self):
result = _hash_chat_id("telegram:12345")
assert result.startswith("telegram:")
assert "12345" not in result
def test_hash_chat_id_no_prefix(self):
result = _hash_chat_id("12345")
assert len(result) == 12
assert "12345" not in result
def test_looks_like_phone(self):
assert _looks_like_phone("+15551234567")
assert _looks_like_phone("15551234567")
assert _looks_like_phone("+1-555-123-4567")
assert not _looks_like_phone("alice")
assert not _looks_like_phone("user-123")
assert not _looks_like_phone("")
# ---------------------------------------------------------------------------
# Integration: build_session_context_prompt
# ---------------------------------------------------------------------------
def _make_context(
user_id="user-123",
user_name=None,
chat_id="telegram:99999",
platform=Platform.TELEGRAM,
home_channels=None,
):
source = SessionSource(
platform=platform,
chat_id=chat_id,
chat_type="dm",
user_id=user_id,
user_name=user_name,
)
return SessionContext(
source=source,
connected_platforms=[platform],
home_channels=home_channels or {},
)
class TestBuildSessionContextPromptRedaction:
def test_no_redaction_by_default(self):
ctx = _make_context(user_id="user-123")
prompt = build_session_context_prompt(ctx)
assert "user-123" in prompt
def test_user_id_hashed_when_redact_pii(self):
ctx = _make_context(user_id="user-123")
prompt = build_session_context_prompt(ctx, redact_pii=True)
assert "user-123" not in prompt
assert "user_" in prompt # hashed ID present
def test_user_name_not_redacted(self):
ctx = _make_context(user_id="user-123", user_name="Alice")
prompt = build_session_context_prompt(ctx, redact_pii=True)
assert "Alice" in prompt
# user_id should not appear when user_name is present (name takes priority)
assert "user-123" not in prompt
def test_home_channel_id_hashed(self):
hc = {
Platform.TELEGRAM: HomeChannel(
platform=Platform.TELEGRAM,
chat_id="telegram:99999",
name="Home Chat",
)
}
ctx = _make_context(home_channels=hc)
prompt = build_session_context_prompt(ctx, redact_pii=True)
assert "99999" not in prompt
assert "telegram:" in prompt # prefix preserved
assert "Home Chat" in prompt # name not redacted
def test_home_channel_id_preserved_without_redaction(self):
hc = {
Platform.TELEGRAM: HomeChannel(
platform=Platform.TELEGRAM,
chat_id="telegram:99999",
name="Home Chat",
)
}
ctx = _make_context(home_channels=hc)
prompt = build_session_context_prompt(ctx, redact_pii=False)
assert "99999" in prompt
def test_redaction_is_deterministic(self):
ctx = _make_context(user_id="+15551234567")
prompt1 = build_session_context_prompt(ctx, redact_pii=True)
prompt2 = build_session_context_prompt(ctx, redact_pii=True)
assert prompt1 == prompt2
def test_different_ids_produce_different_hashes(self):
ctx1 = _make_context(user_id="user-A")
ctx2 = _make_context(user_id="user-B")
p1 = build_session_context_prompt(ctx1, redact_pii=True)
p2 = build_session_context_prompt(ctx2, redact_pii=True)
assert p1 != p2
def test_discord_ids_not_redacted_even_with_flag(self):
"""Discord needs real IDs for <@user_id> mentions."""
ctx = _make_context(user_id="123456789", platform=Platform.DISCORD)
prompt = build_session_context_prompt(ctx, redact_pii=True)
assert "123456789" in prompt
def test_whatsapp_ids_redacted(self):
ctx = _make_context(user_id="+15551234567", platform=Platform.WHATSAPP)
prompt = build_session_context_prompt(ctx, redact_pii=True)
assert "+15551234567" not in prompt
assert "user_" in prompt
def test_signal_ids_redacted(self):
ctx = _make_context(user_id="+15551234567", platform=Platform.SIGNAL)
prompt = build_session_context_prompt(ctx, redact_pii=True)
assert "+15551234567" not in prompt
assert "user_" in prompt
def test_slack_ids_not_redacted(self):
"""Slack may need IDs for mentions too."""
ctx = _make_context(user_id="U12345ABC", platform=Platform.SLACK)
prompt = build_session_context_prompt(ctx, redact_pii=True)
assert "U12345ABC" in prompt

View file

@ -0,0 +1,129 @@
"""Tests for the /plan gateway slash command."""
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from agent.skill_commands import scan_skill_commands
from gateway.config import GatewayConfig, Platform, PlatformConfig
from gateway.platforms.base import MessageEvent
from gateway.session import SessionEntry, SessionSource
def _make_runner():
from gateway.run import GatewayRunner
runner = object.__new__(GatewayRunner)
runner.config = GatewayConfig(
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}
)
runner.adapters = {}
runner._voice_mode = {}
runner.hooks = SimpleNamespace(emit=AsyncMock(), loaded_hooks=False)
runner.session_store = MagicMock()
runner.session_store.get_or_create_session.return_value = SessionEntry(
session_key="agent:main:telegram:dm:c1:u1",
session_id="sess-1",
created_at=datetime.now(),
updated_at=datetime.now(),
platform=Platform.TELEGRAM,
chat_type="dm",
)
runner.session_store.load_transcript.return_value = []
runner.session_store.has_any_sessions.return_value = True
runner.session_store.append_to_transcript = MagicMock()
runner.session_store.rewrite_transcript = MagicMock()
runner._running_agents = {}
runner._pending_messages = {}
runner._pending_approvals = {}
runner._session_db = None
runner._reasoning_config = None
runner._provider_routing = {}
runner._fallback_model = None
runner._show_reasoning = False
runner._is_user_authorized = lambda _source: True
runner._set_session_env = lambda _context: None
runner._run_agent = AsyncMock(
return_value={
"final_response": "planned",
"messages": [],
"tools": [],
"history_offset": 0,
"last_prompt_tokens": 0,
}
)
return runner
def _make_event(text="/plan"):
return MessageEvent(
text=text,
source=SessionSource(
platform=Platform.TELEGRAM,
user_id="u1",
chat_id="c1",
user_name="tester",
chat_type="dm",
),
message_id="m1",
)
def _make_plan_skill(skills_dir):
skill_dir = skills_dir / "plan"
skill_dir.mkdir(parents=True, exist_ok=True)
(skill_dir / "SKILL.md").write_text(
"""---
name: plan
description: Plan mode skill.
---
# Plan
Use the current conversation context when no explicit instruction is provided.
Save plans under the active workspace's .hermes/plans directory.
"""
)
class TestGatewayPlanCommand:
@pytest.mark.asyncio
async def test_plan_command_loads_skill_and_runs_agent(self, monkeypatch, tmp_path):
import gateway.run as gateway_run
runner = _make_runner()
event = _make_event("/plan Add OAuth login")
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
monkeypatch.setattr(
"agent.model_metadata.get_model_context_length",
lambda *_args, **_kwargs: 100_000,
)
with patch("tools.skills_tool.SKILLS_DIR", tmp_path):
_make_plan_skill(tmp_path)
scan_skill_commands()
result = await runner._handle_message(event)
assert result == "planned"
forwarded = runner._run_agent.call_args.kwargs["message"]
assert "Plan mode skill" in forwarded
assert "Add OAuth login" in forwarded
assert ".hermes/plans" in forwarded
assert str(tmp_path / "plans") not in forwarded
assert "active workspace/backend cwd" in forwarded
assert "Runtime note:" in forwarded
@pytest.mark.asyncio
async def test_plan_command_appears_in_help_output_via_skill_listing(self, tmp_path):
runner = _make_runner()
event = _make_event("/help")
with patch("tools.skills_tool.SKILLS_DIR", tmp_path):
_make_plan_skill(tmp_path)
scan_skill_commands()
result = await runner._handle_help_command(event)
assert "/plan" in result

View file

@ -0,0 +1,412 @@
"""Tests for gateway/platforms/base.py — MessageEvent, media extraction, message truncation."""
import os
from unittest.mock import patch
from gateway.platforms.base import (
BasePlatformAdapter,
GATEWAY_SECRET_CAPTURE_UNSUPPORTED_MESSAGE,
MessageEvent,
MessageType,
)
class TestSecretCaptureGuidance:
def test_gateway_secret_capture_message_points_to_local_setup(self):
message = GATEWAY_SECRET_CAPTURE_UNSUPPORTED_MESSAGE
assert "local cli" in message.lower()
assert "~/.hermes/.env" in message
# ---------------------------------------------------------------------------
# MessageEvent — command parsing
# ---------------------------------------------------------------------------
class TestMessageEventIsCommand:
def test_slash_command(self):
event = MessageEvent(text="/new")
assert event.is_command() is True
def test_regular_text(self):
event = MessageEvent(text="hello world")
assert event.is_command() is False
def test_empty_text(self):
event = MessageEvent(text="")
assert event.is_command() is False
def test_slash_only(self):
event = MessageEvent(text="/")
assert event.is_command() is True
class TestMessageEventGetCommand:
def test_simple_command(self):
event = MessageEvent(text="/new")
assert event.get_command() == "new"
def test_command_with_args(self):
event = MessageEvent(text="/reset session")
assert event.get_command() == "reset"
def test_not_a_command(self):
event = MessageEvent(text="hello")
assert event.get_command() is None
def test_command_is_lowercased(self):
event = MessageEvent(text="/HELP")
assert event.get_command() == "help"
def test_slash_only_returns_empty(self):
event = MessageEvent(text="/")
assert event.get_command() == ""
class TestMessageEventGetCommandArgs:
def test_command_with_args(self):
event = MessageEvent(text="/new session id 123")
assert event.get_command_args() == "session id 123"
def test_command_without_args(self):
event = MessageEvent(text="/new")
assert event.get_command_args() == ""
def test_not_a_command_returns_full_text(self):
event = MessageEvent(text="hello world")
assert event.get_command_args() == "hello world"
# ---------------------------------------------------------------------------
# extract_images
# ---------------------------------------------------------------------------
class TestExtractImages:
def test_no_images(self):
images, cleaned = BasePlatformAdapter.extract_images("Just regular text.")
assert images == []
assert cleaned == "Just regular text."
def test_markdown_image_with_image_ext(self):
content = "Here is a photo: ![cat](https://example.com/cat.png)"
images, cleaned = BasePlatformAdapter.extract_images(content)
assert len(images) == 1
assert images[0][0] == "https://example.com/cat.png"
assert images[0][1] == "cat"
assert "![cat]" not in cleaned
def test_markdown_image_jpg(self):
content = "![photo](https://example.com/photo.jpg)"
images, _ = BasePlatformAdapter.extract_images(content)
assert len(images) == 1
assert images[0][0] == "https://example.com/photo.jpg"
assert images[0][1] == "photo"
def test_markdown_image_jpeg(self):
content = "![](https://example.com/photo.jpeg)"
images, _ = BasePlatformAdapter.extract_images(content)
assert len(images) == 1
assert images[0][0] == "https://example.com/photo.jpeg"
assert images[0][1] == ""
def test_markdown_image_gif(self):
content = "![anim](https://example.com/anim.gif)"
images, _ = BasePlatformAdapter.extract_images(content)
assert len(images) == 1
assert images[0][0] == "https://example.com/anim.gif"
assert images[0][1] == "anim"
def test_markdown_image_webp(self):
content = "![](https://example.com/img.webp)"
images, _ = BasePlatformAdapter.extract_images(content)
assert len(images) == 1
assert images[0][0] == "https://example.com/img.webp"
assert images[0][1] == ""
def test_fal_media_cdn(self):
content = "![gen](https://fal.media/files/abc123/output.png)"
images, _ = BasePlatformAdapter.extract_images(content)
assert len(images) == 1
assert images[0][0] == "https://fal.media/files/abc123/output.png"
assert images[0][1] == "gen"
def test_fal_cdn_url(self):
content = "![](https://fal-cdn.example.com/result)"
images, _ = BasePlatformAdapter.extract_images(content)
assert len(images) == 1
assert images[0][0] == "https://fal-cdn.example.com/result"
assert images[0][1] == ""
def test_replicate_delivery(self):
content = "![](https://replicate.delivery/pbxt/abc/output)"
images, _ = BasePlatformAdapter.extract_images(content)
assert len(images) == 1
assert images[0][0] == "https://replicate.delivery/pbxt/abc/output"
assert images[0][1] == ""
def test_non_image_ext_not_extracted(self):
"""Markdown image with non-image extension should not be extracted."""
content = "![doc](https://example.com/report.pdf)"
images, cleaned = BasePlatformAdapter.extract_images(content)
assert images == []
assert "![doc]" in cleaned # Should be preserved
def test_html_img_tag(self):
content = 'Check this: <img src="https://example.com/photo.png">'
images, cleaned = BasePlatformAdapter.extract_images(content)
assert len(images) == 1
assert images[0][0] == "https://example.com/photo.png"
assert images[0][1] == "" # HTML images have no alt text
assert "<img" not in cleaned
def test_html_img_self_closing(self):
content = '<img src="https://example.com/photo.png"/>'
images, _ = BasePlatformAdapter.extract_images(content)
assert len(images) == 1
assert images[0][0] == "https://example.com/photo.png"
assert images[0][1] == ""
def test_html_img_with_closing_tag(self):
content = '<img src="https://example.com/photo.png"></img>'
images, _ = BasePlatformAdapter.extract_images(content)
assert len(images) == 1
assert images[0][0] == "https://example.com/photo.png"
assert images[0][1] == ""
def test_multiple_images(self):
content = "![a](https://example.com/a.png)\n![b](https://example.com/b.jpg)"
images, cleaned = BasePlatformAdapter.extract_images(content)
assert len(images) == 2
assert "![a]" not in cleaned
assert "![b]" not in cleaned
def test_mixed_markdown_and_html(self):
content = '![cat](https://example.com/cat.png)\n<img src="https://example.com/dog.jpg">'
images, _ = BasePlatformAdapter.extract_images(content)
assert len(images) == 2
def test_cleaned_content_trims_excess_newlines(self):
content = "Before\n\n![img](https://example.com/img.png)\n\n\n\nAfter"
_, cleaned = BasePlatformAdapter.extract_images(content)
assert "\n\n\n" not in cleaned
def test_non_http_url_not_matched(self):
content = "![file](file:///local/path.png)"
images, _ = BasePlatformAdapter.extract_images(content)
assert images == []
def test_non_image_link_preserved_when_mixed_with_images(self):
"""Regression: non-image markdown links must not be silently removed
when the response also contains real images."""
content = (
"Here is the image: ![photo](https://fal.media/cat.png)\n"
"And a doc: ![report](https://example.com/report.pdf)"
)
images, cleaned = BasePlatformAdapter.extract_images(content)
assert len(images) == 1
assert images[0][0] == "https://fal.media/cat.png"
# The PDF link must survive in cleaned content
assert "![report](https://example.com/report.pdf)" in cleaned
# ---------------------------------------------------------------------------
# extract_media
# ---------------------------------------------------------------------------
class TestExtractMedia:
def test_no_media(self):
media, cleaned = BasePlatformAdapter.extract_media("Just text.")
assert media == []
assert cleaned == "Just text."
def test_single_media_tag(self):
content = "MEDIA:/path/to/audio.ogg"
media, cleaned = BasePlatformAdapter.extract_media(content)
assert len(media) == 1
assert media[0][0] == "/path/to/audio.ogg"
assert media[0][1] is False # no voice tag
def test_media_with_voice_directive(self):
content = "[[audio_as_voice]]\nMEDIA:/path/to/voice.ogg"
media, cleaned = BasePlatformAdapter.extract_media(content)
assert len(media) == 1
assert media[0][0] == "/path/to/voice.ogg"
assert media[0][1] is True # voice tag present
def test_multiple_media_tags(self):
content = "MEDIA:/a.ogg\nMEDIA:/b.ogg"
media, _ = BasePlatformAdapter.extract_media(content)
assert len(media) == 2
def test_voice_directive_removed_from_content(self):
content = "[[audio_as_voice]]\nSome text\nMEDIA:/voice.ogg"
_, cleaned = BasePlatformAdapter.extract_media(content)
assert "[[audio_as_voice]]" not in cleaned
assert "MEDIA:" not in cleaned
assert "Some text" in cleaned
def test_media_with_text_before(self):
content = "Here is your audio:\nMEDIA:/output.ogg"
media, cleaned = BasePlatformAdapter.extract_media(content)
assert len(media) == 1
assert "Here is your audio" in cleaned
def test_cleaned_content_trims_excess_newlines(self):
content = "Before\n\nMEDIA:/audio.ogg\n\n\n\nAfter"
_, cleaned = BasePlatformAdapter.extract_media(content)
assert "\n\n\n" not in cleaned
def test_media_tag_allows_optional_whitespace_after_colon(self):
content = "MEDIA: /path/to/audio.ogg"
media, cleaned = BasePlatformAdapter.extract_media(content)
assert media == [("/path/to/audio.ogg", False)]
assert cleaned == ""
def test_media_tag_strips_wrapping_quotes_and_backticks(self):
content = "MEDIA: `/path/to/file.png`\nMEDIA:\"/path/to/file2.png\"\nMEDIA:'/path/to/file3.png'"
media, cleaned = BasePlatformAdapter.extract_media(content)
assert media == [
("/path/to/file.png", False),
("/path/to/file2.png", False),
("/path/to/file3.png", False),
]
assert cleaned == ""
def test_media_tag_supports_quoted_paths_with_spaces(self):
content = "Here\nMEDIA: '/tmp/my image.png'\nAfter"
media, cleaned = BasePlatformAdapter.extract_media(content)
assert media == [("/tmp/my image.png", False)]
assert "Here" in cleaned
assert "After" in cleaned
# ---------------------------------------------------------------------------
# truncate_message
# ---------------------------------------------------------------------------
class TestTruncateMessage:
def _adapter(self):
"""Create a minimal adapter instance for testing static/instance methods."""
class StubAdapter(BasePlatformAdapter):
async def connect(self):
return True
async def disconnect(self):
pass
async def send(self, *a, **kw):
pass
async def get_chat_info(self, *a):
return {}
from gateway.config import Platform, PlatformConfig
config = PlatformConfig(enabled=True, token="test")
return StubAdapter(config=config, platform=Platform.TELEGRAM)
def test_short_message_single_chunk(self):
adapter = self._adapter()
chunks = adapter.truncate_message("Hello world", max_length=100)
assert chunks == ["Hello world"]
def test_exact_length_single_chunk(self):
adapter = self._adapter()
msg = "x" * 100
chunks = adapter.truncate_message(msg, max_length=100)
assert chunks == [msg]
def test_long_message_splits(self):
adapter = self._adapter()
msg = "word " * 200 # ~1000 chars
chunks = adapter.truncate_message(msg, max_length=200)
assert len(chunks) > 1
# Verify all original content is preserved across chunks
reassembled = "".join(chunks)
# Strip chunk indicators like (1/N) to get raw content
for word in msg.strip().split():
assert word in reassembled, f"Word '{word}' lost during truncation"
def test_chunks_have_indicators(self):
adapter = self._adapter()
msg = "word " * 200
chunks = adapter.truncate_message(msg, max_length=200)
assert "(1/" in chunks[0]
assert f"({len(chunks)}/{len(chunks)})" in chunks[-1]
def test_code_block_first_chunk_closed(self):
adapter = self._adapter()
msg = "Before\n```python\n" + "x = 1\n" * 100 + "```\nAfter"
chunks = adapter.truncate_message(msg, max_length=300)
assert len(chunks) > 1
# First chunk must have a closing fence appended (code block was split)
first_fences = chunks[0].count("```")
assert first_fences == 2, "First chunk should have opening + closing fence"
def test_code_block_language_tag_carried(self):
adapter = self._adapter()
msg = "Start\n```javascript\n" + "console.log('x');\n" * 80 + "```\nEnd"
chunks = adapter.truncate_message(msg, max_length=300)
if len(chunks) > 1:
# At least one continuation chunk should reopen with ```javascript
reopened_with_lang = any("```javascript" in chunk for chunk in chunks[1:])
assert reopened_with_lang, (
"No continuation chunk reopened with language tag"
)
def test_continuation_chunks_have_balanced_fences(self):
"""Regression: continuation chunks must close reopened code blocks."""
adapter = self._adapter()
msg = "Before\n```python\n" + "x = 1\n" * 100 + "```\nAfter"
chunks = adapter.truncate_message(msg, max_length=300)
assert len(chunks) > 1
for i, chunk in enumerate(chunks):
fence_count = chunk.count("```")
assert fence_count % 2 == 0, (
f"Chunk {i} has unbalanced fences ({fence_count})"
)
def test_each_chunk_under_max_length(self):
adapter = self._adapter()
msg = "word " * 500
max_len = 200
chunks = adapter.truncate_message(msg, max_length=max_len)
for i, chunk in enumerate(chunks):
assert len(chunk) <= max_len + 20, (
f"Chunk {i} too long: {len(chunk)} > {max_len}"
)
# ---------------------------------------------------------------------------
# _get_human_delay
# ---------------------------------------------------------------------------
class TestGetHumanDelay:
def test_off_mode(self):
with patch.dict(os.environ, {"HERMES_HUMAN_DELAY_MODE": "off"}):
assert BasePlatformAdapter._get_human_delay() == 0.0
def test_default_is_off(self):
with patch.dict(os.environ, {}, clear=False):
os.environ.pop("HERMES_HUMAN_DELAY_MODE", None)
assert BasePlatformAdapter._get_human_delay() == 0.0
def test_natural_mode_range(self):
with patch.dict(os.environ, {"HERMES_HUMAN_DELAY_MODE": "natural"}):
delay = BasePlatformAdapter._get_human_delay()
assert 0.8 <= delay <= 2.5
def test_custom_mode_uses_env_vars(self):
env = {
"HERMES_HUMAN_DELAY_MODE": "custom",
"HERMES_HUMAN_DELAY_MIN_MS": "100",
"HERMES_HUMAN_DELAY_MAX_MS": "200",
}
with patch.dict(os.environ, env):
delay = BasePlatformAdapter._get_human_delay()
assert 0.1 <= delay <= 0.2

View file

@ -0,0 +1,401 @@
"""Tests for the gateway platform reconnection watcher."""
import asyncio
import time
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from gateway.config import GatewayConfig, Platform, PlatformConfig
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, SendResult
from gateway.run import GatewayRunner
class StubAdapter(BasePlatformAdapter):
"""Adapter whose connect() result can be controlled."""
def __init__(self, *, succeed=True, fatal_error=None, fatal_retryable=True):
super().__init__(PlatformConfig(enabled=True, token="test"), Platform.TELEGRAM)
self._succeed = succeed
self._fatal_error = fatal_error
self._fatal_retryable = fatal_retryable
async def connect(self):
if self._fatal_error:
self._set_fatal_error("test_error", self._fatal_error, retryable=self._fatal_retryable)
return False
return self._succeed
async def disconnect(self):
return None
async def send(self, chat_id, content, reply_to=None, metadata=None):
return SendResult(success=True, message_id="1")
async def send_typing(self, chat_id, metadata=None):
return None
async def get_chat_info(self, chat_id):
return {"id": chat_id}
def _make_runner():
"""Create a minimal GatewayRunner via object.__new__ to skip __init__."""
runner = object.__new__(GatewayRunner)
runner.config = GatewayConfig(
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="test")}
)
runner._running = True
runner._shutdown_event = asyncio.Event()
runner._exit_reason = None
runner._exit_with_failure = False
runner._exit_cleanly = False
runner._failed_platforms = {}
runner.adapters = {}
runner.delivery_router = MagicMock()
runner._running_agents = {}
runner._pending_messages = {}
runner._pending_approvals = {}
runner._honcho_managers = {}
runner._honcho_configs = {}
runner._shutdown_all_gateway_honcho = lambda: None
return runner
# --- Startup queueing ---
class TestStartupFailureQueuing:
"""Verify that failed platforms are queued during startup."""
def test_failed_platform_queued_on_connect_failure(self):
"""When adapter.connect() returns False without fatal error, queue for retry."""
runner = _make_runner()
platform_config = PlatformConfig(enabled=True, token="test")
runner._failed_platforms[Platform.TELEGRAM] = {
"config": platform_config,
"attempts": 1,
"next_retry": time.monotonic() + 30,
}
assert Platform.TELEGRAM in runner._failed_platforms
assert runner._failed_platforms[Platform.TELEGRAM]["attempts"] == 1
def test_failed_platform_not_queued_for_nonretryable(self):
"""Non-retryable errors should not be in the retry queue."""
runner = _make_runner()
# Simulate: adapter had a non-retryable error, wasn't queued
assert Platform.TELEGRAM not in runner._failed_platforms
# --- Reconnect watcher ---
class TestPlatformReconnectWatcher:
"""Test the _platform_reconnect_watcher background task."""
@pytest.mark.asyncio
async def test_reconnect_succeeds_on_retry(self):
"""Watcher should reconnect a failed platform when connect() succeeds."""
runner = _make_runner()
runner._sync_voice_mode_state_to_adapter = MagicMock()
platform_config = PlatformConfig(enabled=True, token="test")
runner._failed_platforms[Platform.TELEGRAM] = {
"config": platform_config,
"attempts": 1,
"next_retry": time.monotonic() - 1, # Already past retry time
}
succeed_adapter = StubAdapter(succeed=True)
real_sleep = asyncio.sleep
with patch.object(runner, "_create_adapter", return_value=succeed_adapter):
with patch("gateway.run.build_channel_directory", create=True):
# Run one iteration of the watcher then stop
async def run_one_iteration():
runner._running = True
# Patch the sleep to exit after first check
call_count = 0
async def fake_sleep(n):
nonlocal call_count
call_count += 1
if call_count > 1:
runner._running = False
await real_sleep(0)
with patch("asyncio.sleep", side_effect=fake_sleep):
await runner._platform_reconnect_watcher()
await run_one_iteration()
assert Platform.TELEGRAM not in runner._failed_platforms
assert Platform.TELEGRAM in runner.adapters
@pytest.mark.asyncio
async def test_reconnect_nonretryable_removed_from_queue(self):
"""Non-retryable errors should remove the platform from the retry queue."""
runner = _make_runner()
platform_config = PlatformConfig(enabled=True, token="test")
runner._failed_platforms[Platform.TELEGRAM] = {
"config": platform_config,
"attempts": 1,
"next_retry": time.monotonic() - 1,
}
fail_adapter = StubAdapter(
succeed=False, fatal_error="bad token", fatal_retryable=False
)
real_sleep = asyncio.sleep
with patch.object(runner, "_create_adapter", return_value=fail_adapter):
async def run_one_iteration():
runner._running = True
call_count = 0
async def fake_sleep(n):
nonlocal call_count
call_count += 1
if call_count > 1:
runner._running = False
await real_sleep(0)
with patch("asyncio.sleep", side_effect=fake_sleep):
await runner._platform_reconnect_watcher()
await run_one_iteration()
assert Platform.TELEGRAM not in runner._failed_platforms
assert Platform.TELEGRAM not in runner.adapters
@pytest.mark.asyncio
async def test_reconnect_retryable_stays_in_queue(self):
"""Retryable failures should remain in the queue with incremented attempts."""
runner = _make_runner()
platform_config = PlatformConfig(enabled=True, token="test")
runner._failed_platforms[Platform.TELEGRAM] = {
"config": platform_config,
"attempts": 1,
"next_retry": time.monotonic() - 1,
}
fail_adapter = StubAdapter(
succeed=False, fatal_error="DNS failure", fatal_retryable=True
)
real_sleep = asyncio.sleep
with patch.object(runner, "_create_adapter", return_value=fail_adapter):
async def run_one_iteration():
runner._running = True
call_count = 0
async def fake_sleep(n):
nonlocal call_count
call_count += 1
if call_count > 1:
runner._running = False
await real_sleep(0)
with patch("asyncio.sleep", side_effect=fake_sleep):
await runner._platform_reconnect_watcher()
await run_one_iteration()
assert Platform.TELEGRAM in runner._failed_platforms
assert runner._failed_platforms[Platform.TELEGRAM]["attempts"] == 2
@pytest.mark.asyncio
async def test_reconnect_gives_up_after_max_attempts(self):
"""After max attempts, platform should be removed from retry queue."""
runner = _make_runner()
platform_config = PlatformConfig(enabled=True, token="test")
runner._failed_platforms[Platform.TELEGRAM] = {
"config": platform_config,
"attempts": 20, # At max
"next_retry": time.monotonic() - 1,
}
real_sleep = asyncio.sleep
with patch.object(runner, "_create_adapter") as mock_create:
async def run_one_iteration():
runner._running = True
call_count = 0
async def fake_sleep(n):
nonlocal call_count
call_count += 1
if call_count > 1:
runner._running = False
await real_sleep(0)
with patch("asyncio.sleep", side_effect=fake_sleep):
await runner._platform_reconnect_watcher()
await run_one_iteration()
assert Platform.TELEGRAM not in runner._failed_platforms
mock_create.assert_not_called() # Should give up without trying
@pytest.mark.asyncio
async def test_reconnect_skips_when_not_time_yet(self):
"""Watcher should skip platforms whose next_retry is in the future."""
runner = _make_runner()
platform_config = PlatformConfig(enabled=True, token="test")
runner._failed_platforms[Platform.TELEGRAM] = {
"config": platform_config,
"attempts": 1,
"next_retry": time.monotonic() + 9999, # Far in the future
}
real_sleep = asyncio.sleep
with patch.object(runner, "_create_adapter") as mock_create:
async def run_one_iteration():
runner._running = True
call_count = 0
async def fake_sleep(n):
nonlocal call_count
call_count += 1
if call_count > 1:
runner._running = False
await real_sleep(0)
with patch("asyncio.sleep", side_effect=fake_sleep):
await runner._platform_reconnect_watcher()
await run_one_iteration()
assert Platform.TELEGRAM in runner._failed_platforms
mock_create.assert_not_called()
@pytest.mark.asyncio
async def test_no_failed_platforms_watcher_idles(self):
"""When no platforms are failed, watcher should just idle."""
runner = _make_runner()
# No failed platforms
real_sleep = asyncio.sleep
with patch.object(runner, "_create_adapter") as mock_create:
async def run_briefly():
runner._running = True
call_count = 0
async def fake_sleep(n):
nonlocal call_count
call_count += 1
if call_count > 2:
runner._running = False
await real_sleep(0)
with patch("asyncio.sleep", side_effect=fake_sleep):
await runner._platform_reconnect_watcher()
await run_briefly()
mock_create.assert_not_called()
@pytest.mark.asyncio
async def test_adapter_create_returns_none(self):
"""If _create_adapter returns None, remove from queue (missing deps)."""
runner = _make_runner()
platform_config = PlatformConfig(enabled=True, token="test")
runner._failed_platforms[Platform.TELEGRAM] = {
"config": platform_config,
"attempts": 1,
"next_retry": time.monotonic() - 1,
}
real_sleep = asyncio.sleep
with patch.object(runner, "_create_adapter", return_value=None):
async def run_one_iteration():
runner._running = True
call_count = 0
async def fake_sleep(n):
nonlocal call_count
call_count += 1
if call_count > 1:
runner._running = False
await real_sleep(0)
with patch("asyncio.sleep", side_effect=fake_sleep):
await runner._platform_reconnect_watcher()
await run_one_iteration()
assert Platform.TELEGRAM not in runner._failed_platforms
# --- Runtime disconnection queueing ---
class TestRuntimeDisconnectQueuing:
"""Test that _handle_adapter_fatal_error queues retryable disconnections."""
@pytest.mark.asyncio
async def test_retryable_runtime_error_queued_for_reconnect(self):
"""Retryable runtime errors should add the platform to _failed_platforms."""
runner = _make_runner()
adapter = StubAdapter(succeed=True)
adapter._set_fatal_error("network_error", "DNS failure", retryable=True)
runner.adapters[Platform.TELEGRAM] = adapter
await runner._handle_adapter_fatal_error(adapter)
assert Platform.TELEGRAM in runner._failed_platforms
assert runner._failed_platforms[Platform.TELEGRAM]["attempts"] == 0
@pytest.mark.asyncio
async def test_nonretryable_runtime_error_not_queued(self):
"""Non-retryable runtime errors should not be queued for reconnection."""
runner = _make_runner()
adapter = StubAdapter(succeed=True)
adapter._set_fatal_error("auth_error", "bad token", retryable=False)
runner.adapters[Platform.TELEGRAM] = adapter
# Need to prevent stop() from running fully
runner.stop = AsyncMock()
await runner._handle_adapter_fatal_error(adapter)
assert Platform.TELEGRAM not in runner._failed_platforms
@pytest.mark.asyncio
async def test_retryable_error_prevents_shutdown_when_queued(self):
"""Gateway should not shut down if failed platforms are queued for reconnection."""
runner = _make_runner()
runner.stop = AsyncMock()
adapter = StubAdapter(succeed=True)
adapter._set_fatal_error("network_error", "DNS failure", retryable=True)
runner.adapters[Platform.TELEGRAM] = adapter
await runner._handle_adapter_fatal_error(adapter)
# stop() should NOT have been called since we have platforms queued
runner.stop.assert_not_called()
assert Platform.TELEGRAM in runner._failed_platforms
@pytest.mark.asyncio
async def test_nonretryable_error_triggers_shutdown(self):
"""Gateway should shut down when no adapters remain and nothing is queued."""
runner = _make_runner()
runner.stop = AsyncMock()
adapter = StubAdapter(succeed=True)
adapter._set_fatal_error("auth_error", "bad token", retryable=False)
runner.adapters[Platform.TELEGRAM] = adapter
await runner._handle_adapter_fatal_error(adapter)
runner.stop.assert_called_once()

View file

@ -0,0 +1,165 @@
"""Tests for /queue message consumption after normal agent completion.
Verifies that messages queued via /queue (which store in
adapter._pending_messages WITHOUT triggering an interrupt) are consumed
after the agent finishes its current task not silently dropped.
"""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from gateway.platforms.base import (
BasePlatformAdapter,
MessageEvent,
MessageType,
PlatformConfig,
Platform,
)
# ---------------------------------------------------------------------------
# Minimal adapter for testing pending message storage
# ---------------------------------------------------------------------------
class _StubAdapter(BasePlatformAdapter):
def __init__(self):
super().__init__(PlatformConfig(enabled=True, token="test"), Platform.TELEGRAM)
async def connect(self) -> bool:
return True
async def disconnect(self) -> None:
self._mark_disconnected()
async def send(self, chat_id, content, reply_to=None, metadata=None):
from gateway.platforms.base import SendResult
return SendResult(success=True, message_id="msg-1")
async def get_chat_info(self, chat_id):
return {"id": chat_id, "type": "dm"}
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestQueueMessageStorage:
"""Verify /queue stores messages correctly in adapter._pending_messages."""
def test_queue_stores_message_in_pending(self):
adapter = _StubAdapter()
session_key = "telegram:user:123"
event = MessageEvent(
text="do this next",
message_type=MessageType.TEXT,
source=MagicMock(chat_id="123", platform=Platform.TELEGRAM),
message_id="q1",
)
adapter._pending_messages[session_key] = event
assert session_key in adapter._pending_messages
assert adapter._pending_messages[session_key].text == "do this next"
def test_get_pending_message_consumes_and_clears(self):
adapter = _StubAdapter()
session_key = "telegram:user:123"
event = MessageEvent(
text="queued prompt",
message_type=MessageType.TEXT,
source=MagicMock(chat_id="123", platform=Platform.TELEGRAM),
message_id="q2",
)
adapter._pending_messages[session_key] = event
retrieved = adapter.get_pending_message(session_key)
assert retrieved is not None
assert retrieved.text == "queued prompt"
# Should be consumed (cleared)
assert adapter.get_pending_message(session_key) is None
def test_queue_does_not_set_interrupt_event(self):
"""The whole point of /queue — no interrupt signal."""
adapter = _StubAdapter()
session_key = "telegram:user:123"
# Simulate an active session (agent running)
adapter._active_sessions[session_key] = asyncio.Event()
# Store a queued message (what /queue does)
event = MessageEvent(
text="queued",
message_type=MessageType.TEXT,
source=MagicMock(),
message_id="q3",
)
adapter._pending_messages[session_key] = event
# The interrupt event should NOT be set
assert not adapter._active_sessions[session_key].is_set()
assert not adapter.has_pending_interrupt(session_key)
def test_regular_message_sets_interrupt_event(self):
"""Contrast: regular messages DO trigger interrupt."""
adapter = _StubAdapter()
session_key = "telegram:user:123"
adapter._active_sessions[session_key] = asyncio.Event()
# Simulate regular message arrival (what handle_message does)
event = MessageEvent(
text="new message",
message_type=MessageType.TEXT,
source=MagicMock(),
message_id="m1",
)
adapter._pending_messages[session_key] = event
adapter._active_sessions[session_key].set() # this is what handle_message does
assert adapter.has_pending_interrupt(session_key)
class TestQueueConsumptionAfterCompletion:
"""Verify that pending messages are consumed after normal completion."""
def test_pending_message_available_after_normal_completion(self):
"""After agent finishes without interrupt, pending message should
still be retrievable from adapter._pending_messages."""
adapter = _StubAdapter()
session_key = "telegram:user:123"
# Simulate: agent starts, /queue stores a message, agent finishes
adapter._active_sessions[session_key] = asyncio.Event()
event = MessageEvent(
text="process this after",
message_type=MessageType.TEXT,
source=MagicMock(),
message_id="q4",
)
adapter._pending_messages[session_key] = event
# Agent finishes (no interrupt)
del adapter._active_sessions[session_key]
# The queued message should still be retrievable
retrieved = adapter.get_pending_message(session_key)
assert retrieved is not None
assert retrieved.text == "process this after"
def test_multiple_queues_last_one_wins(self):
"""If user /queue's multiple times, last message overwrites."""
adapter = _StubAdapter()
session_key = "telegram:user:123"
for text in ["first", "second", "third"]:
event = MessageEvent(
text=text,
message_type=MessageType.TEXT,
source=MagicMock(),
message_id=f"q-{text}",
)
adapter._pending_messages[session_key] = event
retrieved = adapter.get_pending_message(session_key)
assert retrieved.text == "third"

View file

@ -0,0 +1,220 @@
"""Tests for gateway /reasoning command and hot reload behavior."""
import asyncio
import inspect
import sys
import types
from unittest.mock import AsyncMock, MagicMock
import pytest
import yaml
import gateway.run as gateway_run
from gateway.config import Platform
from gateway.platforms.base import MessageEvent
from gateway.session import SessionSource
def _make_event(text="/reasoning", platform=Platform.TELEGRAM, user_id="12345", chat_id="67890"):
"""Build a MessageEvent for testing."""
source = SessionSource(
platform=platform,
user_id=user_id,
chat_id=chat_id,
user_name="testuser",
)
return MessageEvent(text=text, source=source)
def _make_runner():
"""Create a bare GatewayRunner without calling __init__."""
runner = object.__new__(gateway_run.GatewayRunner)
runner.adapters = {}
runner._ephemeral_system_prompt = ""
runner._prefill_messages = []
runner._reasoning_config = None
runner._show_reasoning = False
runner._provider_routing = {}
runner._fallback_model = None
runner._running_agents = {}
runner.hooks = MagicMock()
runner.hooks.emit = AsyncMock()
runner.hooks.loaded_hooks = []
runner._session_db = None
runner._get_or_create_gateway_honcho = lambda session_key: (None, None)
return runner
class _CapturingAgent:
"""Fake agent that records init kwargs for assertions."""
last_init = None
def __init__(self, *args, **kwargs):
type(self).last_init = dict(kwargs)
self.tools = []
def run_conversation(self, user_message: str, conversation_history=None, task_id=None):
return {
"final_response": "ok",
"messages": [],
"api_calls": 1,
}
class TestReasoningCommand:
@pytest.mark.asyncio
async def test_reasoning_in_help_output(self):
runner = _make_runner()
event = _make_event(text="/help")
result = await runner._handle_help_command(event)
assert "/reasoning [level|show|hide]" in result
def test_reasoning_is_known_command(self):
source = inspect.getsource(gateway_run.GatewayRunner._handle_message)
assert '"reasoning"' in source
@pytest.mark.asyncio
async def test_reasoning_command_reloads_current_state_from_config(self, tmp_path, monkeypatch):
hermes_home = tmp_path / "hermes"
hermes_home.mkdir()
config_path = hermes_home / "config.yaml"
config_path.write_text(
"agent:\n reasoning_effort: none\ndisplay:\n show_reasoning: true\n",
encoding="utf-8",
)
monkeypatch.setattr(gateway_run, "_hermes_home", hermes_home)
monkeypatch.delenv("HERMES_REASONING_EFFORT", raising=False)
runner = _make_runner()
runner._reasoning_config = {"enabled": True, "effort": "xhigh"}
runner._show_reasoning = False
result = await runner._handle_reasoning_command(_make_event("/reasoning"))
assert "**Effort:** `none (disabled)`" in result
assert "**Display:** on ✓" in result
assert runner._reasoning_config == {"enabled": False}
assert runner._show_reasoning is True
@pytest.mark.asyncio
async def test_handle_reasoning_command_updates_config_and_cache(self, tmp_path, monkeypatch):
hermes_home = tmp_path / "hermes"
hermes_home.mkdir()
config_path = hermes_home / "config.yaml"
config_path.write_text("agent:\n reasoning_effort: medium\n", encoding="utf-8")
monkeypatch.setattr(gateway_run, "_hermes_home", hermes_home)
monkeypatch.delenv("HERMES_REASONING_EFFORT", raising=False)
runner = _make_runner()
runner._reasoning_config = {"enabled": True, "effort": "medium"}
result = await runner._handle_reasoning_command(_make_event("/reasoning low"))
saved = yaml.safe_load(config_path.read_text(encoding="utf-8"))
assert saved["agent"]["reasoning_effort"] == "low"
assert runner._reasoning_config == {"enabled": True, "effort": "low"}
assert "takes effect on next message" in result
def test_run_agent_reloads_reasoning_config_per_message(self, tmp_path, monkeypatch):
hermes_home = tmp_path / "hermes"
hermes_home.mkdir()
(hermes_home / "config.yaml").write_text("agent:\n reasoning_effort: low\n", encoding="utf-8")
monkeypatch.setattr(gateway_run, "_hermes_home", hermes_home)
monkeypatch.setattr(gateway_run, "_env_path", hermes_home / ".env")
monkeypatch.setattr(gateway_run, "load_dotenv", lambda *args, **kwargs: None)
monkeypatch.setattr(
gateway_run,
"_resolve_runtime_agent_kwargs",
lambda: {
"provider": "openrouter",
"api_mode": "chat_completions",
"base_url": "https://openrouter.ai/api/v1",
"api_key": "test-key",
},
)
monkeypatch.delenv("HERMES_REASONING_EFFORT", raising=False)
fake_run_agent = types.ModuleType("run_agent")
fake_run_agent.AIAgent = _CapturingAgent
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
_CapturingAgent.last_init = None
runner = _make_runner()
runner._reasoning_config = {"enabled": True, "effort": "xhigh"}
source = SessionSource(
platform=Platform.LOCAL,
chat_id="cli",
chat_name="CLI",
chat_type="dm",
user_id="user-1",
)
result = asyncio.run(
runner._run_agent(
message="ping",
context_prompt="",
history=[],
source=source,
session_id="session-1",
session_key="agent:main:local:dm",
)
)
assert result["final_response"] == "ok"
assert _CapturingAgent.last_init is not None
assert _CapturingAgent.last_init["reasoning_config"] == {"enabled": True, "effort": "low"}
def test_run_agent_prefers_config_over_stale_reasoning_env(self, tmp_path, monkeypatch):
hermes_home = tmp_path / "hermes"
hermes_home.mkdir()
(hermes_home / "config.yaml").write_text("agent:\n reasoning_effort: none\n", encoding="utf-8")
monkeypatch.setattr(gateway_run, "_hermes_home", hermes_home)
monkeypatch.setattr(gateway_run, "_env_path", hermes_home / ".env")
monkeypatch.setattr(gateway_run, "load_dotenv", lambda *args, **kwargs: None)
monkeypatch.setattr(
gateway_run,
"_resolve_runtime_agent_kwargs",
lambda: {
"provider": "openrouter",
"api_mode": "chat_completions",
"base_url": "https://openrouter.ai/api/v1",
"api_key": "test-key",
},
)
monkeypatch.setenv("HERMES_REASONING_EFFORT", "low")
fake_run_agent = types.ModuleType("run_agent")
fake_run_agent.AIAgent = _CapturingAgent
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
_CapturingAgent.last_init = None
runner = _make_runner()
source = SessionSource(
platform=Platform.LOCAL,
chat_id="cli",
chat_name="CLI",
chat_type="dm",
user_id="user-1",
)
result = asyncio.run(
runner._run_agent(
message="ping",
context_prompt="",
history=[],
source=source,
session_id="session-1",
session_key="agent:main:local:dm",
)
)
assert result["final_response"] == "ok"
assert _CapturingAgent.last_init is not None
assert _CapturingAgent.last_init["reasoning_config"] == {"enabled": False}

View file

@ -0,0 +1,226 @@
"""Tests for /resume gateway slash command.
Tests the _handle_resume_command handler (switch to a previously-named session)
across gateway messenger platforms.
"""
from unittest.mock import MagicMock, AsyncMock
import pytest
from gateway.config import Platform
from gateway.platforms.base import MessageEvent
from gateway.session import SessionSource, build_session_key
def _make_event(text="/resume", platform=Platform.TELEGRAM,
user_id="12345", chat_id="67890"):
"""Build a MessageEvent for testing."""
source = SessionSource(
platform=platform,
user_id=user_id,
chat_id=chat_id,
user_name="testuser",
)
return MessageEvent(text=text, source=source)
def _session_key_for_event(event):
"""Get the session key that build_session_key produces for an event."""
return build_session_key(event.source)
def _make_runner(session_db=None, current_session_id="current_session_001",
event=None):
"""Create a bare GatewayRunner with a mock session_store and optional session_db."""
from gateway.run import GatewayRunner
runner = object.__new__(GatewayRunner)
runner.adapters = {}
runner._voice_mode = {}
runner._session_db = session_db
runner._running_agents = {}
# Compute the real session key if an event is provided
session_key = build_session_key(event.source) if event else "agent:main:telegram:dm"
# Mock session_store that returns a session entry with a known session_id
mock_session_entry = MagicMock()
mock_session_entry.session_id = current_session_id
mock_session_entry.session_key = session_key
mock_store = MagicMock()
mock_store.get_or_create_session.return_value = mock_session_entry
mock_store.load_transcript.return_value = []
mock_store.switch_session.return_value = mock_session_entry
runner.session_store = mock_store
# Stub out memory flushing
runner._async_flush_memories = AsyncMock()
return runner
# ---------------------------------------------------------------------------
# _handle_resume_command
# ---------------------------------------------------------------------------
class TestHandleResumeCommand:
"""Tests for GatewayRunner._handle_resume_command."""
@pytest.mark.asyncio
async def test_no_session_db(self):
"""Returns error when session database is unavailable."""
runner = _make_runner(session_db=None)
event = _make_event(text="/resume My Project")
result = await runner._handle_resume_command(event)
assert "not available" in result.lower()
@pytest.mark.asyncio
async def test_list_named_sessions_when_no_arg(self, tmp_path):
"""With no argument, lists recently titled sessions."""
from hermes_state import SessionDB
db = SessionDB(db_path=tmp_path / "state.db")
db.create_session("sess_001", "telegram")
db.create_session("sess_002", "telegram")
db.set_session_title("sess_001", "Research")
db.set_session_title("sess_002", "Coding")
event = _make_event(text="/resume")
runner = _make_runner(session_db=db, event=event)
result = await runner._handle_resume_command(event)
assert "Research" in result
assert "Coding" in result
assert "Named Sessions" in result
db.close()
@pytest.mark.asyncio
async def test_list_shows_usage_when_no_titled(self, tmp_path):
"""With no arg and no titled sessions, shows instructions."""
from hermes_state import SessionDB
db = SessionDB(db_path=tmp_path / "state.db")
db.create_session("sess_001", "telegram") # No title
event = _make_event(text="/resume")
runner = _make_runner(session_db=db, event=event)
result = await runner._handle_resume_command(event)
assert "No named sessions" in result
assert "/title" in result
db.close()
@pytest.mark.asyncio
async def test_resume_by_name(self, tmp_path):
"""Resolves a title and switches to that session."""
from hermes_state import SessionDB
db = SessionDB(db_path=tmp_path / "state.db")
db.create_session("old_session_abc", "telegram")
db.set_session_title("old_session_abc", "My Project")
db.create_session("current_session_001", "telegram")
event = _make_event(text="/resume My Project")
runner = _make_runner(session_db=db, current_session_id="current_session_001",
event=event)
result = await runner._handle_resume_command(event)
assert "Resumed" in result
assert "My Project" in result
# Verify switch_session was called with the old session ID
runner.session_store.switch_session.assert_called_once()
call_args = runner.session_store.switch_session.call_args
assert call_args[0][1] == "old_session_abc"
db.close()
@pytest.mark.asyncio
async def test_resume_nonexistent_name(self, tmp_path):
"""Returns error for unknown session name."""
from hermes_state import SessionDB
db = SessionDB(db_path=tmp_path / "state.db")
db.create_session("current_session_001", "telegram")
event = _make_event(text="/resume Nonexistent Session")
runner = _make_runner(session_db=db, event=event)
result = await runner._handle_resume_command(event)
assert "No session found" in result
db.close()
@pytest.mark.asyncio
async def test_resume_already_on_session(self, tmp_path):
"""Returns friendly message when already on the requested session."""
from hermes_state import SessionDB
db = SessionDB(db_path=tmp_path / "state.db")
db.create_session("current_session_001", "telegram")
db.set_session_title("current_session_001", "Active Project")
event = _make_event(text="/resume Active Project")
runner = _make_runner(session_db=db, current_session_id="current_session_001",
event=event)
result = await runner._handle_resume_command(event)
assert "Already on session" in result
db.close()
@pytest.mark.asyncio
async def test_resume_auto_lineage(self, tmp_path):
"""Asking for 'My Project' when 'My Project #2' exists gets the latest."""
from hermes_state import SessionDB
db = SessionDB(db_path=tmp_path / "state.db")
db.create_session("sess_v1", "telegram")
db.set_session_title("sess_v1", "My Project")
db.create_session("sess_v2", "telegram")
db.set_session_title("sess_v2", "My Project #2")
db.create_session("current_session_001", "telegram")
event = _make_event(text="/resume My Project")
runner = _make_runner(session_db=db, current_session_id="current_session_001",
event=event)
result = await runner._handle_resume_command(event)
assert "Resumed" in result
# Should resolve to #2 (latest in lineage)
call_args = runner.session_store.switch_session.call_args
assert call_args[0][1] == "sess_v2"
db.close()
@pytest.mark.asyncio
async def test_resume_clears_running_agent(self, tmp_path):
"""Switching sessions clears any cached running agent."""
from hermes_state import SessionDB
db = SessionDB(db_path=tmp_path / "state.db")
db.create_session("old_session", "telegram")
db.set_session_title("old_session", "Old Work")
db.create_session("current_session_001", "telegram")
event = _make_event(text="/resume Old Work")
runner = _make_runner(session_db=db, current_session_id="current_session_001",
event=event)
# Simulate a running agent using the real session key
real_key = _session_key_for_event(event)
runner._running_agents[real_key] = MagicMock()
await runner._handle_resume_command(event)
assert real_key not in runner._running_agents
db.close()
@pytest.mark.asyncio
async def test_resume_flushes_memories_with_gateway_session_key(self, tmp_path):
"""Resume should preserve the gateway session key for Honcho flushes."""
from hermes_state import SessionDB
db = SessionDB(db_path=tmp_path / "state.db")
db.create_session("old_session", "telegram")
db.set_session_title("old_session", "Old Work")
db.create_session("current_session_001", "telegram")
event = _make_event(text="/resume Old Work")
runner = _make_runner(
session_db=db,
current_session_id="current_session_001",
event=event,
)
await runner._handle_resume_command(event)
runner._async_flush_memories.assert_called_once_with(
"current_session_001",
_session_key_for_event(event),
)
db.close()

View file

@ -0,0 +1,97 @@
"""Regression tests for /retry replacement semantics."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from gateway.config import GatewayConfig
from gateway.platforms.base import MessageEvent, MessageType
from gateway.run import GatewayRunner
from gateway.session import SessionStore
@pytest.mark.asyncio
async def test_gateway_retry_replaces_last_user_turn_in_transcript(tmp_path):
config = GatewayConfig()
with patch("gateway.session.SessionStore._ensure_loaded"):
store = SessionStore(sessions_dir=tmp_path, config=config)
store._db = None
store._loaded = True
session_id = "retry_session"
for msg in [
{"role": "session_meta", "tools": []},
{"role": "user", "content": "first question"},
{"role": "assistant", "content": "first answer"},
{"role": "user", "content": "retry me"},
{"role": "assistant", "content": "old answer"},
]:
store.append_to_transcript(session_id, msg)
gw = GatewayRunner.__new__(GatewayRunner)
gw.config = config
gw.session_store = store
session_entry = MagicMock(session_id=session_id)
session_entry.last_prompt_tokens = 111
gw.session_store.get_or_create_session = MagicMock(return_value=session_entry)
async def fake_handle_message(event):
assert event.text == "retry me"
transcript_before = store.load_transcript(session_id)
assert [m.get("content") for m in transcript_before if m.get("role") == "user"] == [
"first question"
]
store.append_to_transcript(session_id, {"role": "user", "content": event.text})
store.append_to_transcript(session_id, {"role": "assistant", "content": "new answer"})
return "new answer"
gw._handle_message = AsyncMock(side_effect=fake_handle_message)
result = await gw._handle_retry_command(
MessageEvent(text="/retry", message_type=MessageType.TEXT, source=MagicMock())
)
assert result == "new answer"
transcript_after = store.load_transcript(session_id)
assert [m.get("content") for m in transcript_after if m.get("role") == "user"] == [
"first question",
"retry me",
]
assert [m.get("content") for m in transcript_after if m.get("role") == "assistant"] == [
"first answer",
"new answer",
]
@pytest.mark.asyncio
async def test_gateway_retry_replays_original_text_not_retry_command(tmp_path):
config = MagicMock()
config.sessions_dir = tmp_path
config.max_context_messages = 20
gw = GatewayRunner.__new__(GatewayRunner)
gw.config = config
gw.session_store = MagicMock()
session_entry = MagicMock(session_id="test-session")
session_entry.last_prompt_tokens = 55
gw.session_store.get_or_create_session.return_value = session_entry
gw.session_store.load_transcript.return_value = [
{"role": "user", "content": "real message"},
{"role": "assistant", "content": "answer"},
]
gw.session_store.rewrite_transcript = MagicMock()
captured = {}
async def fake_handle_message(event):
captured["text"] = event.text
return "ok"
gw._handle_message = AsyncMock(side_effect=fake_handle_message)
await gw._handle_retry_command(
MessageEvent(text="/retry", message_type=MessageType.TEXT, source=MagicMock())
)
assert captured["text"] == "real message"

View file

@ -0,0 +1,60 @@
"""Regression test: /retry must return the agent response, not None.
Before the fix in PR #441, _handle_retry_command() called
_handle_message(retry_event) but discarded its return value with `return None`,
so users never received the final response.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock
from gateway.run import GatewayRunner
from gateway.platforms.base import MessageEvent, MessageType
@pytest.fixture
def gateway(tmp_path):
config = MagicMock()
config.sessions_dir = tmp_path
config.max_context_messages = 20
gw = GatewayRunner.__new__(GatewayRunner)
gw.config = config
gw.session_store = MagicMock()
return gw
@pytest.mark.asyncio
async def test_retry_returns_response_not_none(gateway):
"""_handle_retry_command must return the inner handler response, not None."""
gateway.session_store.get_or_create_session.return_value = MagicMock(
session_id="test-session"
)
gateway.session_store.load_transcript.return_value = [
{"role": "user", "content": "Hello Hermes"},
{"role": "assistant", "content": "Hi there!"},
]
gateway.session_store.rewrite_transcript = MagicMock()
expected_response = "Hi there! (retried)"
gateway._handle_message = AsyncMock(return_value=expected_response)
event = MessageEvent(
text="/retry",
message_type=MessageType.TEXT,
source=MagicMock(),
)
result = await gateway._handle_retry_command(event)
assert result is not None, "/retry must not return None"
assert result == expected_response
@pytest.mark.asyncio
async def test_retry_no_previous_message(gateway):
"""If there is no previous user message, return early with a message."""
gateway.session_store.get_or_create_session.return_value = MagicMock(
session_id="test-session"
)
gateway.session_store.load_transcript.return_value = []
event = MessageEvent(
text="/retry",
message_type=MessageType.TEXT,
source=MagicMock(),
)
result = await gateway._handle_retry_command(event)
assert result == "No previous message to retry."

View file

@ -0,0 +1,135 @@
"""Tests for topic-aware gateway progress updates."""
import importlib
import sys
import time
import types
from types import SimpleNamespace
import pytest
from gateway.config import Platform, PlatformConfig
from gateway.platforms.base import BasePlatformAdapter, SendResult
from gateway.session import SessionSource
class ProgressCaptureAdapter(BasePlatformAdapter):
def __init__(self):
super().__init__(PlatformConfig(enabled=True, token="fake-token"), Platform.TELEGRAM)
self.sent = []
self.edits = []
self.typing = []
async def connect(self) -> bool:
return True
async def disconnect(self) -> None:
return None
async def send(self, chat_id, content, reply_to=None, metadata=None) -> SendResult:
self.sent.append(
{
"chat_id": chat_id,
"content": content,
"reply_to": reply_to,
"metadata": metadata,
}
)
return SendResult(success=True, message_id="progress-1")
async def edit_message(self, chat_id, message_id, content) -> SendResult:
self.edits.append(
{
"chat_id": chat_id,
"message_id": message_id,
"content": content,
}
)
return SendResult(success=True, message_id=message_id)
async def send_typing(self, chat_id, metadata=None) -> None:
self.typing.append({"chat_id": chat_id, "metadata": metadata})
async def get_chat_info(self, chat_id: str):
return {"id": chat_id}
class FakeAgent:
def __init__(self, **kwargs):
self.tool_progress_callback = kwargs.get("tool_progress_callback")
self.tools = []
def run_conversation(self, message, conversation_history=None, task_id=None):
self.tool_progress_callback("terminal", "pwd")
time.sleep(0.35)
self.tool_progress_callback("browser_navigate", "https://example.com")
time.sleep(0.35)
return {
"final_response": "done",
"messages": [],
"api_calls": 1,
}
def _make_runner(adapter):
gateway_run = importlib.import_module("gateway.run")
GatewayRunner = gateway_run.GatewayRunner
runner = object.__new__(GatewayRunner)
runner.adapters = {Platform.TELEGRAM: adapter}
runner._voice_mode = {}
runner._prefill_messages = []
runner._ephemeral_system_prompt = ""
runner._reasoning_config = None
runner._provider_routing = {}
runner._fallback_model = None
runner._session_db = None
runner._running_agents = {}
runner.hooks = SimpleNamespace(loaded_hooks=False)
return runner
@pytest.mark.asyncio
async def test_run_agent_progress_stays_in_originating_topic(monkeypatch, tmp_path):
monkeypatch.setenv("HERMES_TOOL_PROGRESS_MODE", "all")
fake_dotenv = types.ModuleType("dotenv")
fake_dotenv.load_dotenv = lambda *args, **kwargs: None
monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv)
fake_run_agent = types.ModuleType("run_agent")
fake_run_agent.AIAgent = FakeAgent
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
adapter = ProgressCaptureAdapter()
runner = _make_runner(adapter)
gateway_run = importlib.import_module("gateway.run")
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "fake"})
source = SessionSource(
platform=Platform.TELEGRAM,
chat_id="-1001",
chat_type="group",
thread_id="17585",
)
result = await runner._run_agent(
message="hello",
context_prompt="",
history=[],
source=source,
session_id="sess-1",
session_key="agent:main:telegram:group:-1001:17585",
)
assert result["final_response"] == "done"
assert adapter.sent == [
{
"chat_id": "-1001",
"content": '💻 terminal: "pwd"',
"reply_to": None,
"metadata": {"thread_id": "17585"},
}
]
assert adapter.edits
assert all(call["metadata"] == {"thread_id": "17585"} for call in adapter.typing)

View file

@ -0,0 +1,95 @@
from unittest.mock import AsyncMock
import pytest
from gateway.config import GatewayConfig, Platform, PlatformConfig
from gateway.platforms.base import BasePlatformAdapter
from gateway.run import GatewayRunner
class _FatalAdapter(BasePlatformAdapter):
def __init__(self):
super().__init__(PlatformConfig(enabled=True, token="token"), Platform.TELEGRAM)
async def connect(self) -> bool:
self._set_fatal_error(
"telegram_token_lock",
"Another local Hermes gateway is already using this Telegram bot token.",
retryable=False,
)
return False
async def disconnect(self) -> None:
self._mark_disconnected()
async def send(self, chat_id, content, reply_to=None, metadata=None):
raise NotImplementedError
async def get_chat_info(self, chat_id):
return {"id": chat_id}
class _RuntimeRetryableAdapter(BasePlatformAdapter):
def __init__(self):
super().__init__(PlatformConfig(enabled=True, token="token"), Platform.WHATSAPP)
async def connect(self) -> bool:
return True
async def disconnect(self) -> None:
self._mark_disconnected()
async def send(self, chat_id, content, reply_to=None, metadata=None):
raise NotImplementedError
async def get_chat_info(self, chat_id):
return {"id": chat_id}
@pytest.mark.asyncio
async def test_runner_requests_clean_exit_for_nonretryable_startup_conflict(monkeypatch, tmp_path):
config = GatewayConfig(
platforms={
Platform.TELEGRAM: PlatformConfig(enabled=True, token="token")
},
sessions_dir=tmp_path / "sessions",
)
runner = GatewayRunner(config)
monkeypatch.setattr(runner, "_create_adapter", lambda platform, platform_config: _FatalAdapter())
ok = await runner.start()
assert ok is True
assert runner.should_exit_cleanly is True
assert "already using this Telegram bot token" in runner.exit_reason
@pytest.mark.asyncio
async def test_runner_queues_retryable_runtime_fatal_for_reconnection(monkeypatch, tmp_path):
"""Retryable runtime fatal errors queue the platform for reconnection
instead of shutting down the gateway."""
config = GatewayConfig(
platforms={
Platform.WHATSAPP: PlatformConfig(enabled=True, token="token")
},
sessions_dir=tmp_path / "sessions",
)
runner = GatewayRunner(config)
adapter = _RuntimeRetryableAdapter()
adapter._set_fatal_error(
"whatsapp_bridge_exited",
"WhatsApp bridge process exited unexpectedly (code 1).",
retryable=True,
)
runner.adapters = {Platform.WHATSAPP: adapter}
runner.delivery_router.adapters = runner.adapters
runner.stop = AsyncMock()
await runner._handle_adapter_fatal_error(adapter)
# Should NOT shut down — platform is queued for reconnection
runner.stop.assert_not_awaited()
assert Platform.WHATSAPP in runner._failed_platforms
assert runner._failed_platforms[Platform.WHATSAPP]["attempts"] == 0

View file

@ -0,0 +1,89 @@
import pytest
from gateway.config import GatewayConfig, Platform, PlatformConfig
from gateway.platforms.base import BasePlatformAdapter
from gateway.run import GatewayRunner
from gateway.status import read_runtime_status
class _RetryableFailureAdapter(BasePlatformAdapter):
def __init__(self):
super().__init__(PlatformConfig(enabled=True, token="***"), Platform.TELEGRAM)
async def connect(self) -> bool:
self._set_fatal_error(
"telegram_connect_error",
"Telegram startup failed: temporary DNS resolution failure.",
retryable=True,
)
return False
async def disconnect(self) -> None:
self._mark_disconnected()
async def send(self, chat_id, content, reply_to=None, metadata=None):
raise NotImplementedError
async def get_chat_info(self, chat_id):
return {"id": chat_id}
class _DisabledAdapter(BasePlatformAdapter):
def __init__(self):
super().__init__(PlatformConfig(enabled=False, token="***"), Platform.TELEGRAM)
async def connect(self) -> bool:
raise AssertionError("connect should not be called for disabled platforms")
async def disconnect(self) -> None:
self._mark_disconnected()
async def send(self, chat_id, content, reply_to=None, metadata=None):
raise NotImplementedError
async def get_chat_info(self, chat_id):
return {"id": chat_id}
@pytest.mark.asyncio
async def test_runner_returns_failure_for_retryable_startup_errors(monkeypatch, tmp_path):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
config = GatewayConfig(
platforms={
Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")
},
sessions_dir=tmp_path / "sessions",
)
runner = GatewayRunner(config)
monkeypatch.setattr(runner, "_create_adapter", lambda platform, platform_config: _RetryableFailureAdapter())
ok = await runner.start()
assert ok is False
assert runner.should_exit_cleanly is False
state = read_runtime_status()
assert state["gateway_state"] == "startup_failed"
assert "temporary DNS resolution failure" in state["exit_reason"]
assert state["platforms"]["telegram"]["state"] == "fatal"
assert state["platforms"]["telegram"]["error_code"] == "telegram_connect_error"
@pytest.mark.asyncio
async def test_runner_allows_cron_only_mode_when_no_platforms_are_enabled(monkeypatch, tmp_path):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
config = GatewayConfig(
platforms={
Platform.TELEGRAM: PlatformConfig(enabled=False, token="***")
},
sessions_dir=tmp_path / "sessions",
)
runner = GatewayRunner(config)
ok = await runner.start()
assert ok is True
assert runner.should_exit_cleanly is False
assert runner.adapters == {}
state = read_runtime_status()
assert state["gateway_state"] == "running"

View file

@ -0,0 +1,437 @@
"""
Tests for send_image_file() on Telegram, Discord, and Slack platforms,
and MEDIA: .png extraction/routing in the base platform adapter.
Covers: local image file sending, file-not-found handling, fallback on error,
MEDIA: tag extraction for image extensions, and routing to send_image_file.
"""
import asyncio
import os
import sys
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from gateway.config import PlatformConfig
from gateway.platforms.base import BasePlatformAdapter, SendResult
def _run(coro):
"""Run a coroutine in a fresh event loop for sync-style tests."""
return asyncio.run(coro)
# ---------------------------------------------------------------------------
# MEDIA: extraction tests for image files
# ---------------------------------------------------------------------------
class TestExtractMediaImages:
"""Test that MEDIA: tags with image extensions are correctly extracted."""
def test_png_image_extracted(self):
content = "Here is the screenshot:\nMEDIA:/home/user/.hermes/browser_screenshots/shot.png"
media, cleaned = BasePlatformAdapter.extract_media(content)
assert len(media) == 1
assert media[0][0] == "/home/user/.hermes/browser_screenshots/shot.png"
assert "MEDIA:" not in cleaned
assert "Here is the screenshot" in cleaned
def test_jpg_image_extracted(self):
content = "MEDIA:/tmp/photo.jpg"
media, cleaned = BasePlatformAdapter.extract_media(content)
assert len(media) == 1
assert media[0][0] == "/tmp/photo.jpg"
def test_webp_image_extracted(self):
content = "MEDIA:/tmp/image.webp"
media, _ = BasePlatformAdapter.extract_media(content)
assert len(media) == 1
def test_mixed_audio_and_image(self):
content = "MEDIA:/audio.ogg\nMEDIA:/screenshot.png"
media, _ = BasePlatformAdapter.extract_media(content)
assert len(media) == 2
paths = [m[0] for m in media]
assert "/audio.ogg" in paths
assert "/screenshot.png" in paths
# ---------------------------------------------------------------------------
# Telegram send_image_file tests
# ---------------------------------------------------------------------------
def _ensure_telegram_mock():
"""Install mock telegram modules so TelegramAdapter can be imported."""
if "telegram" in sys.modules and hasattr(sys.modules["telegram"], "__file__"):
return
telegram_mod = MagicMock()
telegram_mod.ext.ContextTypes.DEFAULT_TYPE = type(None)
telegram_mod.constants.ParseMode.MARKDOWN_V2 = "MarkdownV2"
telegram_mod.constants.ChatType.GROUP = "group"
telegram_mod.constants.ChatType.SUPERGROUP = "supergroup"
telegram_mod.constants.ChatType.CHANNEL = "channel"
telegram_mod.constants.ChatType.PRIVATE = "private"
for name in ("telegram", "telegram.ext", "telegram.constants"):
sys.modules.setdefault(name, telegram_mod)
_ensure_telegram_mock()
from gateway.platforms.telegram import TelegramAdapter # noqa: E402
class TestTelegramSendImageFile:
@pytest.fixture
def adapter(self):
config = PlatformConfig(enabled=True, token="fake-token")
a = TelegramAdapter(config)
a._bot = MagicMock()
return a
def test_sends_local_image_as_photo(self, adapter, tmp_path):
"""send_image_file should call bot.send_photo with the opened file."""
img = tmp_path / "screenshot.png"
img.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100) # Minimal PNG-like
mock_msg = MagicMock()
mock_msg.message_id = 42
adapter._bot.send_photo = AsyncMock(return_value=mock_msg)
result = _run(
adapter.send_image_file(chat_id="12345", image_path=str(img))
)
assert result.success
assert result.message_id == "42"
adapter._bot.send_photo.assert_awaited_once()
# Verify photo arg was a file object (opened in rb mode)
call_kwargs = adapter._bot.send_photo.call_args
assert call_kwargs.kwargs["chat_id"] == 12345
def test_returns_error_when_file_missing(self, adapter):
"""send_image_file should return error for nonexistent file."""
result = _run(
adapter.send_image_file(chat_id="12345", image_path="/nonexistent/image.png")
)
assert not result.success
assert "not found" in result.error
def test_returns_error_when_not_connected(self, adapter):
"""send_image_file should return error when bot is None."""
adapter._bot = None
result = _run(
adapter.send_image_file(chat_id="12345", image_path="/tmp/img.png")
)
assert not result.success
assert "Not connected" in result.error
def test_caption_truncated_to_1024(self, adapter, tmp_path):
"""Telegram captions have a 1024 char limit."""
img = tmp_path / "shot.png"
img.write_bytes(b"\x89PNG" + b"\x00" * 50)
mock_msg = MagicMock()
mock_msg.message_id = 1
adapter._bot.send_photo = AsyncMock(return_value=mock_msg)
long_caption = "A" * 2000
_run(
adapter.send_image_file(chat_id="12345", image_path=str(img), caption=long_caption)
)
call_kwargs = adapter._bot.send_photo.call_args.kwargs
assert len(call_kwargs["caption"]) == 1024
def test_thread_id_forwarded(self, adapter, tmp_path):
"""metadata thread_id is forwarded as message_thread_id (required for Telegram forum groups)."""
img = tmp_path / "shot.png"
img.write_bytes(b"\x89PNG" + b"\x00" * 50)
mock_msg = MagicMock()
mock_msg.message_id = 43
adapter._bot.send_photo = AsyncMock(return_value=mock_msg)
_run(
adapter.send_image_file(
chat_id="12345",
image_path=str(img),
metadata={"thread_id": "789"},
)
)
call_kwargs = adapter._bot.send_photo.call_args.kwargs
assert call_kwargs["message_thread_id"] == 789
# ---------------------------------------------------------------------------
# Discord send_image_file tests
# ---------------------------------------------------------------------------
def _ensure_discord_mock():
"""Install mock discord module so DiscordAdapter can be imported."""
if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
return
discord_mod = MagicMock()
discord_mod.Intents.default.return_value = MagicMock()
discord_mod.Client = MagicMock
discord_mod.File = MagicMock
for name in ("discord", "discord.ext", "discord.ext.commands"):
sys.modules.setdefault(name, discord_mod)
_ensure_discord_mock()
import discord as discord_mod_ref # noqa: E402
from gateway.platforms.discord import DiscordAdapter # noqa: E402
class TestDiscordSendImageFile:
@pytest.fixture
def adapter(self):
config = PlatformConfig(enabled=True, token="fake-token")
a = DiscordAdapter(config)
a._client = MagicMock()
return a
def test_sends_local_image_as_attachment(self, adapter, tmp_path):
"""send_image_file should create discord.File and send to channel."""
img = tmp_path / "screenshot.png"
img.write_bytes(b"\x89PNG" + b"\x00" * 50)
mock_channel = MagicMock()
mock_msg = MagicMock()
mock_msg.id = 99
mock_channel.send = AsyncMock(return_value=mock_msg)
adapter._client.get_channel = MagicMock(return_value=mock_channel)
result = _run(
adapter.send_image_file(chat_id="67890", image_path=str(img))
)
assert result.success
assert result.message_id == "99"
mock_channel.send.assert_awaited_once()
def test_send_document_uploads_file_attachment(self, adapter, tmp_path):
"""send_document should upload a native Discord attachment."""
pdf = tmp_path / "sample.pdf"
pdf.write_bytes(b"%PDF-1.4\n%\xe2\xe3\xcf\xd3\n")
mock_channel = MagicMock()
mock_msg = MagicMock()
mock_msg.id = 100
mock_channel.send = AsyncMock(return_value=mock_msg)
adapter._client.get_channel = MagicMock(return_value=mock_channel)
with patch.object(discord_mod_ref, "File", MagicMock()) as file_cls:
result = _run(
adapter.send_document(
chat_id="67890",
file_path=str(pdf),
file_name="renamed.pdf",
metadata={"thread_id": "123"},
)
)
assert result.success
assert result.message_id == "100"
assert "file" in mock_channel.send.call_args.kwargs
assert file_cls.call_args.kwargs["filename"] == "renamed.pdf"
def test_send_video_uploads_file_attachment(self, adapter, tmp_path):
"""send_video should upload a native Discord attachment."""
video = tmp_path / "clip.mp4"
video.write_bytes(b"\x00\x00\x00\x18ftypmp42" + b"\x00" * 50)
mock_channel = MagicMock()
mock_msg = MagicMock()
mock_msg.id = 101
mock_channel.send = AsyncMock(return_value=mock_msg)
adapter._client.get_channel = MagicMock(return_value=mock_channel)
with patch.object(discord_mod_ref, "File", MagicMock()) as file_cls:
result = _run(
adapter.send_video(
chat_id="67890",
video_path=str(video),
metadata={"thread_id": "123"},
)
)
assert result.success
assert result.message_id == "101"
assert "file" in mock_channel.send.call_args.kwargs
assert file_cls.call_args.kwargs["filename"] == "clip.mp4"
def test_returns_error_when_file_missing(self, adapter):
result = _run(
adapter.send_image_file(chat_id="67890", image_path="/nonexistent.png")
)
assert not result.success
assert "not found" in result.error
def test_returns_error_when_not_connected(self, adapter):
adapter._client = None
result = _run(
adapter.send_image_file(chat_id="67890", image_path="/tmp/img.png")
)
assert not result.success
assert "Not connected" in result.error
def test_handles_missing_channel(self, adapter):
adapter._client.get_channel = MagicMock(return_value=None)
adapter._client.fetch_channel = AsyncMock(return_value=None)
result = _run(
adapter.send_image_file(chat_id="99999", image_path="/tmp/img.png")
)
assert not result.success
assert "not found" in result.error
# ---------------------------------------------------------------------------
# Slack send_image_file tests
# ---------------------------------------------------------------------------
def _ensure_slack_mock():
"""Install mock slack_bolt module so SlackAdapter can be imported."""
if "slack_bolt" in sys.modules and hasattr(sys.modules["slack_bolt"], "__file__"):
return
slack_mod = MagicMock()
for name in ("slack_bolt", "slack_bolt.async_app", "slack_sdk", "slack_sdk.web.async_client"):
sys.modules.setdefault(name, slack_mod)
_ensure_slack_mock()
from gateway.platforms.slack import SlackAdapter # noqa: E402
class TestSlackSendImageFile:
@pytest.fixture
def adapter(self):
config = PlatformConfig(enabled=True, token="xoxb-fake")
a = SlackAdapter(config)
a._app = MagicMock()
return a
def test_sends_local_image_via_upload(self, adapter, tmp_path):
"""send_image_file should call files_upload_v2 with the local path."""
img = tmp_path / "screenshot.png"
img.write_bytes(b"\x89PNG" + b"\x00" * 50)
mock_result = MagicMock()
adapter._app.client.files_upload_v2 = AsyncMock(return_value=mock_result)
result = _run(
adapter.send_image_file(chat_id="C12345", image_path=str(img))
)
assert result.success
adapter._app.client.files_upload_v2.assert_awaited_once()
call_kwargs = adapter._app.client.files_upload_v2.call_args.kwargs
assert call_kwargs["file"] == str(img)
assert call_kwargs["filename"] == "screenshot.png"
assert call_kwargs["channel"] == "C12345"
def test_returns_error_when_file_missing(self, adapter):
result = _run(
adapter.send_image_file(chat_id="C12345", image_path="/nonexistent.png")
)
assert not result.success
assert "not found" in result.error
def test_returns_error_when_not_connected(self, adapter):
adapter._app = None
result = _run(
adapter.send_image_file(chat_id="C12345", image_path="/tmp/img.png")
)
assert not result.success
assert "Not connected" in result.error
# ---------------------------------------------------------------------------
# browser_vision screenshot cleanup tests
# ---------------------------------------------------------------------------
class TestScreenshotCleanup:
def test_cleanup_removes_old_screenshots(self, tmp_path):
"""_cleanup_old_screenshots should remove files older than max_age_hours."""
import time
from tools.browser_tool import _cleanup_old_screenshots, _last_screenshot_cleanup_by_dir
_last_screenshot_cleanup_by_dir.clear()
# Create a "fresh" file
fresh = tmp_path / "browser_screenshot_fresh.png"
fresh.write_bytes(b"new")
# Create an "old" file and backdate its mtime
old = tmp_path / "browser_screenshot_old.png"
old.write_bytes(b"old")
old_time = time.time() - (25 * 3600) # 25 hours ago
os.utime(str(old), (old_time, old_time))
_cleanup_old_screenshots(tmp_path, max_age_hours=24)
assert fresh.exists(), "Fresh screenshot should not be removed"
assert not old.exists(), "Old screenshot should be removed"
def test_cleanup_is_throttled_per_directory(self, tmp_path):
import time
from tools.browser_tool import _cleanup_old_screenshots, _last_screenshot_cleanup_by_dir
_last_screenshot_cleanup_by_dir.clear()
old = tmp_path / "browser_screenshot_old.png"
old.write_bytes(b"old")
old_time = time.time() - (25 * 3600)
os.utime(str(old), (old_time, old_time))
_cleanup_old_screenshots(tmp_path, max_age_hours=24)
assert not old.exists()
old.write_bytes(b"old-again")
os.utime(str(old), (old_time, old_time))
_cleanup_old_screenshots(tmp_path, max_age_hours=24)
assert old.exists(), "Repeated cleanup should be skipped while throttled"
def test_cleanup_ignores_non_screenshot_files(self, tmp_path):
"""Only files matching browser_screenshot_*.png should be cleaned."""
import time
from tools.browser_tool import _cleanup_old_screenshots, _last_screenshot_cleanup_by_dir
_last_screenshot_cleanup_by_dir.clear()
other_file = tmp_path / "important_data.txt"
other_file.write_bytes(b"keep me")
old_time = time.time() - (48 * 3600)
os.utime(str(other_file), (old_time, old_time))
_cleanup_old_screenshots(tmp_path, max_age_hours=24)
assert other_file.exists(), "Non-screenshot files should not be touched"
def test_cleanup_handles_empty_dir(self, tmp_path):
"""Cleanup should not fail on empty directory."""
from tools.browser_tool import _cleanup_old_screenshots, _last_screenshot_cleanup_by_dir
_last_screenshot_cleanup_by_dir.clear()
_cleanup_old_screenshots(tmp_path, max_age_hours=24) # Should not raise
def test_cleanup_handles_nonexistent_dir(self):
"""Cleanup should not fail if directory doesn't exist."""
from pathlib import Path
from tools.browser_tool import _cleanup_old_screenshots, _last_screenshot_cleanup_by_dir
_last_screenshot_cleanup_by_dir.clear()
_cleanup_old_screenshots(Path("/nonexistent/dir"), max_age_hours=24) # Should not raise

View file

@ -0,0 +1,767 @@
"""Tests for gateway session management."""
import json
import pytest
from pathlib import Path
from unittest.mock import patch, MagicMock
from gateway.config import Platform, HomeChannel, GatewayConfig, PlatformConfig
from gateway.session import (
SessionSource,
SessionStore,
build_session_context,
build_session_context_prompt,
build_session_key,
)
class TestSessionSourceRoundtrip:
def test_full_roundtrip(self):
source = SessionSource(
platform=Platform.TELEGRAM,
chat_id="12345",
chat_name="My Group",
chat_type="group",
user_id="99",
user_name="alice",
thread_id="t1",
)
d = source.to_dict()
restored = SessionSource.from_dict(d)
assert restored.platform == Platform.TELEGRAM
assert restored.chat_id == "12345"
assert restored.chat_name == "My Group"
assert restored.chat_type == "group"
assert restored.user_id == "99"
assert restored.user_name == "alice"
assert restored.thread_id == "t1"
def test_full_roundtrip_with_chat_topic(self):
"""chat_topic should survive to_dict/from_dict roundtrip."""
source = SessionSource(
platform=Platform.DISCORD,
chat_id="789",
chat_name="Server / #project-planning",
chat_type="group",
user_id="42",
user_name="bob",
chat_topic="Planning and coordination for Project X",
)
d = source.to_dict()
assert d["chat_topic"] == "Planning and coordination for Project X"
restored = SessionSource.from_dict(d)
assert restored.chat_topic == "Planning and coordination for Project X"
assert restored.chat_name == "Server / #project-planning"
def test_minimal_roundtrip(self):
source = SessionSource(platform=Platform.LOCAL, chat_id="cli")
d = source.to_dict()
restored = SessionSource.from_dict(d)
assert restored.platform == Platform.LOCAL
assert restored.chat_id == "cli"
assert restored.chat_type == "dm" # default value preserved
def test_chat_id_coerced_to_string(self):
"""from_dict should handle numeric chat_id (common from Telegram)."""
restored = SessionSource.from_dict({
"platform": "telegram",
"chat_id": 12345,
})
assert restored.chat_id == "12345"
assert isinstance(restored.chat_id, str)
def test_missing_optional_fields(self):
restored = SessionSource.from_dict({
"platform": "discord",
"chat_id": "abc",
})
assert restored.chat_name is None
assert restored.user_id is None
assert restored.user_name is None
assert restored.thread_id is None
assert restored.chat_topic is None
assert restored.chat_type == "dm"
def test_invalid_platform_raises(self):
with pytest.raises((ValueError, KeyError)):
SessionSource.from_dict({"platform": "nonexistent", "chat_id": "1"})
class TestSessionSourceDescription:
def test_local_cli(self):
source = SessionSource.local_cli()
assert source.description == "CLI terminal"
def test_dm_with_username(self):
source = SessionSource(
platform=Platform.TELEGRAM, chat_id="123",
chat_type="dm", user_name="bob",
)
assert "DM" in source.description
assert "bob" in source.description
def test_dm_without_username_falls_back_to_user_id(self):
source = SessionSource(
platform=Platform.TELEGRAM, chat_id="123",
chat_type="dm", user_id="456",
)
assert "456" in source.description
def test_group_shows_chat_name(self):
source = SessionSource(
platform=Platform.DISCORD, chat_id="789",
chat_type="group", chat_name="Dev Chat",
)
assert "group" in source.description
assert "Dev Chat" in source.description
def test_channel_type(self):
source = SessionSource(
platform=Platform.TELEGRAM, chat_id="100",
chat_type="channel", chat_name="Announcements",
)
assert "channel" in source.description
assert "Announcements" in source.description
def test_thread_id_appended(self):
source = SessionSource(
platform=Platform.DISCORD, chat_id="789",
chat_type="group", chat_name="General",
thread_id="thread-42",
)
assert "thread" in source.description
assert "thread-42" in source.description
def test_unknown_chat_type_uses_name(self):
source = SessionSource(
platform=Platform.SLACK, chat_id="C01",
chat_type="forum", chat_name="Questions",
)
assert "Questions" in source.description
class TestLocalCliFactory:
def test_local_cli_defaults(self):
source = SessionSource.local_cli()
assert source.platform == Platform.LOCAL
assert source.chat_id == "cli"
assert source.chat_type == "dm"
assert source.chat_name == "CLI terminal"
class TestBuildSessionContextPrompt:
def test_telegram_prompt_contains_platform_and_chat(self):
config = GatewayConfig(
platforms={
Platform.TELEGRAM: PlatformConfig(
enabled=True,
token="fake-token",
home_channel=HomeChannel(
platform=Platform.TELEGRAM,
chat_id="111",
name="Home Chat",
),
),
},
)
source = SessionSource(
platform=Platform.TELEGRAM,
chat_id="111",
chat_name="Home Chat",
chat_type="dm",
)
ctx = build_session_context(source, config)
prompt = build_session_context_prompt(ctx)
assert "Telegram" in prompt
assert "Home Chat" in prompt
def test_discord_prompt(self):
config = GatewayConfig(
platforms={
Platform.DISCORD: PlatformConfig(
enabled=True,
token="fake-d...oken",
),
},
)
source = SessionSource(
platform=Platform.DISCORD,
chat_id="guild-123",
chat_name="Server",
chat_type="group",
user_name="alice",
)
ctx = build_session_context(source, config)
prompt = build_session_context_prompt(ctx)
assert "Discord" in prompt
assert "cannot search" in prompt.lower() or "do not have access" in prompt.lower()
def test_slack_prompt_includes_platform_notes(self):
config = GatewayConfig(
platforms={
Platform.SLACK: PlatformConfig(enabled=True, token="fake"),
},
)
source = SessionSource(
platform=Platform.SLACK,
chat_id="C123",
chat_name="general",
chat_type="group",
user_name="bob",
)
ctx = build_session_context(source, config)
prompt = build_session_context_prompt(ctx)
assert "Slack" in prompt
assert "cannot search" in prompt.lower()
assert "pin" in prompt.lower()
def test_discord_prompt_with_channel_topic(self):
"""Channel topic should appear in the session context prompt."""
config = GatewayConfig(
platforms={
Platform.DISCORD: PlatformConfig(
enabled=True,
token="fake-discord-token",
),
},
)
source = SessionSource(
platform=Platform.DISCORD,
chat_id="guild-123",
chat_name="Server / #project-planning",
chat_type="group",
user_name="alice",
chat_topic="Planning and coordination for Project X",
)
ctx = build_session_context(source, config)
prompt = build_session_context_prompt(ctx)
assert "Discord" in prompt
assert "**Channel Topic:** Planning and coordination for Project X" in prompt
def test_prompt_omits_channel_topic_when_none(self):
"""Channel Topic line should NOT appear when chat_topic is None."""
config = GatewayConfig(
platforms={
Platform.DISCORD: PlatformConfig(
enabled=True,
token="fake-discord-token",
),
},
)
source = SessionSource(
platform=Platform.DISCORD,
chat_id="guild-123",
chat_name="Server / #general",
chat_type="group",
user_name="alice",
)
ctx = build_session_context(source, config)
prompt = build_session_context_prompt(ctx)
assert "Channel Topic" not in prompt
def test_local_prompt_mentions_machine(self):
config = GatewayConfig()
source = SessionSource.local_cli()
ctx = build_session_context(source, config)
prompt = build_session_context_prompt(ctx)
assert "Local" in prompt
assert "machine running this agent" in prompt
def test_whatsapp_prompt(self):
config = GatewayConfig(
platforms={
Platform.WHATSAPP: PlatformConfig(enabled=True, token=""),
},
)
source = SessionSource(
platform=Platform.WHATSAPP,
chat_id="15551234567@s.whatsapp.net",
chat_type="dm",
user_name="Phone User",
)
ctx = build_session_context(source, config)
prompt = build_session_context_prompt(ctx)
assert "WhatsApp" in prompt or "whatsapp" in prompt.lower()
class TestSessionStoreRewriteTranscript:
"""Regression: /retry and /undo must persist truncated history to disk."""
@pytest.fixture()
def store(self, tmp_path):
config = GatewayConfig()
with patch("gateway.session.SessionStore._ensure_loaded"):
s = SessionStore(sessions_dir=tmp_path, config=config)
s._db = None # no SQLite for these tests
s._loaded = True
return s
def test_rewrite_replaces_jsonl(self, store, tmp_path):
session_id = "test_session_1"
# Write initial transcript
for msg in [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "hi"},
{"role": "user", "content": "undo this"},
{"role": "assistant", "content": "ok"},
]:
store.append_to_transcript(session_id, msg)
# Rewrite with truncated history
store.rewrite_transcript(session_id, [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "hi"},
])
reloaded = store.load_transcript(session_id)
assert len(reloaded) == 2
assert reloaded[0]["content"] == "hello"
assert reloaded[1]["content"] == "hi"
def test_rewrite_with_empty_list(self, store):
session_id = "test_session_2"
store.append_to_transcript(session_id, {"role": "user", "content": "hi"})
store.rewrite_transcript(session_id, [])
reloaded = store.load_transcript(session_id)
assert reloaded == []
class TestLoadTranscriptCorruptLines:
"""Regression: corrupt JSONL lines (e.g. from mid-write crash) must be
skipped instead of crashing the entire transcript load. GH-1193."""
@pytest.fixture()
def store(self, tmp_path):
config = GatewayConfig()
with patch("gateway.session.SessionStore._ensure_loaded"):
s = SessionStore(sessions_dir=tmp_path, config=config)
s._db = None
s._loaded = True
return s
def test_corrupt_line_skipped(self, store, tmp_path):
session_id = "corrupt_test"
transcript_path = store.get_transcript_path(session_id)
transcript_path.parent.mkdir(parents=True, exist_ok=True)
with open(transcript_path, "w") as f:
f.write('{"role": "user", "content": "hello"}\n')
f.write('{"role": "assistant", "content": "hi th') # truncated
f.write("\n")
f.write('{"role": "user", "content": "goodbye"}\n')
messages = store.load_transcript(session_id)
assert len(messages) == 2
assert messages[0]["content"] == "hello"
assert messages[1]["content"] == "goodbye"
def test_all_lines_corrupt_returns_empty(self, store, tmp_path):
session_id = "all_corrupt"
transcript_path = store.get_transcript_path(session_id)
transcript_path.parent.mkdir(parents=True, exist_ok=True)
with open(transcript_path, "w") as f:
f.write("not json at all\n")
f.write("{truncated\n")
messages = store.load_transcript(session_id)
assert messages == []
def test_valid_transcript_unaffected(self, store, tmp_path):
session_id = "valid_test"
store.append_to_transcript(session_id, {"role": "user", "content": "a"})
store.append_to_transcript(session_id, {"role": "assistant", "content": "b"})
messages = store.load_transcript(session_id)
assert len(messages) == 2
assert messages[0]["content"] == "a"
assert messages[1]["content"] == "b"
class TestWhatsAppDMSessionKeyConsistency:
"""Regression: all session-key construction must go through build_session_key
so DMs are isolated by chat_id across platforms."""
@pytest.fixture()
def store(self, tmp_path):
config = GatewayConfig()
with patch("gateway.session.SessionStore._ensure_loaded"):
s = SessionStore(sessions_dir=tmp_path, config=config)
s._db = None
s._loaded = True
return s
def test_whatsapp_dm_includes_chat_id(self):
source = SessionSource(
platform=Platform.WHATSAPP,
chat_id="15551234567@s.whatsapp.net",
chat_type="dm",
user_name="Phone User",
)
key = build_session_key(source)
assert key == "agent:main:whatsapp:dm:15551234567@s.whatsapp.net"
def test_store_delegates_to_build_session_key(self, store):
"""SessionStore._generate_session_key must produce the same result."""
source = SessionSource(
platform=Platform.WHATSAPP,
chat_id="15551234567@s.whatsapp.net",
chat_type="dm",
user_name="Phone User",
)
assert store._generate_session_key(source) == build_session_key(source)
def test_store_creates_distinct_group_sessions_per_user(self, store):
first = SessionSource(
platform=Platform.DISCORD,
chat_id="guild-123",
chat_type="group",
user_id="alice",
user_name="Alice",
)
second = SessionSource(
platform=Platform.DISCORD,
chat_id="guild-123",
chat_type="group",
user_id="bob",
user_name="Bob",
)
first_entry = store.get_or_create_session(first)
second_entry = store.get_or_create_session(second)
assert first_entry.session_key == "agent:main:discord:group:guild-123:alice"
assert second_entry.session_key == "agent:main:discord:group:guild-123:bob"
assert first_entry.session_id != second_entry.session_id
def test_store_shares_group_sessions_when_disabled_in_config(self, store):
store.config.group_sessions_per_user = False
first = SessionSource(
platform=Platform.DISCORD,
chat_id="guild-123",
chat_type="group",
user_id="alice",
user_name="Alice",
)
second = SessionSource(
platform=Platform.DISCORD,
chat_id="guild-123",
chat_type="group",
user_id="bob",
user_name="Bob",
)
first_entry = store.get_or_create_session(first)
second_entry = store.get_or_create_session(second)
assert first_entry.session_key == "agent:main:discord:group:guild-123"
assert second_entry.session_key == "agent:main:discord:group:guild-123"
assert first_entry.session_id == second_entry.session_id
def test_telegram_dm_includes_chat_id(self):
"""Non-WhatsApp DMs should also include chat_id to separate users."""
source = SessionSource(
platform=Platform.TELEGRAM,
chat_id="99",
chat_type="dm",
)
key = build_session_key(source)
assert key == "agent:main:telegram:dm:99"
def test_distinct_dm_chat_ids_get_distinct_session_keys(self):
"""Different DM chats must not collapse into one shared session."""
first = SessionSource(platform=Platform.TELEGRAM, chat_id="99", chat_type="dm")
second = SessionSource(platform=Platform.TELEGRAM, chat_id="100", chat_type="dm")
assert build_session_key(first) == "agent:main:telegram:dm:99"
assert build_session_key(second) == "agent:main:telegram:dm:100"
assert build_session_key(first) != build_session_key(second)
def test_discord_group_includes_chat_id(self):
"""Group/channel keys include chat_type and chat_id."""
source = SessionSource(
platform=Platform.DISCORD,
chat_id="guild-123",
chat_type="group",
)
key = build_session_key(source)
assert key == "agent:main:discord:group:guild-123"
def test_group_sessions_are_isolated_per_user_when_user_id_present(self):
first = SessionSource(
platform=Platform.DISCORD,
chat_id="guild-123",
chat_type="group",
user_id="alice",
)
second = SessionSource(
platform=Platform.DISCORD,
chat_id="guild-123",
chat_type="group",
user_id="bob",
)
assert build_session_key(first) == "agent:main:discord:group:guild-123:alice"
assert build_session_key(second) == "agent:main:discord:group:guild-123:bob"
assert build_session_key(first) != build_session_key(second)
def test_group_sessions_can_be_shared_when_isolation_disabled(self):
first = SessionSource(
platform=Platform.DISCORD,
chat_id="guild-123",
chat_type="group",
user_id="alice",
)
second = SessionSource(
platform=Platform.DISCORD,
chat_id="guild-123",
chat_type="group",
user_id="bob",
)
assert build_session_key(first, group_sessions_per_user=False) == "agent:main:discord:group:guild-123"
assert build_session_key(second, group_sessions_per_user=False) == "agent:main:discord:group:guild-123"
def test_group_thread_includes_thread_id(self):
"""Forum-style threads need a distinct session key within one group."""
source = SessionSource(
platform=Platform.TELEGRAM,
chat_id="-1002285219667",
chat_type="group",
thread_id="17585",
)
key = build_session_key(source)
assert key == "agent:main:telegram:group:-1002285219667:17585"
def test_group_thread_sessions_are_isolated_per_user(self):
source = SessionSource(
platform=Platform.TELEGRAM,
chat_id="-1002285219667",
chat_type="group",
thread_id="17585",
user_id="42",
)
key = build_session_key(source)
assert key == "agent:main:telegram:group:-1002285219667:17585:42"
class TestSessionStoreEntriesAttribute:
"""Regression: /reset must access _entries, not _sessions."""
def test_entries_attribute_exists(self):
config = GatewayConfig()
with patch("gateway.session.SessionStore._ensure_loaded"):
store = SessionStore(sessions_dir=Path("/tmp"), config=config)
store._loaded = True
assert hasattr(store, "_entries")
assert not hasattr(store, "_sessions")
class TestHasAnySessions:
"""Tests for has_any_sessions() fix (issue #351)."""
@pytest.fixture
def store_with_mock_db(self, tmp_path):
"""SessionStore with a mocked database."""
config = GatewayConfig()
with patch("gateway.session.SessionStore._ensure_loaded"):
s = SessionStore(sessions_dir=tmp_path, config=config)
s._loaded = True
s._entries = {}
s._db = MagicMock()
return s
def test_uses_database_count_when_available(self, store_with_mock_db):
"""has_any_sessions should use database session_count, not len(_entries)."""
store = store_with_mock_db
# Simulate single-platform user with only 1 entry in memory
store._entries = {"telegram:12345": MagicMock()}
# But database has 3 sessions (current + 2 previous resets)
store._db.session_count.return_value = 3
assert store.has_any_sessions() is True
store._db.session_count.assert_called_once()
def test_first_session_ever_returns_false(self, store_with_mock_db):
"""First session ever should return False (only current session in DB)."""
store = store_with_mock_db
store._entries = {"telegram:12345": MagicMock()}
# Database has exactly 1 session (the current one just created)
store._db.session_count.return_value = 1
assert store.has_any_sessions() is False
def test_fallback_without_database(self, tmp_path):
"""Should fall back to len(_entries) when DB is not available."""
config = GatewayConfig()
with patch("gateway.session.SessionStore._ensure_loaded"):
store = SessionStore(sessions_dir=tmp_path, config=config)
store._loaded = True
store._db = None
store._entries = {"key1": MagicMock(), "key2": MagicMock()}
# > 1 entries means has sessions
assert store.has_any_sessions() is True
store._entries = {"key1": MagicMock()}
assert store.has_any_sessions() is False
class TestLastPromptTokens:
"""Tests for the last_prompt_tokens field — actual API token tracking."""
def test_session_entry_default(self):
"""New sessions should have last_prompt_tokens=0."""
from gateway.session import SessionEntry
from datetime import datetime
entry = SessionEntry(
session_key="test",
session_id="s1",
created_at=datetime.now(),
updated_at=datetime.now(),
)
assert entry.last_prompt_tokens == 0
def test_session_entry_roundtrip(self):
"""last_prompt_tokens should survive serialization/deserialization."""
from gateway.session import SessionEntry
from datetime import datetime
entry = SessionEntry(
session_key="test",
session_id="s1",
created_at=datetime.now(),
updated_at=datetime.now(),
last_prompt_tokens=42000,
)
d = entry.to_dict()
assert d["last_prompt_tokens"] == 42000
restored = SessionEntry.from_dict(d)
assert restored.last_prompt_tokens == 42000
def test_session_entry_from_old_data(self):
"""Old session data without last_prompt_tokens should default to 0."""
from gateway.session import SessionEntry
data = {
"session_key": "test",
"session_id": "s1",
"created_at": "2025-01-01T00:00:00",
"updated_at": "2025-01-01T00:00:00",
"input_tokens": 100,
"output_tokens": 50,
"total_tokens": 150,
# No last_prompt_tokens — old format
}
entry = SessionEntry.from_dict(data)
assert entry.last_prompt_tokens == 0
def test_update_session_sets_last_prompt_tokens(self, tmp_path):
"""update_session should store the actual prompt token count."""
config = GatewayConfig()
with patch("gateway.session.SessionStore._ensure_loaded"):
store = SessionStore(sessions_dir=tmp_path, config=config)
store._loaded = True
store._db = None
store._save = MagicMock()
from gateway.session import SessionEntry
from datetime import datetime
entry = SessionEntry(
session_key="k1",
session_id="s1",
created_at=datetime.now(),
updated_at=datetime.now(),
)
store._entries = {"k1": entry}
store.update_session("k1", last_prompt_tokens=85000)
assert entry.last_prompt_tokens == 85000
def test_update_session_none_does_not_change(self, tmp_path):
"""update_session with default (None) should not change last_prompt_tokens."""
config = GatewayConfig()
with patch("gateway.session.SessionStore._ensure_loaded"):
store = SessionStore(sessions_dir=tmp_path, config=config)
store._loaded = True
store._db = None
store._save = MagicMock()
from gateway.session import SessionEntry
from datetime import datetime
entry = SessionEntry(
session_key="k1",
session_id="s1",
created_at=datetime.now(),
updated_at=datetime.now(),
last_prompt_tokens=50000,
)
store._entries = {"k1": entry}
store.update_session("k1") # No last_prompt_tokens arg
assert entry.last_prompt_tokens == 50000 # unchanged
def test_update_session_zero_resets(self, tmp_path):
"""update_session with last_prompt_tokens=0 should reset the field."""
config = GatewayConfig()
with patch("gateway.session.SessionStore._ensure_loaded"):
store = SessionStore(sessions_dir=tmp_path, config=config)
store._loaded = True
store._db = None
store._save = MagicMock()
from gateway.session import SessionEntry
from datetime import datetime
entry = SessionEntry(
session_key="k1",
session_id="s1",
created_at=datetime.now(),
updated_at=datetime.now(),
last_prompt_tokens=85000,
)
store._entries = {"k1": entry}
store.update_session("k1", last_prompt_tokens=0)
assert entry.last_prompt_tokens == 0
def test_update_session_passes_model_to_db(self, tmp_path):
"""Gateway session updates should forward the resolved model to SQLite."""
config = GatewayConfig()
with patch("gateway.session.SessionStore._ensure_loaded"):
store = SessionStore(sessions_dir=tmp_path, config=config)
store._loaded = True
store._save = MagicMock()
store._db = MagicMock()
from gateway.session import SessionEntry
from datetime import datetime
entry = SessionEntry(
session_key="k1",
session_id="s1",
created_at=datetime.now(),
updated_at=datetime.now(),
)
store._entries = {"k1": entry}
store.update_session("k1", model="openai/gpt-5.4")
store._db.update_token_counts.assert_called_once_with(
"s1",
input_tokens=0,
output_tokens=0,
cache_read_tokens=0,
cache_write_tokens=0,
estimated_cost_usd=None,
cost_status=None,
cost_source=None,
billing_provider=None,
billing_base_url=None,
model="openai/gpt-5.4",
)

View file

@ -0,0 +1,45 @@
import os
from gateway.config import Platform
from gateway.run import GatewayRunner
from gateway.session import SessionContext, SessionSource
def test_set_session_env_includes_thread_id(monkeypatch):
runner = object.__new__(GatewayRunner)
source = SessionSource(
platform=Platform.TELEGRAM,
chat_id="-1001",
chat_name="Group",
chat_type="group",
thread_id="17585",
)
context = SessionContext(source=source, connected_platforms=[], home_channels={})
monkeypatch.delenv("HERMES_SESSION_PLATFORM", raising=False)
monkeypatch.delenv("HERMES_SESSION_CHAT_ID", raising=False)
monkeypatch.delenv("HERMES_SESSION_CHAT_NAME", raising=False)
monkeypatch.delenv("HERMES_SESSION_THREAD_ID", raising=False)
runner._set_session_env(context)
assert os.getenv("HERMES_SESSION_PLATFORM") == "telegram"
assert os.getenv("HERMES_SESSION_CHAT_ID") == "-1001"
assert os.getenv("HERMES_SESSION_CHAT_NAME") == "Group"
assert os.getenv("HERMES_SESSION_THREAD_ID") == "17585"
def test_clear_session_env_removes_thread_id(monkeypatch):
runner = object.__new__(GatewayRunner)
monkeypatch.setenv("HERMES_SESSION_PLATFORM", "telegram")
monkeypatch.setenv("HERMES_SESSION_CHAT_ID", "-1001")
monkeypatch.setenv("HERMES_SESSION_CHAT_NAME", "Group")
monkeypatch.setenv("HERMES_SESSION_THREAD_ID", "17585")
runner._clear_session_env()
assert os.getenv("HERMES_SESSION_PLATFORM") is None
assert os.getenv("HERMES_SESSION_CHAT_ID") is None
assert os.getenv("HERMES_SESSION_CHAT_NAME") is None
assert os.getenv("HERMES_SESSION_THREAD_ID") is None

View file

@ -0,0 +1,383 @@
"""Tests for gateway session hygiene — auto-compression of large sessions.
Verifies that the gateway detects pathologically large transcripts and
triggers auto-compression before running the agent. (#628)
The hygiene system uses the SAME compression config as the agent:
compression.threshold × model context length
so CLI and messaging platforms behave identically.
"""
import importlib
import sys
import types
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import patch, MagicMock, AsyncMock
import pytest
from agent.model_metadata import estimate_messages_tokens_rough
from gateway.config import GatewayConfig, Platform, PlatformConfig
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, SendResult
from gateway.session import SessionEntry, SessionSource
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_history(n_messages: int, content_size: int = 100) -> list:
"""Build a fake transcript with n_messages user/assistant pairs."""
history = []
content = "x" * content_size
for i in range(n_messages):
role = "user" if i % 2 == 0 else "assistant"
history.append({"role": role, "content": content, "timestamp": f"t{i}"})
return history
def _make_large_history_tokens(target_tokens: int) -> list:
"""Build a history that estimates to roughly target_tokens tokens."""
# estimate_messages_tokens_rough counts total chars in str(msg) // 4
# Each msg dict has ~60 chars of overhead + content chars
# So for N tokens we need roughly N * 4 total chars across all messages
target_chars = target_tokens * 4
# Each message as a dict string is roughly len(content) + 60 chars
msg_overhead = 60
# Use 50 messages with appropriately sized content
n_msgs = 50
content_size = max(10, (target_chars // n_msgs) - msg_overhead)
return _make_history(n_msgs, content_size=content_size)
class HygieneCaptureAdapter(BasePlatformAdapter):
def __init__(self):
super().__init__(PlatformConfig(enabled=True, token="fake-token"), Platform.TELEGRAM)
self.sent = []
async def connect(self) -> bool:
return True
async def disconnect(self) -> None:
return None
async def send(self, chat_id, content, reply_to=None, metadata=None) -> SendResult:
self.sent.append(
{
"chat_id": chat_id,
"content": content,
"reply_to": reply_to,
"metadata": metadata,
}
)
return SendResult(success=True, message_id="hygiene-1")
async def get_chat_info(self, chat_id: str):
return {"id": chat_id}
# ---------------------------------------------------------------------------
# Detection threshold tests (model-aware, unified with compression config)
# ---------------------------------------------------------------------------
class TestSessionHygieneThresholds:
"""Test that the threshold logic correctly identifies large sessions.
Thresholds are derived from model context length × compression threshold,
matching what the agent's ContextCompressor uses.
"""
def test_small_session_below_thresholds(self):
"""A 10-message session should not trigger compression."""
history = _make_history(10)
approx_tokens = estimate_messages_tokens_rough(history)
# For a 200k-context model at 85% threshold = 170k
context_length = 200_000
threshold_pct = 0.85
compress_token_threshold = int(context_length * threshold_pct)
needs_compress = approx_tokens >= compress_token_threshold
assert not needs_compress
def test_large_token_count_triggers(self):
"""High token count should trigger compression when exceeding model threshold."""
# Build a history that exceeds 85% of a 200k model (170k tokens)
history = _make_large_history_tokens(180_000)
approx_tokens = estimate_messages_tokens_rough(history)
context_length = 200_000
threshold_pct = 0.85
compress_token_threshold = int(context_length * threshold_pct)
needs_compress = approx_tokens >= compress_token_threshold
assert needs_compress
def test_under_threshold_no_trigger(self):
"""Session under threshold should not trigger, even with many messages."""
# 250 short messages — lots of messages but well under token threshold
history = _make_history(250, content_size=10)
approx_tokens = estimate_messages_tokens_rough(history)
# 200k model at 85% = 170k token threshold
context_length = 200_000
threshold_pct = 0.85
compress_token_threshold = int(context_length * threshold_pct)
needs_compress = approx_tokens >= compress_token_threshold
assert not needs_compress, (
f"250 short messages (~{approx_tokens} tokens) should NOT trigger "
f"compression at {compress_token_threshold} token threshold"
)
def test_message_count_alone_does_not_trigger(self):
"""Message count alone should NOT trigger — only token count matters.
The old system used an OR of token-count and message-count thresholds,
which caused premature compression in tool-heavy sessions with 200+
messages but low total tokens.
"""
# 300 very short messages — old system would compress, new should not
history = _make_history(300, content_size=10)
approx_tokens = estimate_messages_tokens_rough(history)
context_length = 200_000
threshold_pct = 0.85
compress_token_threshold = int(context_length * threshold_pct)
# Token-based check only
needs_compress = approx_tokens >= compress_token_threshold
assert not needs_compress
def test_threshold_scales_with_model(self):
"""Different models should have different compression thresholds."""
# 128k model at 85% = 108,800 tokens
small_model_threshold = int(128_000 * 0.85)
# 200k model at 85% = 170,000 tokens
large_model_threshold = int(200_000 * 0.85)
# 1M model at 85% = 850,000 tokens
huge_model_threshold = int(1_000_000 * 0.85)
# A session at ~120k tokens:
history = _make_large_history_tokens(120_000)
approx_tokens = estimate_messages_tokens_rough(history)
# Should trigger for 128k model
assert approx_tokens >= small_model_threshold
# Should NOT trigger for 200k model
assert approx_tokens < large_model_threshold
# Should NOT trigger for 1M model
assert approx_tokens < huge_model_threshold
def test_custom_threshold_percentage(self):
"""Custom threshold percentage from config should be respected."""
context_length = 200_000
# At 50% threshold = 100k
low_threshold = int(context_length * 0.50)
# At 90% threshold = 180k
high_threshold = int(context_length * 0.90)
history = _make_large_history_tokens(150_000)
approx_tokens = estimate_messages_tokens_rough(history)
# Should trigger at 50% but not at 90%
assert approx_tokens >= low_threshold
assert approx_tokens < high_threshold
def test_minimum_message_guard(self):
"""Sessions with fewer than 4 messages should never trigger."""
history = _make_history(3, content_size=100_000)
# Even with enormous content, < 4 messages should be skipped
# (the gateway code checks `len(history) >= 4` before evaluating)
assert len(history) < 4
class TestSessionHygieneWarnThreshold:
"""Test the post-compression warning threshold (95% of context)."""
def test_warn_when_still_large(self):
"""If compressed result is still above 95% of context, should warn."""
context_length = 200_000
warn_threshold = int(context_length * 0.95) # 190k
post_compress_tokens = 195_000
assert post_compress_tokens >= warn_threshold
def test_no_warn_when_under(self):
"""If compressed result is under 95% of context, no warning."""
context_length = 200_000
warn_threshold = int(context_length * 0.95) # 190k
post_compress_tokens = 150_000
assert post_compress_tokens < warn_threshold
class TestEstimatedTokenThreshold:
"""Verify that hygiene thresholds are always below the model's context
limit for both actual and estimated token counts.
Regression: a previous 1.4x multiplier on rough estimates pushed the
threshold to 85% * 1.4 = 119% of context, which exceeded the model's
limit and prevented hygiene from ever firing for ~200K models (GLM-5).
The fix removed the multiplier entirely the 85% threshold already
provides ample headroom over the agent's 50% compressor.
"""
def test_threshold_below_context_for_200k_model(self):
"""Hygiene threshold must always be below model context."""
context_length = 200_000
threshold = int(context_length * 0.85)
assert threshold < context_length
def test_threshold_below_context_for_128k_model(self):
context_length = 128_000
threshold = int(context_length * 0.85)
assert threshold < context_length
def test_no_multiplier_means_same_threshold_for_estimated_and_actual(self):
"""Without the 1.4x, estimated and actual token paths use the same threshold."""
context_length = 200_000
threshold_pct = 0.85
threshold = int(context_length * threshold_pct)
# Both paths should use 170K — no inflation
assert threshold == 170_000
def test_warn_threshold_below_context(self):
"""Warn threshold (95%) must be below context length."""
for ctx in (128_000, 200_000, 1_000_000):
warn = int(ctx * 0.95)
assert warn < ctx
def test_overestimate_fires_early_but_safely(self):
"""If rough estimate is 50% inflated, hygiene fires at ~57% actual usage.
That's between the agent's 50% threshold and the model's limit —
safe and harmless.
"""
context_length = 200_000
threshold = int(context_length * 0.85) # 170K
# If actual tokens = 113K, rough estimate = 113K * 1.5 = 170K
# Hygiene fires when estimate hits 170K, actual is ~113K = 57% of ctx
actual_when_fires = threshold / 1.5
assert actual_when_fires > context_length * 0.50, (
"Early fire should still be above agent's 50% threshold"
)
assert actual_when_fires < context_length, (
"Early fire must be well below model limit"
)
class TestTokenEstimation:
"""Verify rough token estimation works as expected for hygiene checks."""
def test_empty_history(self):
assert estimate_messages_tokens_rough([]) == 0
def test_proportional_to_content(self):
small = _make_history(10, content_size=100)
large = _make_history(10, content_size=10_000)
assert estimate_messages_tokens_rough(large) > estimate_messages_tokens_rough(small)
def test_proportional_to_count(self):
few = _make_history(10, content_size=1000)
many = _make_history(100, content_size=1000)
assert estimate_messages_tokens_rough(many) > estimate_messages_tokens_rough(few)
def test_pathological_session_detected(self):
"""The reported pathological case: 648 messages, ~299K tokens.
With a 200k model at 85% threshold (170k), this should trigger.
"""
history = _make_history(648, content_size=1800)
tokens = estimate_messages_tokens_rough(history)
# Should be well above the 170K threshold for a 200k model
threshold = int(200_000 * 0.85)
assert tokens > threshold
@pytest.mark.asyncio
async def test_session_hygiene_messages_stay_in_originating_topic(monkeypatch, tmp_path):
fake_dotenv = types.ModuleType("dotenv")
fake_dotenv.load_dotenv = lambda *args, **kwargs: None
monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv)
class FakeCompressAgent:
def __init__(self, **kwargs):
self.model = kwargs.get("model")
def _compress_context(self, messages, *_args, **_kwargs):
return ([{"role": "assistant", "content": "compressed"}], None)
fake_run_agent = types.ModuleType("run_agent")
fake_run_agent.AIAgent = FakeCompressAgent
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
gateway_run = importlib.import_module("gateway.run")
GatewayRunner = gateway_run.GatewayRunner
adapter = HygieneCaptureAdapter()
runner = object.__new__(GatewayRunner)
runner.config = GatewayConfig(
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="fake-token")}
)
runner.adapters = {Platform.TELEGRAM: adapter}
runner._voice_mode = {}
runner.hooks = SimpleNamespace(emit=AsyncMock(), loaded_hooks=False)
runner.session_store = MagicMock()
runner.session_store.get_or_create_session.return_value = SessionEntry(
session_key="agent:main:telegram:group:-1001:17585",
session_id="sess-1",
created_at=datetime.now(),
updated_at=datetime.now(),
platform=Platform.TELEGRAM,
chat_type="group",
)
runner.session_store.load_transcript.return_value = _make_history(6, content_size=400)
runner.session_store.has_any_sessions.return_value = True
runner.session_store.rewrite_transcript = MagicMock()
runner.session_store.append_to_transcript = MagicMock()
runner._running_agents = {}
runner._pending_messages = {}
runner._pending_approvals = {}
runner._session_db = None
runner._is_user_authorized = lambda _source: True
runner._set_session_env = lambda _context: None
runner._run_agent = AsyncMock(
return_value={
"final_response": "ok",
"messages": [],
"tools": [],
"history_offset": 0,
"last_prompt_tokens": 0,
}
)
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "fake"})
monkeypatch.setattr(
"agent.model_metadata.get_model_context_length",
lambda *_args, **_kwargs: 100,
)
monkeypatch.setenv("TELEGRAM_HOME_CHANNEL", "795544298")
event = MessageEvent(
text="hello",
source=SessionSource(
platform=Platform.TELEGRAM,
chat_id="-1001",
chat_type="group",
thread_id="17585",
),
message_id="1",
)
result = await runner._handle_message(event)
assert result == "ok"
assert len(adapter.sent) == 2
assert adapter.sent[0]["chat_id"] == "-1001"
assert "Session is large" in adapter.sent[0]["content"]
assert adapter.sent[0]["metadata"] == {"thread_id": "17585"}
assert adapter.sent[1]["chat_id"] == "-1001"
assert "Compressed:" in adapter.sent[1]["content"]
assert adapter.sent[1]["metadata"] == {"thread_id": "17585"}

View file

@ -0,0 +1,267 @@
"""Tests for the session race guard that prevents concurrent agent runs.
The sentinel-based guard ensures that when _handle_message passes the
"is an agent already running?" check and proceeds to the slow async
setup path (vision enrichment, STT, hooks, session hygiene), a second
message for the same session is correctly recognized as "already running"
and routed through the interrupt/queue path instead of spawning a
duplicate agent.
"""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from gateway.config import GatewayConfig, Platform, PlatformConfig
from gateway.platforms.base import MessageEvent, MessageType
from gateway.run import GatewayRunner, _AGENT_PENDING_SENTINEL
from gateway.session import SessionSource, build_session_key
class _FakeAdapter:
"""Minimal adapter stub for testing."""
def __init__(self):
self._pending_messages = {}
async def send(self, chat_id, text, **kwargs):
pass
def _make_runner():
runner = object.__new__(GatewayRunner)
runner.config = GatewayConfig(
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}
)
runner.adapters = {Platform.TELEGRAM: _FakeAdapter()}
runner._running_agents = {}
runner._pending_messages = {}
runner._pending_approvals = {}
runner._voice_mode = {}
runner._is_user_authorized = lambda _source: True
return runner
def _make_event(text="hello", chat_id="12345"):
source = SessionSource(
platform=Platform.TELEGRAM, chat_id=chat_id, chat_type="dm"
)
return MessageEvent(text=text, message_type=MessageType.TEXT, source=source)
# ------------------------------------------------------------------
# Test 1: Sentinel is placed before _handle_message_with_agent runs
# ------------------------------------------------------------------
@pytest.mark.asyncio
async def test_sentinel_placed_before_agent_setup():
"""After passing the 'not running' guard, the sentinel must be
written into _running_agents *before* any await, so that a
concurrent message sees the session as occupied."""
runner = _make_runner()
event = _make_event()
session_key = build_session_key(event.source)
# Patch _handle_message_with_agent to capture state at entry
sentinel_was_set = False
async def mock_inner(self_inner, ev, src, qk):
nonlocal sentinel_was_set
sentinel_was_set = runner._running_agents.get(qk) is _AGENT_PENDING_SENTINEL
return "ok"
with patch.object(GatewayRunner, "_handle_message_with_agent", mock_inner):
await runner._handle_message(event)
assert sentinel_was_set, (
"Sentinel must be in _running_agents when _handle_message_with_agent starts"
)
# ------------------------------------------------------------------
# Test 2: Sentinel is cleaned up after _handle_message_with_agent
# ------------------------------------------------------------------
@pytest.mark.asyncio
async def test_sentinel_cleaned_up_after_handler_returns():
"""If _handle_message_with_agent returns normally, the sentinel
must be removed so the session is not permanently locked."""
runner = _make_runner()
event = _make_event()
session_key = build_session_key(event.source)
async def mock_inner(self_inner, ev, src, qk):
return "ok"
with patch.object(GatewayRunner, "_handle_message_with_agent", mock_inner):
await runner._handle_message(event)
assert session_key not in runner._running_agents, (
"Sentinel must be removed after handler completes"
)
# ------------------------------------------------------------------
# Test 3: Sentinel cleaned up on exception
# ------------------------------------------------------------------
@pytest.mark.asyncio
async def test_sentinel_cleaned_up_on_exception():
"""If _handle_message_with_agent raises, the sentinel must still
be cleaned up so the session is not permanently locked."""
runner = _make_runner()
event = _make_event()
session_key = build_session_key(event.source)
async def mock_inner(self_inner, ev, src, qk):
raise RuntimeError("boom")
with patch.object(GatewayRunner, "_handle_message_with_agent", mock_inner):
with pytest.raises(RuntimeError, match="boom"):
await runner._handle_message(event)
assert session_key not in runner._running_agents, (
"Sentinel must be removed even if handler raises"
)
# ------------------------------------------------------------------
# Test 4: Second message during sentinel sees "already running"
# ------------------------------------------------------------------
@pytest.mark.asyncio
async def test_second_message_during_sentinel_queued_not_duplicate():
"""While the sentinel is set (agent setup in progress), a second
message for the same session must hit the 'already running' branch
and be queued not start a second agent."""
runner = _make_runner()
event1 = _make_event(text="first message")
event2 = _make_event(text="second message")
session_key = build_session_key(event1.source)
barrier = asyncio.Event()
async def slow_inner(self_inner, ev, src, qk):
# Simulate slow setup — wait until test tells us to proceed
await barrier.wait()
return "ok"
with patch.object(GatewayRunner, "_handle_message_with_agent", slow_inner):
# Start first message (will block at barrier)
task1 = asyncio.create_task(runner._handle_message(event1))
# Yield so task1 enters slow_inner and sentinel is set
await asyncio.sleep(0)
# Verify sentinel is set
assert runner._running_agents.get(session_key) is _AGENT_PENDING_SENTINEL
# Second message should see "already running" and be queued
result2 = await runner._handle_message(event2)
assert result2 is None, "Second message should return None (queued)"
# The second message should have been queued in adapter pending
adapter = runner.adapters[Platform.TELEGRAM]
assert session_key in adapter._pending_messages, (
"Second message should be queued as pending"
)
assert adapter._pending_messages[session_key] is event2
# Let first message complete
barrier.set()
await task1
# ------------------------------------------------------------------
# Test 5: Sentinel not placed for command messages
# ------------------------------------------------------------------
@pytest.mark.asyncio
async def test_command_messages_do_not_leave_sentinel():
"""Slash commands (/help, /status, etc.) return early from
_handle_message. They must NOT leave a sentinel behind."""
runner = _make_runner()
source = SessionSource(
platform=Platform.TELEGRAM, chat_id="12345", chat_type="dm"
)
event = MessageEvent(
text="/help", message_type=MessageType.TEXT, source=source
)
session_key = build_session_key(source)
# Mock the help handler to avoid needing full runner setup
runner._handle_help_command = AsyncMock(return_value="Help text")
# Need hooks for command emission
runner.hooks = MagicMock()
runner.hooks.emit = AsyncMock()
await runner._handle_message(event)
assert session_key not in runner._running_agents, (
"Command handlers must not leave sentinel in _running_agents"
)
# ------------------------------------------------------------------
# Test 6: /stop during sentinel returns helpful message
# ------------------------------------------------------------------
@pytest.mark.asyncio
async def test_stop_during_sentinel_returns_message():
"""If /stop arrives while the sentinel is set (agent still starting),
it should return a helpful message instead of crashing or queuing."""
runner = _make_runner()
event1 = _make_event(text="hello")
session_key = build_session_key(event1.source)
barrier = asyncio.Event()
async def slow_inner(self_inner, ev, src, qk):
await barrier.wait()
return "ok"
with patch.object(GatewayRunner, "_handle_message_with_agent", slow_inner):
task1 = asyncio.create_task(runner._handle_message(event1))
await asyncio.sleep(0)
# Sentinel should be set
assert runner._running_agents.get(session_key) is _AGENT_PENDING_SENTINEL
# Send /stop — should get a message, not crash
stop_event = _make_event(text="/stop")
result = await runner._handle_message(stop_event)
assert result is not None, "/stop during sentinel should return a message"
assert "starting up" in result.lower()
# Should NOT be queued as pending
adapter = runner.adapters[Platform.TELEGRAM]
assert session_key not in adapter._pending_messages
barrier.set()
await task1
# ------------------------------------------------------------------
# Test 7: Shutdown skips sentinel entries
# ------------------------------------------------------------------
@pytest.mark.asyncio
async def test_shutdown_skips_sentinel():
"""During gateway shutdown, sentinel entries in _running_agents
should be skipped without raising AttributeError."""
runner = _make_runner()
session_key = "telegram:dm:99999"
# Simulate a sentinel in _running_agents
runner._running_agents[session_key] = _AGENT_PENDING_SENTINEL
# Also add a real agent mock to verify it still gets interrupted
real_agent = MagicMock()
runner._running_agents["telegram:dm:88888"] = real_agent
runner.adapters = {} # No adapters to disconnect
runner._running = True
runner._shutdown_event = asyncio.Event()
runner._exit_reason = None
runner._shutdown_all_gateway_honcho = lambda: None
with patch("gateway.status.remove_pid_file"), \
patch("gateway.status.write_runtime_status"):
await runner.stop()
# Real agent should have been interrupted
real_agent.interrupt.assert_called_once()
# Should not have raised on the sentinel

View file

@ -0,0 +1,207 @@
"""Tests for session auto-reset notifications.
Verifies that:
- _should_reset() returns a reason string ("idle" or "daily") instead of bool
- SessionEntry captures auto_reset_reason
- SessionResetPolicy.notify controls whether notifications are sent
- notify_exclude_platforms skips notifications for excluded platforms
"""
from datetime import datetime, timedelta
from unittest.mock import MagicMock
import pytest
from gateway.config import (
GatewayConfig,
Platform,
PlatformConfig,
SessionResetPolicy,
)
from gateway.session import SessionEntry, SessionSource, SessionStore
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_source(platform=Platform.TELEGRAM, chat_id="123", user_id="u1"):
return SessionSource(
platform=platform,
chat_id=chat_id,
user_id=user_id,
)
def _make_store(policy=None, tmp_path=None):
config = GatewayConfig()
if policy:
config.default_reset_policy = policy
store = SessionStore(sessions_dir=tmp_path or "/tmp/test-sessions", config=config)
return store
# ---------------------------------------------------------------------------
# _should_reset returns reason string
# ---------------------------------------------------------------------------
class TestShouldResetReason:
def test_returns_none_when_not_expired(self, tmp_path):
store = _make_store(
SessionResetPolicy(mode="both", idle_minutes=60, at_hour=4),
tmp_path,
)
entry = SessionEntry(
session_key="test",
session_id="s1",
created_at=datetime.now(),
updated_at=datetime.now(), # just updated
)
source = _make_source()
assert store._should_reset(entry, source) is None
def test_returns_idle_when_idle_expired(self, tmp_path):
store = _make_store(
SessionResetPolicy(mode="idle", idle_minutes=30),
tmp_path,
)
entry = SessionEntry(
session_key="test",
session_id="s1",
created_at=datetime.now() - timedelta(hours=2),
updated_at=datetime.now() - timedelta(hours=1), # 60min ago > 30min threshold
)
source = _make_source()
assert store._should_reset(entry, source) == "idle"
def test_returns_daily_when_daily_boundary_crossed(self, tmp_path):
now = datetime.now()
store = _make_store(
SessionResetPolicy(mode="daily", at_hour=now.hour),
tmp_path,
)
entry = SessionEntry(
session_key="test",
session_id="s1",
created_at=now - timedelta(days=2),
updated_at=now - timedelta(days=1), # last active yesterday
)
source = _make_source()
assert store._should_reset(entry, source) == "daily"
def test_returns_none_when_mode_is_none(self, tmp_path):
store = _make_store(
SessionResetPolicy(mode="none"),
tmp_path,
)
entry = SessionEntry(
session_key="test",
session_id="s1",
created_at=datetime.now() - timedelta(days=30),
updated_at=datetime.now() - timedelta(days=30),
)
source = _make_source()
assert store._should_reset(entry, source) is None
# ---------------------------------------------------------------------------
# SessionEntry captures reason
# ---------------------------------------------------------------------------
class TestSessionEntryReason:
def test_auto_reset_reason_stored(self, tmp_path):
store = _make_store(
SessionResetPolicy(mode="idle", idle_minutes=1),
tmp_path,
)
source = _make_source()
# Create initial session
entry1 = store.get_or_create_session(source)
assert not entry1.was_auto_reset
# Age it past the idle threshold
entry1.updated_at = datetime.now() - timedelta(minutes=5)
store._save()
# Next call should create a new session with reason
entry2 = store.get_or_create_session(source)
assert entry2.was_auto_reset is True
assert entry2.auto_reset_reason == "idle"
assert entry2.session_id != entry1.session_id
def test_reset_had_activity_false_when_no_tokens(self, tmp_path):
"""Expired session with no tokens → reset_had_activity=False."""
store = _make_store(
SessionResetPolicy(mode="idle", idle_minutes=1),
tmp_path,
)
source = _make_source()
entry1 = store.get_or_create_session(source)
# No tokens used — session was idle with no conversation
entry1.updated_at = datetime.now() - timedelta(minutes=5)
store._save()
entry2 = store.get_or_create_session(source)
assert entry2.was_auto_reset is True
assert entry2.reset_had_activity is False
def test_reset_had_activity_true_when_tokens_used(self, tmp_path):
"""Expired session with tokens → reset_had_activity=True."""
store = _make_store(
SessionResetPolicy(mode="idle", idle_minutes=1),
tmp_path,
)
source = _make_source()
entry1 = store.get_or_create_session(source)
# Simulate some conversation happened
entry1.total_tokens = 5000
entry1.updated_at = datetime.now() - timedelta(minutes=5)
store._save()
entry2 = store.get_or_create_session(source)
assert entry2.was_auto_reset is True
assert entry2.reset_had_activity is True
# ---------------------------------------------------------------------------
# SessionResetPolicy notify config
# ---------------------------------------------------------------------------
class TestResetPolicyNotify:
def test_notify_defaults_true(self):
policy = SessionResetPolicy()
assert policy.notify is True
def test_notify_exclude_defaults(self):
policy = SessionResetPolicy()
assert "api_server" in policy.notify_exclude_platforms
assert "webhook" in policy.notify_exclude_platforms
def test_from_dict_with_notify_false(self):
policy = SessionResetPolicy.from_dict({"notify": False})
assert policy.notify is False
def test_from_dict_with_custom_excludes(self):
policy = SessionResetPolicy.from_dict({
"notify_exclude_platforms": ["api_server", "webhook", "homeassistant"],
})
assert "homeassistant" in policy.notify_exclude_platforms
def test_from_dict_preserves_defaults_on_missing_keys(self):
policy = SessionResetPolicy.from_dict({})
assert policy.notify is True
assert "api_server" in policy.notify_exclude_platforms
def test_to_dict_roundtrip(self):
original = SessionResetPolicy(
mode="idle",
notify=False,
notify_exclude_platforms=("api_server",),
)
restored = SessionResetPolicy.from_dict(original.to_dict())
assert restored.notify == original.notify
assert restored.notify_exclude_platforms == original.notify_exclude_platforms
assert restored.mode == original.mode

View file

@ -0,0 +1,298 @@
"""Tests for Signal messenger platform adapter."""
import json
import pytest
from unittest.mock import MagicMock, patch, AsyncMock
from gateway.config import Platform, PlatformConfig
# ---------------------------------------------------------------------------
# Platform & Config
# ---------------------------------------------------------------------------
class TestSignalPlatformEnum:
def test_signal_enum_exists(self):
assert Platform.SIGNAL.value == "signal"
def test_signal_in_platform_list(self):
platforms = [p.value for p in Platform]
assert "signal" in platforms
class TestSignalConfigLoading:
def test_apply_env_overrides_signal(self, monkeypatch):
monkeypatch.setenv("SIGNAL_HTTP_URL", "http://localhost:9090")
monkeypatch.setenv("SIGNAL_ACCOUNT", "+15551234567")
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
assert Platform.SIGNAL in config.platforms
sc = config.platforms[Platform.SIGNAL]
assert sc.enabled is True
assert sc.extra["http_url"] == "http://localhost:9090"
assert sc.extra["account"] == "+15551234567"
def test_signal_not_loaded_without_both_vars(self, monkeypatch):
monkeypatch.setenv("SIGNAL_HTTP_URL", "http://localhost:9090")
# No SIGNAL_ACCOUNT
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
assert Platform.SIGNAL not in config.platforms
def test_connected_platforms_includes_signal(self, monkeypatch):
monkeypatch.setenv("SIGNAL_HTTP_URL", "http://localhost:8080")
monkeypatch.setenv("SIGNAL_ACCOUNT", "+15551234567")
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
connected = config.get_connected_platforms()
assert Platform.SIGNAL in connected
# ---------------------------------------------------------------------------
# Adapter Init & Helpers
# ---------------------------------------------------------------------------
class TestSignalAdapterInit:
def _make_config(self, **extra):
config = PlatformConfig()
config.enabled = True
config.extra = {
"http_url": "http://localhost:8080",
"account": "+15551234567",
**extra,
}
return config
def test_init_parses_config(self, monkeypatch):
monkeypatch.setenv("SIGNAL_GROUP_ALLOWED_USERS", "group123,group456")
from gateway.platforms.signal import SignalAdapter
adapter = SignalAdapter(self._make_config())
assert adapter.http_url == "http://localhost:8080"
assert adapter.account == "+15551234567"
assert "group123" in adapter.group_allow_from
def test_init_empty_allowlist(self, monkeypatch):
monkeypatch.setenv("SIGNAL_GROUP_ALLOWED_USERS", "")
from gateway.platforms.signal import SignalAdapter
adapter = SignalAdapter(self._make_config())
assert len(adapter.group_allow_from) == 0
def test_init_strips_trailing_slash(self, monkeypatch):
monkeypatch.setenv("SIGNAL_GROUP_ALLOWED_USERS", "")
from gateway.platforms.signal import SignalAdapter
adapter = SignalAdapter(self._make_config(http_url="http://localhost:8080/"))
assert adapter.http_url == "http://localhost:8080"
def test_self_message_filtering(self, monkeypatch):
monkeypatch.setenv("SIGNAL_GROUP_ALLOWED_USERS", "")
from gateway.platforms.signal import SignalAdapter
adapter = SignalAdapter(self._make_config())
assert adapter._account_normalized == "+15551234567"
class TestSignalHelpers:
def test_redact_phone_long(self):
from gateway.platforms.signal import _redact_phone
assert _redact_phone("+15551234567") == "+155****4567"
def test_redact_phone_short(self):
from gateway.platforms.signal import _redact_phone
assert _redact_phone("+12345") == "+1****45"
def test_redact_phone_empty(self):
from gateway.platforms.signal import _redact_phone
assert _redact_phone("") == "<none>"
def test_parse_comma_list(self):
from gateway.platforms.signal import _parse_comma_list
assert _parse_comma_list("+1234, +5678 , +9012") == ["+1234", "+5678", "+9012"]
assert _parse_comma_list("") == []
assert _parse_comma_list(" , , ") == []
def test_guess_extension_png(self):
from gateway.platforms.signal import _guess_extension
assert _guess_extension(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100) == ".png"
def test_guess_extension_jpeg(self):
from gateway.platforms.signal import _guess_extension
assert _guess_extension(b"\xff\xd8\xff\xe0" + b"\x00" * 100) == ".jpg"
def test_guess_extension_pdf(self):
from gateway.platforms.signal import _guess_extension
assert _guess_extension(b"%PDF-1.4" + b"\x00" * 100) == ".pdf"
def test_guess_extension_zip(self):
from gateway.platforms.signal import _guess_extension
assert _guess_extension(b"PK\x03\x04" + b"\x00" * 100) == ".zip"
def test_guess_extension_mp4(self):
from gateway.platforms.signal import _guess_extension
assert _guess_extension(b"\x00\x00\x00\x18ftypisom" + b"\x00" * 100) == ".mp4"
def test_guess_extension_unknown(self):
from gateway.platforms.signal import _guess_extension
assert _guess_extension(b"\x00\x01\x02\x03" * 10) == ".bin"
def test_is_image_ext(self):
from gateway.platforms.signal import _is_image_ext
assert _is_image_ext(".png") is True
assert _is_image_ext(".jpg") is True
assert _is_image_ext(".gif") is True
assert _is_image_ext(".pdf") is False
def test_is_audio_ext(self):
from gateway.platforms.signal import _is_audio_ext
assert _is_audio_ext(".mp3") is True
assert _is_audio_ext(".ogg") is True
assert _is_audio_ext(".png") is False
def test_check_requirements(self, monkeypatch):
from gateway.platforms.signal import check_signal_requirements
monkeypatch.setenv("SIGNAL_HTTP_URL", "http://localhost:8080")
monkeypatch.setenv("SIGNAL_ACCOUNT", "+15551234567")
assert check_signal_requirements() is True
def test_render_mentions(self):
from gateway.platforms.signal import _render_mentions
text = "Hello \uFFFC, how are you?"
mentions = [{"start": 6, "length": 1, "number": "+15559999999"}]
result = _render_mentions(text, mentions)
assert "@+15559999999" in result
assert "\uFFFC" not in result
def test_render_mentions_no_mentions(self):
from gateway.platforms.signal import _render_mentions
text = "Hello world"
result = _render_mentions(text, [])
assert result == "Hello world"
def test_check_requirements_missing(self, monkeypatch):
from gateway.platforms.signal import check_signal_requirements
monkeypatch.delenv("SIGNAL_HTTP_URL", raising=False)
monkeypatch.delenv("SIGNAL_ACCOUNT", raising=False)
assert check_signal_requirements() is False
# ---------------------------------------------------------------------------
# Session Source
# ---------------------------------------------------------------------------
class TestSignalSessionSource:
def test_session_source_alt_fields(self):
from gateway.session import SessionSource
source = SessionSource(
platform=Platform.SIGNAL,
chat_id="+15551234567",
user_id="+15551234567",
user_id_alt="uuid:abc-123",
chat_id_alt=None,
)
d = source.to_dict()
assert d["user_id_alt"] == "uuid:abc-123"
assert "chat_id_alt" not in d # None fields excluded
def test_session_source_roundtrip(self):
from gateway.session import SessionSource
source = SessionSource(
platform=Platform.SIGNAL,
chat_id="group:xyz",
chat_type="group",
user_id="+15551234567",
user_id_alt="uuid:abc",
chat_id_alt="xyz",
)
d = source.to_dict()
restored = SessionSource.from_dict(d)
assert restored.user_id_alt == "uuid:abc"
assert restored.chat_id_alt == "xyz"
assert restored.platform == Platform.SIGNAL
# ---------------------------------------------------------------------------
# Phone Redaction in agent/redact.py
# ---------------------------------------------------------------------------
class TestSignalPhoneRedaction:
@pytest.fixture(autouse=True)
def _ensure_redaction_enabled(self, monkeypatch):
monkeypatch.delenv("HERMES_REDACT_SECRETS", raising=False)
def test_us_number(self):
from agent.redact import redact_sensitive_text
result = redact_sensitive_text("Call +15551234567 now")
assert "+15551234567" not in result
assert "+155" in result # Prefix preserved
assert "4567" in result # Suffix preserved
def test_uk_number(self):
from agent.redact import redact_sensitive_text
result = redact_sensitive_text("UK: +442071838750")
assert "+442071838750" not in result
assert "****" in result
def test_multiple_numbers(self):
from agent.redact import redact_sensitive_text
text = "From +15551234567 to +442071838750"
result = redact_sensitive_text(text)
assert "+15551234567" not in result
assert "+442071838750" not in result
def test_short_number_not_matched(self):
from agent.redact import redact_sensitive_text
result = redact_sensitive_text("Code: +12345")
# 5 digits after + is below the 7-digit minimum
assert "+12345" in result # Too short to redact
# ---------------------------------------------------------------------------
# Authorization in run.py
# ---------------------------------------------------------------------------
class TestSignalAuthorization:
def test_signal_in_allowlist_maps(self):
"""Signal should be in the platform auth maps."""
from gateway.run import GatewayRunner
from gateway.config import GatewayConfig
gw = GatewayRunner.__new__(GatewayRunner)
gw.config = GatewayConfig()
gw.pairing_store = MagicMock()
gw.pairing_store.is_approved.return_value = False
source = MagicMock()
source.platform = Platform.SIGNAL
source.user_id = "+15559999999"
# No allowlists set — should check GATEWAY_ALLOW_ALL_USERS
with patch.dict("os.environ", {}, clear=True):
result = gw._is_user_authorized(source)
assert result is False
# ---------------------------------------------------------------------------
# Send Message Tool
# ---------------------------------------------------------------------------
class TestSignalSendMessage:
def test_signal_in_platform_map(self):
"""Signal should be in the send_message tool's platform map."""
from tools.send_message_tool import send_message_tool
# Just verify the import works and Signal is a valid platform
from gateway.config import Platform
assert Platform.SIGNAL.value == "signal"

View file

@ -0,0 +1,948 @@
"""
Tests for Slack platform adapter.
Covers: app_mention handler, send_document, send_video,
incoming document handling, message routing.
Note: slack-bolt may not be installed in the test environment.
We mock the slack modules at import time to avoid collection errors.
"""
import asyncio
import os
import sys
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from gateway.config import Platform, PlatformConfig
from gateway.platforms.base import (
MessageEvent,
MessageType,
SendResult,
SUPPORTED_DOCUMENT_TYPES,
)
# ---------------------------------------------------------------------------
# Mock the slack-bolt package if it's not installed
# ---------------------------------------------------------------------------
def _ensure_slack_mock():
"""Install mock slack modules so SlackAdapter can be imported."""
if "slack_bolt" in sys.modules and hasattr(sys.modules["slack_bolt"], "__file__"):
return # Real library installed
slack_bolt = MagicMock()
slack_bolt.async_app.AsyncApp = MagicMock
slack_bolt.adapter.socket_mode.async_handler.AsyncSocketModeHandler = MagicMock
slack_sdk = MagicMock()
slack_sdk.web.async_client.AsyncWebClient = MagicMock
for name, mod in [
("slack_bolt", slack_bolt),
("slack_bolt.async_app", slack_bolt.async_app),
("slack_bolt.adapter", slack_bolt.adapter),
("slack_bolt.adapter.socket_mode", slack_bolt.adapter.socket_mode),
("slack_bolt.adapter.socket_mode.async_handler", slack_bolt.adapter.socket_mode.async_handler),
("slack_sdk", slack_sdk),
("slack_sdk.web", slack_sdk.web),
("slack_sdk.web.async_client", slack_sdk.web.async_client),
]:
sys.modules.setdefault(name, mod)
_ensure_slack_mock()
# Patch SLACK_AVAILABLE before importing the adapter
import gateway.platforms.slack as _slack_mod
_slack_mod.SLACK_AVAILABLE = True
from gateway.platforms.slack import SlackAdapter # noqa: E402
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture()
def adapter():
config = PlatformConfig(enabled=True, token="xoxb-fake-token")
a = SlackAdapter(config)
# Mock the Slack app client
a._app = MagicMock()
a._app.client = AsyncMock()
a._bot_user_id = "U_BOT"
a._running = True
# Capture events instead of processing them
a.handle_message = AsyncMock()
return a
@pytest.fixture(autouse=True)
def _redirect_cache(tmp_path, monkeypatch):
"""Point document cache to tmp_path so tests don't touch ~/.hermes."""
monkeypatch.setattr(
"gateway.platforms.base.DOCUMENT_CACHE_DIR", tmp_path / "doc_cache"
)
# ---------------------------------------------------------------------------
# TestAppMentionHandler
# ---------------------------------------------------------------------------
class TestAppMentionHandler:
"""Verify that the app_mention event handler is registered."""
def test_app_mention_registered_on_connect(self):
"""connect() should register both 'message' and 'app_mention' handlers."""
config = PlatformConfig(enabled=True, token="xoxb-fake")
adapter = SlackAdapter(config)
# Track which events get registered
registered_events = []
registered_commands = []
mock_app = MagicMock()
def mock_event(event_type):
def decorator(fn):
registered_events.append(event_type)
return fn
return decorator
def mock_command(cmd):
def decorator(fn):
registered_commands.append(cmd)
return fn
return decorator
mock_app.event = mock_event
mock_app.command = mock_command
mock_app.client = AsyncMock()
mock_app.client.auth_test = AsyncMock(return_value={
"user_id": "U_BOT",
"user": "testbot",
})
with patch.object(_slack_mod, "AsyncApp", return_value=mock_app), \
patch.object(_slack_mod, "AsyncSocketModeHandler", return_value=MagicMock()), \
patch.dict(os.environ, {"SLACK_APP_TOKEN": "xapp-fake"}), \
patch("asyncio.create_task"):
asyncio.run(adapter.connect())
assert "message" in registered_events
assert "app_mention" in registered_events
assert "/hermes" in registered_commands
# ---------------------------------------------------------------------------
# TestSendDocument
# ---------------------------------------------------------------------------
class TestSendDocument:
@pytest.mark.asyncio
async def test_send_document_success(self, adapter, tmp_path):
test_file = tmp_path / "report.pdf"
test_file.write_bytes(b"%PDF-1.4 fake content")
adapter._app.client.files_upload_v2 = AsyncMock(return_value={"ok": True})
result = await adapter.send_document(
chat_id="C123",
file_path=str(test_file),
caption="Here's the report",
)
assert result.success
adapter._app.client.files_upload_v2.assert_called_once()
call_kwargs = adapter._app.client.files_upload_v2.call_args[1]
assert call_kwargs["channel"] == "C123"
assert call_kwargs["file"] == str(test_file)
assert call_kwargs["filename"] == "report.pdf"
assert call_kwargs["initial_comment"] == "Here's the report"
@pytest.mark.asyncio
async def test_send_document_custom_name(self, adapter, tmp_path):
test_file = tmp_path / "data.csv"
test_file.write_bytes(b"a,b,c\n1,2,3")
adapter._app.client.files_upload_v2 = AsyncMock(return_value={"ok": True})
result = await adapter.send_document(
chat_id="C123",
file_path=str(test_file),
file_name="quarterly-report.csv",
)
assert result.success
call_kwargs = adapter._app.client.files_upload_v2.call_args[1]
assert call_kwargs["filename"] == "quarterly-report.csv"
@pytest.mark.asyncio
async def test_send_document_missing_file(self, adapter):
result = await adapter.send_document(
chat_id="C123",
file_path="/nonexistent/file.pdf",
)
assert not result.success
assert "not found" in result.error.lower()
@pytest.mark.asyncio
async def test_send_document_not_connected(self, adapter):
adapter._app = None
result = await adapter.send_document(
chat_id="C123",
file_path="/some/file.pdf",
)
assert not result.success
assert "Not connected" in result.error
@pytest.mark.asyncio
async def test_send_document_api_error_falls_back(self, adapter, tmp_path):
test_file = tmp_path / "doc.pdf"
test_file.write_bytes(b"content")
adapter._app.client.files_upload_v2 = AsyncMock(
side_effect=RuntimeError("Slack API error")
)
# Should fall back to base class (text message)
result = await adapter.send_document(
chat_id="C123",
file_path=str(test_file),
)
# Base class send() is also mocked, so check it was attempted
adapter._app.client.chat_postMessage.assert_called_once()
@pytest.mark.asyncio
async def test_send_document_with_thread(self, adapter, tmp_path):
test_file = tmp_path / "notes.txt"
test_file.write_bytes(b"some notes")
adapter._app.client.files_upload_v2 = AsyncMock(return_value={"ok": True})
result = await adapter.send_document(
chat_id="C123",
file_path=str(test_file),
reply_to="1234567890.123456",
)
assert result.success
call_kwargs = adapter._app.client.files_upload_v2.call_args[1]
assert call_kwargs["thread_ts"] == "1234567890.123456"
# ---------------------------------------------------------------------------
# TestSendVideo
# ---------------------------------------------------------------------------
class TestSendVideo:
@pytest.mark.asyncio
async def test_send_video_success(self, adapter, tmp_path):
video = tmp_path / "clip.mp4"
video.write_bytes(b"fake video data")
adapter._app.client.files_upload_v2 = AsyncMock(return_value={"ok": True})
result = await adapter.send_video(
chat_id="C123",
video_path=str(video),
caption="Check this out",
)
assert result.success
call_kwargs = adapter._app.client.files_upload_v2.call_args[1]
assert call_kwargs["filename"] == "clip.mp4"
assert call_kwargs["initial_comment"] == "Check this out"
@pytest.mark.asyncio
async def test_send_video_missing_file(self, adapter):
result = await adapter.send_video(
chat_id="C123",
video_path="/nonexistent/video.mp4",
)
assert not result.success
assert "not found" in result.error.lower()
@pytest.mark.asyncio
async def test_send_video_not_connected(self, adapter):
adapter._app = None
result = await adapter.send_video(
chat_id="C123",
video_path="/some/video.mp4",
)
assert not result.success
assert "Not connected" in result.error
@pytest.mark.asyncio
async def test_send_video_api_error_falls_back(self, adapter, tmp_path):
video = tmp_path / "clip.mp4"
video.write_bytes(b"fake video")
adapter._app.client.files_upload_v2 = AsyncMock(
side_effect=RuntimeError("Slack API error")
)
# Should fall back to base class (text message)
result = await adapter.send_video(
chat_id="C123",
video_path=str(video),
)
adapter._app.client.chat_postMessage.assert_called_once()
# ---------------------------------------------------------------------------
# TestIncomingDocumentHandling
# ---------------------------------------------------------------------------
class TestIncomingDocumentHandling:
def _make_event(self, files=None, text="hello", channel_type="im"):
"""Build a mock Slack message event with file attachments."""
return {
"text": text,
"user": "U_USER",
"channel": "C123",
"channel_type": channel_type,
"ts": "1234567890.000001",
"files": files or [],
}
@pytest.mark.asyncio
async def test_pdf_document_cached(self, adapter):
"""A PDF attachment should be downloaded, cached, and set as DOCUMENT type."""
pdf_bytes = b"%PDF-1.4 fake content"
with patch.object(adapter, "_download_slack_file_bytes", new_callable=AsyncMock) as dl:
dl.return_value = pdf_bytes
event = self._make_event(files=[{
"mimetype": "application/pdf",
"name": "report.pdf",
"url_private_download": "https://files.slack.com/report.pdf",
"size": len(pdf_bytes),
}])
await adapter._handle_slack_message(event)
msg_event = adapter.handle_message.call_args[0][0]
assert msg_event.message_type == MessageType.DOCUMENT
assert len(msg_event.media_urls) == 1
assert os.path.exists(msg_event.media_urls[0])
assert msg_event.media_types == ["application/pdf"]
@pytest.mark.asyncio
async def test_txt_document_injects_content(self, adapter):
"""A .txt file under 100KB should have its content injected into event text."""
content = b"Hello from a text file"
with patch.object(adapter, "_download_slack_file_bytes", new_callable=AsyncMock) as dl:
dl.return_value = content
event = self._make_event(
text="summarize this",
files=[{
"mimetype": "text/plain",
"name": "notes.txt",
"url_private_download": "https://files.slack.com/notes.txt",
"size": len(content),
}],
)
await adapter._handle_slack_message(event)
msg_event = adapter.handle_message.call_args[0][0]
assert "Hello from a text file" in msg_event.text
assert "[Content of notes.txt]" in msg_event.text
assert "summarize this" in msg_event.text
@pytest.mark.asyncio
async def test_md_document_injects_content(self, adapter):
"""A .md file under 100KB should have its content injected."""
content = b"# Title\nSome markdown content"
with patch.object(adapter, "_download_slack_file_bytes", new_callable=AsyncMock) as dl:
dl.return_value = content
event = self._make_event(files=[{
"mimetype": "text/markdown",
"name": "readme.md",
"url_private_download": "https://files.slack.com/readme.md",
"size": len(content),
}], text="")
await adapter._handle_slack_message(event)
msg_event = adapter.handle_message.call_args[0][0]
assert "# Title" in msg_event.text
@pytest.mark.asyncio
async def test_large_txt_not_injected(self, adapter):
"""A .txt file over 100KB should be cached but NOT injected."""
content = b"x" * (200 * 1024)
with patch.object(adapter, "_download_slack_file_bytes", new_callable=AsyncMock) as dl:
dl.return_value = content
event = self._make_event(files=[{
"mimetype": "text/plain",
"name": "big.txt",
"url_private_download": "https://files.slack.com/big.txt",
"size": len(content),
}], text="")
await adapter._handle_slack_message(event)
msg_event = adapter.handle_message.call_args[0][0]
assert len(msg_event.media_urls) == 1
assert "[Content of" not in (msg_event.text or "")
@pytest.mark.asyncio
async def test_unsupported_file_type_skipped(self, adapter):
"""A .zip file should be silently skipped."""
event = self._make_event(files=[{
"mimetype": "application/zip",
"name": "archive.zip",
"url_private_download": "https://files.slack.com/archive.zip",
"size": 1024,
}])
await adapter._handle_slack_message(event)
msg_event = adapter.handle_message.call_args[0][0]
assert msg_event.message_type == MessageType.TEXT
assert len(msg_event.media_urls) == 0
@pytest.mark.asyncio
async def test_oversized_document_skipped(self, adapter):
"""A document over 20MB should be skipped."""
event = self._make_event(files=[{
"mimetype": "application/pdf",
"name": "huge.pdf",
"url_private_download": "https://files.slack.com/huge.pdf",
"size": 25 * 1024 * 1024,
}])
await adapter._handle_slack_message(event)
msg_event = adapter.handle_message.call_args[0][0]
assert len(msg_event.media_urls) == 0
@pytest.mark.asyncio
async def test_document_download_error_handled(self, adapter):
"""If document download fails, handler should not crash."""
with patch.object(adapter, "_download_slack_file_bytes", new_callable=AsyncMock) as dl:
dl.side_effect = RuntimeError("download failed")
event = self._make_event(files=[{
"mimetype": "application/pdf",
"name": "report.pdf",
"url_private_download": "https://files.slack.com/report.pdf",
"size": 1024,
}])
await adapter._handle_slack_message(event)
# Handler should still be called (the exception is caught)
adapter.handle_message.assert_called_once()
@pytest.mark.asyncio
async def test_image_still_handled(self, adapter):
"""Image attachments should still go through the image path, not document."""
with patch.object(adapter, "_download_slack_file", new_callable=AsyncMock) as dl:
dl.return_value = "/tmp/cached_image.jpg"
event = self._make_event(files=[{
"mimetype": "image/jpeg",
"name": "photo.jpg",
"url_private_download": "https://files.slack.com/photo.jpg",
"size": 1024,
}])
await adapter._handle_slack_message(event)
msg_event = adapter.handle_message.call_args[0][0]
assert msg_event.message_type == MessageType.PHOTO
# ---------------------------------------------------------------------------
# TestMessageRouting
# ---------------------------------------------------------------------------
class TestMessageRouting:
@pytest.mark.asyncio
async def test_dm_processed_without_mention(self, adapter):
"""DM messages should be processed without requiring a bot mention."""
event = {
"text": "hello",
"user": "U_USER",
"channel": "D123",
"channel_type": "im",
"ts": "1234567890.000001",
}
await adapter._handle_slack_message(event)
adapter.handle_message.assert_called_once()
@pytest.mark.asyncio
async def test_channel_message_requires_mention(self, adapter):
"""Channel messages without a bot mention should be ignored."""
event = {
"text": "just talking",
"user": "U_USER",
"channel": "C123",
"channel_type": "channel",
"ts": "1234567890.000001",
}
await adapter._handle_slack_message(event)
adapter.handle_message.assert_not_called()
@pytest.mark.asyncio
async def test_channel_mention_strips_bot_id(self, adapter):
"""When mentioned in a channel, the bot mention should be stripped."""
event = {
"text": "<@U_BOT> what's the weather?",
"user": "U_USER",
"channel": "C123",
"channel_type": "channel",
"ts": "1234567890.000001",
}
await adapter._handle_slack_message(event)
msg_event = adapter.handle_message.call_args[0][0]
assert msg_event.text == "what's the weather?"
assert "<@U_BOT>" not in msg_event.text
@pytest.mark.asyncio
async def test_bot_messages_ignored(self, adapter):
"""Messages from bots should be ignored."""
event = {
"text": "bot response",
"bot_id": "B_OTHER",
"channel": "C123",
"channel_type": "im",
"ts": "1234567890.000001",
}
await adapter._handle_slack_message(event)
adapter.handle_message.assert_not_called()
@pytest.mark.asyncio
async def test_message_edits_ignored(self, adapter):
"""Message edits should be ignored."""
event = {
"text": "edited message",
"user": "U_USER",
"channel": "C123",
"channel_type": "im",
"ts": "1234567890.000001",
"subtype": "message_changed",
}
await adapter._handle_slack_message(event)
adapter.handle_message.assert_not_called()
# ---------------------------------------------------------------------------
# TestSendTyping — assistant.threads.setStatus
# ---------------------------------------------------------------------------
class TestSendTyping:
"""Test typing indicator via assistant.threads.setStatus."""
@pytest.mark.asyncio
async def test_sets_status_in_thread(self, adapter):
adapter._app.client.assistant_threads_setStatus = AsyncMock()
await adapter.send_typing("C123", metadata={"thread_id": "parent_ts"})
adapter._app.client.assistant_threads_setStatus.assert_called_once_with(
channel_id="C123",
thread_ts="parent_ts",
status="is thinking...",
)
@pytest.mark.asyncio
async def test_noop_without_thread(self, adapter):
adapter._app.client.assistant_threads_setStatus = AsyncMock()
await adapter.send_typing("C123")
adapter._app.client.assistant_threads_setStatus.assert_not_called()
@pytest.mark.asyncio
async def test_handles_missing_scope_gracefully(self, adapter):
adapter._app.client.assistant_threads_setStatus = AsyncMock(
side_effect=Exception("missing_scope")
)
# Should not raise
await adapter.send_typing("C123", metadata={"thread_id": "ts1"})
@pytest.mark.asyncio
async def test_uses_thread_ts_fallback(self, adapter):
adapter._app.client.assistant_threads_setStatus = AsyncMock()
await adapter.send_typing("C123", metadata={"thread_ts": "fallback_ts"})
adapter._app.client.assistant_threads_setStatus.assert_called_once_with(
channel_id="C123",
thread_ts="fallback_ts",
status="is thinking...",
)
# ---------------------------------------------------------------------------
# TestFormatMessage — Markdown → mrkdwn conversion
# ---------------------------------------------------------------------------
class TestFormatMessage:
"""Test markdown to Slack mrkdwn conversion."""
def test_bold_conversion(self, adapter):
assert adapter.format_message("**hello**") == "*hello*"
def test_italic_asterisk_conversion(self, adapter):
assert adapter.format_message("*hello*") == "_hello_"
def test_italic_underscore_preserved(self, adapter):
assert adapter.format_message("_hello_") == "_hello_"
def test_header_to_bold(self, adapter):
assert adapter.format_message("## Section Title") == "*Section Title*"
def test_header_with_bold_content(self, adapter):
# **bold** inside a header should not double-wrap
assert adapter.format_message("## **Title**") == "*Title*"
def test_link_conversion(self, adapter):
result = adapter.format_message("[click here](https://example.com)")
assert result == "<https://example.com|click here>"
def test_strikethrough(self, adapter):
assert adapter.format_message("~~deleted~~") == "~deleted~"
def test_code_block_preserved(self, adapter):
code = "```python\nx = **not bold**\n```"
assert adapter.format_message(code) == code
def test_inline_code_preserved(self, adapter):
text = "Use `**raw**` syntax"
assert adapter.format_message(text) == "Use `**raw**` syntax"
def test_mixed_content(self, adapter):
text = "**Bold** and *italic* with `code`"
result = adapter.format_message(text)
assert "*Bold*" in result
assert "_italic_" in result
assert "`code`" in result
def test_empty_string(self, adapter):
assert adapter.format_message("") == ""
def test_none_passthrough(self, adapter):
assert adapter.format_message(None) is None
# ---------------------------------------------------------------------------
# TestReactions
# ---------------------------------------------------------------------------
class TestReactions:
"""Test emoji reaction methods."""
@pytest.mark.asyncio
async def test_add_reaction_calls_api(self, adapter):
adapter._app.client.reactions_add = AsyncMock()
result = await adapter._add_reaction("C123", "ts1", "eyes")
assert result is True
adapter._app.client.reactions_add.assert_called_once_with(
channel="C123", timestamp="ts1", name="eyes"
)
@pytest.mark.asyncio
async def test_add_reaction_handles_error(self, adapter):
adapter._app.client.reactions_add = AsyncMock(side_effect=Exception("already_reacted"))
result = await adapter._add_reaction("C123", "ts1", "eyes")
assert result is False
@pytest.mark.asyncio
async def test_remove_reaction_calls_api(self, adapter):
adapter._app.client.reactions_remove = AsyncMock()
result = await adapter._remove_reaction("C123", "ts1", "eyes")
assert result is True
@pytest.mark.asyncio
async def test_reactions_in_message_flow(self, adapter):
"""Reactions should be added on receipt and swapped on completion."""
adapter._app.client.reactions_add = AsyncMock()
adapter._app.client.reactions_remove = AsyncMock()
adapter._app.client.users_info = AsyncMock(return_value={
"user": {"profile": {"display_name": "Tyler"}}
})
event = {
"text": "hello",
"user": "U_USER",
"channel": "C123",
"channel_type": "im",
"ts": "1234567890.000001",
}
await adapter._handle_slack_message(event)
# Should have added 👀, then removed 👀, then added ✅
add_calls = adapter._app.client.reactions_add.call_args_list
remove_calls = adapter._app.client.reactions_remove.call_args_list
assert len(add_calls) == 2
assert add_calls[0].kwargs["name"] == "eyes"
assert add_calls[1].kwargs["name"] == "white_check_mark"
assert len(remove_calls) == 1
assert remove_calls[0].kwargs["name"] == "eyes"
# ---------------------------------------------------------------------------
# TestUserNameResolution
# ---------------------------------------------------------------------------
class TestUserNameResolution:
"""Test user identity resolution."""
@pytest.mark.asyncio
async def test_resolves_display_name(self, adapter):
adapter._app.client.users_info = AsyncMock(return_value={
"user": {"profile": {"display_name": "Tyler", "real_name": "Tyler B"}}
})
name = await adapter._resolve_user_name("U123")
assert name == "Tyler"
@pytest.mark.asyncio
async def test_falls_back_to_real_name(self, adapter):
adapter._app.client.users_info = AsyncMock(return_value={
"user": {"profile": {"display_name": "", "real_name": "Tyler B"}}
})
name = await adapter._resolve_user_name("U123")
assert name == "Tyler B"
@pytest.mark.asyncio
async def test_caches_result(self, adapter):
adapter._app.client.users_info = AsyncMock(return_value={
"user": {"profile": {"display_name": "Tyler"}}
})
await adapter._resolve_user_name("U123")
await adapter._resolve_user_name("U123")
# Only one API call despite two lookups
assert adapter._app.client.users_info.call_count == 1
@pytest.mark.asyncio
async def test_handles_api_error(self, adapter):
adapter._app.client.users_info = AsyncMock(side_effect=Exception("rate limited"))
name = await adapter._resolve_user_name("U123")
assert name == "U123" # Falls back to user_id
@pytest.mark.asyncio
async def test_user_name_in_message_source(self, adapter):
"""Message source should include resolved user name."""
adapter._app.client.users_info = AsyncMock(return_value={
"user": {"profile": {"display_name": "Tyler"}}
})
adapter._app.client.reactions_add = AsyncMock()
adapter._app.client.reactions_remove = AsyncMock()
event = {
"text": "hello",
"user": "U_USER",
"channel": "C123",
"channel_type": "im",
"ts": "1234567890.000001",
}
await adapter._handle_slack_message(event)
# Check the source in the MessageEvent passed to handle_message
msg_event = adapter.handle_message.call_args[0][0]
assert msg_event.source.user_name == "Tyler"
# ---------------------------------------------------------------------------
# TestSlashCommands — expanded command set
# ---------------------------------------------------------------------------
class TestSlashCommands:
"""Test slash command routing."""
@pytest.mark.asyncio
async def test_compact_maps_to_compress(self, adapter):
command = {"text": "compact", "user_id": "U1", "channel_id": "C1"}
await adapter._handle_slash_command(command)
msg = adapter.handle_message.call_args[0][0]
assert msg.text == "/compress"
@pytest.mark.asyncio
async def test_resume_command(self, adapter):
command = {"text": "resume my session", "user_id": "U1", "channel_id": "C1"}
await adapter._handle_slash_command(command)
msg = adapter.handle_message.call_args[0][0]
assert msg.text == "/resume my session"
@pytest.mark.asyncio
async def test_background_command(self, adapter):
command = {"text": "background run tests", "user_id": "U1", "channel_id": "C1"}
await adapter._handle_slash_command(command)
msg = adapter.handle_message.call_args[0][0]
assert msg.text == "/background run tests"
@pytest.mark.asyncio
async def test_usage_command(self, adapter):
command = {"text": "usage", "user_id": "U1", "channel_id": "C1"}
await adapter._handle_slash_command(command)
msg = adapter.handle_message.call_args[0][0]
assert msg.text == "/usage"
@pytest.mark.asyncio
async def test_reasoning_command(self, adapter):
command = {"text": "reasoning", "user_id": "U1", "channel_id": "C1"}
await adapter._handle_slash_command(command)
msg = adapter.handle_message.call_args[0][0]
assert msg.text == "/reasoning"
# ---------------------------------------------------------------------------
# TestMessageSplitting
# ---------------------------------------------------------------------------
class TestMessageSplitting:
"""Test that long messages are split before sending."""
@pytest.mark.asyncio
async def test_long_message_split_into_chunks(self, adapter):
"""Messages over MAX_MESSAGE_LENGTH should be split."""
long_text = "x" * 45000 # Over Slack's 40k API limit
adapter._app.client.chat_postMessage = AsyncMock(
return_value={"ts": "ts1"}
)
await adapter.send("C123", long_text)
# Should have been called multiple times
assert adapter._app.client.chat_postMessage.call_count >= 2
@pytest.mark.asyncio
async def test_short_message_single_send(self, adapter):
"""Short messages should be sent in one call."""
adapter._app.client.chat_postMessage = AsyncMock(
return_value={"ts": "ts1"}
)
await adapter.send("C123", "hello world")
assert adapter._app.client.chat_postMessage.call_count == 1
# ---------------------------------------------------------------------------
# TestReplyBroadcast
# ---------------------------------------------------------------------------
class TestReplyBroadcast:
"""Test reply_broadcast config option."""
@pytest.mark.asyncio
async def test_broadcast_disabled_by_default(self, adapter):
adapter._app.client.chat_postMessage = AsyncMock(
return_value={"ts": "ts1"}
)
await adapter.send("C123", "hi", metadata={"thread_id": "parent_ts"})
kwargs = adapter._app.client.chat_postMessage.call_args.kwargs
assert "reply_broadcast" not in kwargs
@pytest.mark.asyncio
async def test_broadcast_enabled_via_config(self, adapter):
adapter.config.extra["reply_broadcast"] = True
adapter._app.client.chat_postMessage = AsyncMock(
return_value={"ts": "ts1"}
)
await adapter.send("C123", "hi", metadata={"thread_id": "parent_ts"})
kwargs = adapter._app.client.chat_postMessage.call_args.kwargs
assert kwargs.get("reply_broadcast") is True
# ---------------------------------------------------------------------------
# TestFallbackPreservesThreadContext
# ---------------------------------------------------------------------------
class TestFallbackPreservesThreadContext:
"""Bug fix: file upload fallbacks lost thread context (metadata) when
calling super() without metadata, causing replies to appear outside
the thread."""
@pytest.mark.asyncio
async def test_send_image_file_fallback_preserves_thread(self, adapter, tmp_path):
test_file = tmp_path / "photo.jpg"
test_file.write_bytes(b"\xff\xd8\xff\xe0")
adapter._app.client.files_upload_v2 = AsyncMock(
side_effect=Exception("upload failed")
)
adapter._app.client.chat_postMessage = AsyncMock(
return_value={"ts": "msg_ts"}
)
metadata = {"thread_id": "parent_ts_123"}
await adapter.send_image_file(
chat_id="C123",
image_path=str(test_file),
caption="test image",
metadata=metadata,
)
call_kwargs = adapter._app.client.chat_postMessage.call_args.kwargs
assert call_kwargs.get("thread_ts") == "parent_ts_123"
@pytest.mark.asyncio
async def test_send_video_fallback_preserves_thread(self, adapter, tmp_path):
test_file = tmp_path / "clip.mp4"
test_file.write_bytes(b"\x00\x00\x00\x1c")
adapter._app.client.files_upload_v2 = AsyncMock(
side_effect=Exception("upload failed")
)
adapter._app.client.chat_postMessage = AsyncMock(
return_value={"ts": "msg_ts"}
)
metadata = {"thread_id": "parent_ts_456"}
await adapter.send_video(
chat_id="C123",
video_path=str(test_file),
metadata=metadata,
)
call_kwargs = adapter._app.client.chat_postMessage.call_args.kwargs
assert call_kwargs.get("thread_ts") == "parent_ts_456"
@pytest.mark.asyncio
async def test_send_document_fallback_preserves_thread(self, adapter, tmp_path):
test_file = tmp_path / "report.pdf"
test_file.write_bytes(b"%PDF-1.4")
adapter._app.client.files_upload_v2 = AsyncMock(
side_effect=Exception("upload failed")
)
adapter._app.client.chat_postMessage = AsyncMock(
return_value={"ts": "msg_ts"}
)
metadata = {"thread_id": "parent_ts_789"}
await adapter.send_document(
chat_id="C123",
file_path=str(test_file),
caption="report",
metadata=metadata,
)
call_kwargs = adapter._app.client.chat_postMessage.call_args.kwargs
assert call_kwargs.get("thread_ts") == "parent_ts_789"
@pytest.mark.asyncio
async def test_send_image_file_fallback_includes_caption(self, adapter, tmp_path):
test_file = tmp_path / "photo.jpg"
test_file.write_bytes(b"\xff\xd8\xff\xe0")
adapter._app.client.files_upload_v2 = AsyncMock(
side_effect=Exception("upload failed")
)
adapter._app.client.chat_postMessage = AsyncMock(
return_value={"ts": "msg_ts"}
)
await adapter.send_image_file(
chat_id="C123",
image_path=str(test_file),
caption="important screenshot",
)
call_kwargs = adapter._app.client.chat_postMessage.call_args.kwargs
assert "important screenshot" in call_kwargs["text"]

View file

@ -0,0 +1,215 @@
"""Tests for SMS (Twilio) platform integration.
Covers config loading, format/truncate, echo prevention,
requirements check, and toolset verification.
"""
import os
from unittest.mock import patch
import pytest
from gateway.config import Platform, PlatformConfig, HomeChannel
# ── Config loading ──────────────────────────────────────────────────
class TestSmsConfigLoading:
"""Verify _apply_env_overrides wires SMS correctly."""
def test_sms_platform_enum_exists(self):
assert Platform.SMS.value == "sms"
def test_env_overrides_create_sms_config(self):
from gateway.config import load_gateway_config
env = {
"TWILIO_ACCOUNT_SID": "ACtest123",
"TWILIO_AUTH_TOKEN": "token_abc",
"TWILIO_PHONE_NUMBER": "+15551234567",
}
with patch.dict(os.environ, env, clear=False):
config = load_gateway_config()
assert Platform.SMS in config.platforms
pc = config.platforms[Platform.SMS]
assert pc.enabled is True
assert pc.api_key == "token_abc"
def test_env_overrides_set_home_channel(self):
from gateway.config import load_gateway_config
env = {
"TWILIO_ACCOUNT_SID": "ACtest123",
"TWILIO_AUTH_TOKEN": "token_abc",
"TWILIO_PHONE_NUMBER": "+15551234567",
"SMS_HOME_CHANNEL": "+15559876543",
"SMS_HOME_CHANNEL_NAME": "My Phone",
}
with patch.dict(os.environ, env, clear=False):
config = load_gateway_config()
hc = config.platforms[Platform.SMS].home_channel
assert hc is not None
assert hc.chat_id == "+15559876543"
assert hc.name == "My Phone"
assert hc.platform == Platform.SMS
def test_sms_in_connected_platforms(self):
from gateway.config import load_gateway_config
env = {
"TWILIO_ACCOUNT_SID": "ACtest123",
"TWILIO_AUTH_TOKEN": "token_abc",
}
with patch.dict(os.environ, env, clear=False):
config = load_gateway_config()
connected = config.get_connected_platforms()
assert Platform.SMS in connected
# ── Format / truncate ───────────────────────────────────────────────
class TestSmsFormatAndTruncate:
"""Test SmsAdapter.format_message strips markdown."""
def _make_adapter(self):
from gateway.platforms.sms import SmsAdapter
env = {
"TWILIO_ACCOUNT_SID": "ACtest",
"TWILIO_AUTH_TOKEN": "tok",
"TWILIO_PHONE_NUMBER": "+15550001111",
}
with patch.dict(os.environ, env):
pc = PlatformConfig(enabled=True, api_key="tok")
adapter = object.__new__(SmsAdapter)
adapter.config = pc
adapter._platform = Platform.SMS
adapter._account_sid = "ACtest"
adapter._auth_token = "tok"
adapter._from_number = "+15550001111"
return adapter
def test_strips_bold(self):
adapter = self._make_adapter()
assert adapter.format_message("**hello**") == "hello"
def test_strips_italic(self):
adapter = self._make_adapter()
assert adapter.format_message("*world*") == "world"
def test_strips_code_blocks(self):
adapter = self._make_adapter()
result = adapter.format_message("```python\nprint('hi')\n```")
assert "```" not in result
assert "print('hi')" in result
def test_strips_inline_code(self):
adapter = self._make_adapter()
assert adapter.format_message("`code`") == "code"
def test_strips_headers(self):
adapter = self._make_adapter()
assert adapter.format_message("## Title") == "Title"
def test_strips_links(self):
adapter = self._make_adapter()
assert adapter.format_message("[click](https://example.com)") == "click"
def test_collapses_newlines(self):
adapter = self._make_adapter()
result = adapter.format_message("a\n\n\n\nb")
assert result == "a\n\nb"
# ── Echo prevention ────────────────────────────────────────────────
class TestSmsEchoPrevention:
"""Adapter should ignore messages from its own number."""
def test_own_number_detection(self):
"""The adapter stores _from_number for echo prevention."""
from gateway.platforms.sms import SmsAdapter
env = {
"TWILIO_ACCOUNT_SID": "ACtest",
"TWILIO_AUTH_TOKEN": "tok",
"TWILIO_PHONE_NUMBER": "+15550001111",
}
with patch.dict(os.environ, env):
pc = PlatformConfig(enabled=True, api_key="tok")
adapter = SmsAdapter(pc)
assert adapter._from_number == "+15550001111"
# ── Requirements check ─────────────────────────────────────────────
class TestSmsRequirements:
def test_check_sms_requirements_missing_sid(self):
from gateway.platforms.sms import check_sms_requirements
env = {"TWILIO_AUTH_TOKEN": "tok"}
with patch.dict(os.environ, env, clear=True):
assert check_sms_requirements() is False
def test_check_sms_requirements_missing_token(self):
from gateway.platforms.sms import check_sms_requirements
env = {"TWILIO_ACCOUNT_SID": "ACtest"}
with patch.dict(os.environ, env, clear=True):
assert check_sms_requirements() is False
def test_check_sms_requirements_both_set(self):
from gateway.platforms.sms import check_sms_requirements
env = {
"TWILIO_ACCOUNT_SID": "ACtest",
"TWILIO_AUTH_TOKEN": "tok",
}
with patch.dict(os.environ, env, clear=False):
# Only returns True if aiohttp is also importable
result = check_sms_requirements()
try:
import aiohttp # noqa: F401
assert result is True
except ImportError:
assert result is False
# ── Toolset verification ───────────────────────────────────────────
class TestSmsToolset:
def test_hermes_sms_toolset_exists(self):
from toolsets import get_toolset
ts = get_toolset("hermes-sms")
assert ts is not None
assert "tools" in ts
def test_hermes_sms_in_gateway_includes(self):
from toolsets import get_toolset
gw = get_toolset("hermes-gateway")
assert gw is not None
assert "hermes-sms" in gw["includes"]
def test_sms_platform_hint_exists(self):
from agent.prompt_builder import PLATFORM_HINTS
assert "sms" in PLATFORM_HINTS
assert "concise" in PLATFORM_HINTS["sms"].lower()
def test_sms_in_scheduler_platform_map(self):
"""Verify cron scheduler recognizes 'sms' as a valid platform."""
# Just check the Platform enum has SMS — the scheduler imports it dynamically
assert Platform.SMS.value == "sms"
def test_sms_in_send_message_platform_map(self):
"""Verify send_message_tool recognizes 'sms'."""
# The platform_map is built inside _handle_send; verify SMS enum exists
assert hasattr(Platform, "SMS")
def test_sms_in_cronjob_deliver_description(self):
"""Verify cronjob_tools mentions sms in deliver description."""
from tools.cronjob_tools import CRONJOB_SCHEMA
deliver_desc = CRONJOB_SCHEMA["parameters"]["properties"]["deliver"]["description"]
assert "sms" in deliver_desc.lower()

View file

@ -0,0 +1,81 @@
"""Tests for SSL certificate auto-detection in gateway/run.py."""
import importlib
import os
from unittest.mock import patch, MagicMock
def _load_ensure_ssl():
"""Import _ensure_ssl_certs fresh (gateway/run.py has heavy deps, so we
extract just the function source to avoid importing the whole gateway)."""
# We can test via the actual module since conftest isolates HERMES_HOME,
# but we need to be careful about side effects. Instead, replicate the
# logic in a controlled way.
from types import ModuleType
import textwrap, ssl as _ssl # noqa: F401
code = textwrap.dedent("""\
import os, ssl
def _ensure_ssl_certs():
if "SSL_CERT_FILE" in os.environ:
return
paths = ssl.get_default_verify_paths()
for candidate in (paths.cafile, paths.openssl_cafile):
if candidate and os.path.exists(candidate):
os.environ["SSL_CERT_FILE"] = candidate
return
try:
import certifi
os.environ["SSL_CERT_FILE"] = certifi.where()
return
except ImportError:
pass
for candidate in (
"/etc/ssl/certs/ca-certificates.crt",
"/etc/ssl/cert.pem",
):
if os.path.exists(candidate):
os.environ["SSL_CERT_FILE"] = candidate
return
""")
mod = ModuleType("_ssl_helper")
exec(code, mod.__dict__)
return mod._ensure_ssl_certs
class TestEnsureSslCerts:
def test_respects_existing_env_var(self):
fn = _load_ensure_ssl()
with patch.dict(os.environ, {"SSL_CERT_FILE": "/custom/ca.pem"}):
fn()
assert os.environ["SSL_CERT_FILE"] == "/custom/ca.pem"
def test_sets_from_ssl_default_paths(self, tmp_path):
fn = _load_ensure_ssl()
cert = tmp_path / "ca.crt"
cert.write_text("FAKE CERT")
mock_paths = MagicMock()
mock_paths.cafile = str(cert)
mock_paths.openssl_cafile = None
env = {k: v for k, v in os.environ.items() if k != "SSL_CERT_FILE"}
with patch.dict(os.environ, env, clear=True), \
patch("ssl.get_default_verify_paths", return_value=mock_paths):
fn()
assert os.environ.get("SSL_CERT_FILE") == str(cert)
def test_no_op_when_nothing_found(self):
fn = _load_ensure_ssl()
mock_paths = MagicMock()
mock_paths.cafile = None
mock_paths.openssl_cafile = None
env = {k: v for k, v in os.environ.items() if k != "SSL_CERT_FILE"}
with patch.dict(os.environ, env, clear=True), \
patch("ssl.get_default_verify_paths", return_value=mock_paths), \
patch("os.path.exists", return_value=False), \
patch.dict("sys.modules", {"certifi": None}):
fn()
assert "SSL_CERT_FILE" not in os.environ

View file

@ -0,0 +1,157 @@
"""Tests for gateway runtime status tracking."""
import json
import os
from gateway import status
class TestGatewayPidState:
def test_write_pid_file_records_gateway_metadata(self, tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
status.write_pid_file()
payload = json.loads((tmp_path / "gateway.pid").read_text())
assert payload["pid"] == os.getpid()
assert payload["kind"] == "hermes-gateway"
assert isinstance(payload["argv"], list)
assert payload["argv"]
def test_get_running_pid_rejects_live_non_gateway_pid(self, tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
pid_path = tmp_path / "gateway.pid"
pid_path.write_text(str(os.getpid()))
assert status.get_running_pid() is None
assert not pid_path.exists()
def test_get_running_pid_accepts_gateway_metadata_when_cmdline_unavailable(self, tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
pid_path = tmp_path / "gateway.pid"
pid_path.write_text(json.dumps({
"pid": os.getpid(),
"kind": "hermes-gateway",
"argv": ["python", "-m", "hermes_cli.main", "gateway"],
"start_time": 123,
}))
monkeypatch.setattr(status.os, "kill", lambda pid, sig: None)
monkeypatch.setattr(status, "_get_process_start_time", lambda pid: 123)
monkeypatch.setattr(status, "_read_process_cmdline", lambda pid: None)
assert status.get_running_pid() == os.getpid()
def test_get_running_pid_accepts_script_style_gateway_cmdline(self, tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
pid_path = tmp_path / "gateway.pid"
pid_path.write_text(json.dumps({
"pid": os.getpid(),
"kind": "hermes-gateway",
"argv": ["/venv/bin/python", "/repo/hermes_cli/main.py", "gateway", "run", "--replace"],
"start_time": 123,
}))
monkeypatch.setattr(status.os, "kill", lambda pid, sig: None)
monkeypatch.setattr(status, "_get_process_start_time", lambda pid: 123)
monkeypatch.setattr(
status,
"_read_process_cmdline",
lambda pid: "/venv/bin/python /repo/hermes_cli/main.py gateway run --replace",
)
assert status.get_running_pid() == os.getpid()
class TestGatewayRuntimeStatus:
def test_write_runtime_status_overwrites_stale_pid_on_restart(self, tmp_path, monkeypatch):
"""Regression: setdefault() preserved stale PID from previous process (#1631)."""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
# Simulate a previous gateway run that left a state file with a stale PID
state_path = tmp_path / "gateway_state.json"
state_path.write_text(json.dumps({
"pid": 99999,
"start_time": 1000.0,
"kind": "hermes-gateway",
"platforms": {},
"updated_at": "2025-01-01T00:00:00Z",
}))
status.write_runtime_status(gateway_state="running")
payload = status.read_runtime_status()
assert payload["pid"] == os.getpid(), "PID should be overwritten, not preserved via setdefault"
assert payload["start_time"] != 1000.0, "start_time should be overwritten on restart"
def test_write_runtime_status_records_platform_failure(self, tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
status.write_runtime_status(
gateway_state="startup_failed",
exit_reason="telegram conflict",
platform="telegram",
platform_state="fatal",
error_code="telegram_polling_conflict",
error_message="another poller is active",
)
payload = status.read_runtime_status()
assert payload["gateway_state"] == "startup_failed"
assert payload["exit_reason"] == "telegram conflict"
assert payload["platforms"]["telegram"]["state"] == "fatal"
assert payload["platforms"]["telegram"]["error_code"] == "telegram_polling_conflict"
assert payload["platforms"]["telegram"]["error_message"] == "another poller is active"
class TestScopedLocks:
def test_acquire_scoped_lock_rejects_live_other_process(self, tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_GATEWAY_LOCK_DIR", str(tmp_path / "locks"))
lock_path = tmp_path / "locks" / "telegram-bot-token-2bb80d537b1da3e3.lock"
lock_path.parent.mkdir(parents=True, exist_ok=True)
lock_path.write_text(json.dumps({
"pid": 99999,
"start_time": 123,
"kind": "hermes-gateway",
}))
monkeypatch.setattr(status.os, "kill", lambda pid, sig: None)
monkeypatch.setattr(status, "_get_process_start_time", lambda pid: 123)
acquired, existing = status.acquire_scoped_lock("telegram-bot-token", "secret", metadata={"platform": "telegram"})
assert acquired is False
assert existing["pid"] == 99999
def test_acquire_scoped_lock_replaces_stale_record(self, tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_GATEWAY_LOCK_DIR", str(tmp_path / "locks"))
lock_path = tmp_path / "locks" / "telegram-bot-token-2bb80d537b1da3e3.lock"
lock_path.parent.mkdir(parents=True, exist_ok=True)
lock_path.write_text(json.dumps({
"pid": 99999,
"start_time": 123,
"kind": "hermes-gateway",
}))
def fake_kill(pid, sig):
raise ProcessLookupError
monkeypatch.setattr(status.os, "kill", fake_kill)
acquired, existing = status.acquire_scoped_lock("telegram-bot-token", "secret", metadata={"platform": "telegram"})
assert acquired is True
payload = json.loads(lock_path.read_text())
assert payload["pid"] == os.getpid()
assert payload["metadata"]["platform"] == "telegram"
def test_release_scoped_lock_only_removes_current_owner(self, tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_GATEWAY_LOCK_DIR", str(tmp_path / "locks"))
acquired, _ = status.acquire_scoped_lock("telegram-bot-token", "secret", metadata={"platform": "telegram"})
assert acquired is True
lock_path = tmp_path / "locks" / "telegram-bot-token-2bb80d537b1da3e3.lock"
assert lock_path.exists()
status.release_scoped_lock("telegram-bot-token", "secret")
assert not lock_path.exists()

View file

@ -0,0 +1,140 @@
"""Tests for gateway /status behavior and token persistence."""
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
import pytest
from gateway.config import GatewayConfig, Platform, PlatformConfig
from gateway.platforms.base import MessageEvent
from gateway.session import SessionEntry, SessionSource, build_session_key
def _make_source() -> SessionSource:
return SessionSource(
platform=Platform.TELEGRAM,
user_id="u1",
chat_id="c1",
user_name="tester",
chat_type="dm",
)
def _make_event(text: str) -> MessageEvent:
return MessageEvent(
text=text,
source=_make_source(),
message_id="m1",
)
def _make_runner(session_entry: SessionEntry):
from gateway.run import GatewayRunner
runner = object.__new__(GatewayRunner)
runner.config = GatewayConfig(
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}
)
adapter = MagicMock()
adapter.send = AsyncMock()
runner.adapters = {Platform.TELEGRAM: adapter}
runner._voice_mode = {}
runner.hooks = SimpleNamespace(emit=AsyncMock(), loaded_hooks=False)
runner.session_store = MagicMock()
runner.session_store.get_or_create_session.return_value = session_entry
runner.session_store.load_transcript.return_value = []
runner.session_store.has_any_sessions.return_value = True
runner.session_store.append_to_transcript = MagicMock()
runner.session_store.rewrite_transcript = MagicMock()
runner.session_store.update_session = MagicMock()
runner._running_agents = {}
runner._pending_messages = {}
runner._pending_approvals = {}
runner._session_db = None
runner._reasoning_config = None
runner._provider_routing = {}
runner._fallback_model = None
runner._show_reasoning = False
runner._is_user_authorized = lambda _source: True
runner._set_session_env = lambda _context: None
runner._should_send_voice_reply = lambda *_args, **_kwargs: False
runner._send_voice_reply = AsyncMock()
runner._capture_gateway_honcho_if_configured = lambda *args, **kwargs: None
runner._emit_gateway_run_progress = AsyncMock()
return runner
@pytest.mark.asyncio
async def test_status_command_reports_running_agent_without_interrupt(monkeypatch):
session_entry = SessionEntry(
session_key=build_session_key(_make_source()),
session_id="sess-1",
created_at=datetime.now(),
updated_at=datetime.now(),
platform=Platform.TELEGRAM,
chat_type="dm",
total_tokens=321,
)
runner = _make_runner(session_entry)
running_agent = MagicMock()
runner._running_agents[build_session_key(_make_source())] = running_agent
result = await runner._handle_message(_make_event("/status"))
assert "**Tokens:** 321" in result
assert "**Agent Running:** Yes ⚡" in result
running_agent.interrupt.assert_not_called()
assert runner._pending_messages == {}
@pytest.mark.asyncio
async def test_handle_message_persists_agent_token_counts(monkeypatch):
import gateway.run as gateway_run
session_entry = SessionEntry(
session_key=build_session_key(_make_source()),
session_id="sess-1",
created_at=datetime.now(),
updated_at=datetime.now(),
platform=Platform.TELEGRAM,
chat_type="dm",
)
runner = _make_runner(session_entry)
runner.session_store.load_transcript.return_value = [{"role": "user", "content": "earlier"}]
runner._run_agent = AsyncMock(
return_value={
"final_response": "ok",
"messages": [],
"tools": [],
"history_offset": 0,
"last_prompt_tokens": 80,
"input_tokens": 120,
"output_tokens": 45,
"model": "openai/test-model",
}
)
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
monkeypatch.setattr(
"agent.model_metadata.get_model_context_length",
lambda *_args, **_kwargs: 100000,
)
result = await runner._handle_message(_make_event("hello"))
assert result == "ok"
runner.session_store.update_session.assert_called_once_with(
session_entry.session_key,
input_tokens=120,
output_tokens=45,
cache_read_tokens=0,
cache_write_tokens=0,
last_prompt_tokens=80,
model="openai/test-model",
estimated_cost_usd=None,
cost_status=None,
cost_source=None,
provider=None,
base_url=None,
)

View file

@ -0,0 +1,127 @@
"""Tests for gateway/sticker_cache.py — sticker description cache."""
import json
import time
from unittest.mock import patch
from gateway.sticker_cache import (
_load_cache,
_save_cache,
get_cached_description,
cache_sticker_description,
build_sticker_injection,
build_animated_sticker_injection,
STICKER_VISION_PROMPT,
)
class TestLoadSaveCache:
def test_load_missing_file(self, tmp_path):
with patch("gateway.sticker_cache.CACHE_PATH", tmp_path / "nope.json"):
assert _load_cache() == {}
def test_load_corrupt_file(self, tmp_path):
bad_file = tmp_path / "bad.json"
bad_file.write_text("not json{{{")
with patch("gateway.sticker_cache.CACHE_PATH", bad_file):
assert _load_cache() == {}
def test_save_and_load_roundtrip(self, tmp_path):
cache_file = tmp_path / "cache.json"
data = {"abc123": {"description": "A cat", "emoji": "", "set_name": "", "cached_at": 1.0}}
with patch("gateway.sticker_cache.CACHE_PATH", cache_file):
_save_cache(data)
loaded = _load_cache()
assert loaded == data
def test_save_creates_parent_dirs(self, tmp_path):
cache_file = tmp_path / "sub" / "dir" / "cache.json"
with patch("gateway.sticker_cache.CACHE_PATH", cache_file):
_save_cache({"key": "value"})
assert cache_file.exists()
class TestCacheSticker:
def test_cache_and_retrieve(self, tmp_path):
cache_file = tmp_path / "cache.json"
with patch("gateway.sticker_cache.CACHE_PATH", cache_file):
cache_sticker_description("uid_1", "A happy dog", emoji="🐕", set_name="Dogs")
result = get_cached_description("uid_1")
assert result is not None
assert result["description"] == "A happy dog"
assert result["emoji"] == "🐕"
assert result["set_name"] == "Dogs"
assert "cached_at" in result
def test_missing_sticker_returns_none(self, tmp_path):
cache_file = tmp_path / "cache.json"
with patch("gateway.sticker_cache.CACHE_PATH", cache_file):
result = get_cached_description("nonexistent")
assert result is None
def test_overwrite_existing(self, tmp_path):
cache_file = tmp_path / "cache.json"
with patch("gateway.sticker_cache.CACHE_PATH", cache_file):
cache_sticker_description("uid_1", "Old description")
cache_sticker_description("uid_1", "New description")
result = get_cached_description("uid_1")
assert result["description"] == "New description"
def test_multiple_stickers(self, tmp_path):
cache_file = tmp_path / "cache.json"
with patch("gateway.sticker_cache.CACHE_PATH", cache_file):
cache_sticker_description("uid_1", "Cat")
cache_sticker_description("uid_2", "Dog")
r1 = get_cached_description("uid_1")
r2 = get_cached_description("uid_2")
assert r1["description"] == "Cat"
assert r2["description"] == "Dog"
class TestBuildStickerInjection:
def test_exact_format_no_context(self):
result = build_sticker_injection("A cat waving")
assert result == '[The user sent a sticker~ It shows: "A cat waving" (=^.w.^=)]'
def test_exact_format_emoji_only(self):
result = build_sticker_injection("A cat", emoji="😀")
assert result == '[The user sent a sticker 😀~ It shows: "A cat" (=^.w.^=)]'
def test_exact_format_emoji_and_set_name(self):
result = build_sticker_injection("A cat", emoji="😀", set_name="MyPack")
assert result == '[The user sent a sticker 😀 from "MyPack"~ It shows: "A cat" (=^.w.^=)]'
def test_set_name_without_emoji_ignored(self):
"""set_name alone (no emoji) produces no context — only emoji+set_name triggers 'from' clause."""
result = build_sticker_injection("A cat", set_name="MyPack")
assert result == '[The user sent a sticker~ It shows: "A cat" (=^.w.^=)]'
assert "MyPack" not in result
def test_description_with_quotes(self):
result = build_sticker_injection('A "happy" dog')
assert '"A \\"happy\\" dog"' not in result # no escaping happens
assert 'A "happy" dog' in result
def test_empty_description(self):
result = build_sticker_injection("")
assert result == '[The user sent a sticker~ It shows: "" (=^.w.^=)]'
class TestBuildAnimatedStickerInjection:
def test_exact_format_with_emoji(self):
result = build_animated_sticker_injection(emoji="🎉")
assert result == (
"[The user sent an animated sticker 🎉~ "
"I can't see animated ones yet, but the emoji suggests: 🎉]"
)
def test_exact_format_without_emoji(self):
result = build_animated_sticker_injection()
assert result == "[The user sent an animated sticker~ I can't see animated ones yet]"
def test_empty_emoji_same_as_no_emoji(self):
result = build_animated_sticker_injection(emoji="")
assert result == build_animated_sticker_injection()

View file

@ -0,0 +1,77 @@
"""Gateway STT config tests — honor stt.enabled: false from config.yaml."""
from pathlib import Path
from unittest.mock import AsyncMock, patch
import pytest
import yaml
from gateway.config import GatewayConfig, load_gateway_config
def test_gateway_config_stt_disabled_from_dict_nested():
config = GatewayConfig.from_dict({"stt": {"enabled": False}})
assert config.stt_enabled is False
def test_load_gateway_config_bridges_stt_enabled_from_config_yaml(tmp_path, monkeypatch):
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
(hermes_home / "config.yaml").write_text(
yaml.dump({"stt": {"enabled": False}}),
encoding="utf-8",
)
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
monkeypatch.setattr(Path, "home", lambda: tmp_path)
config = load_gateway_config()
assert config.stt_enabled is False
@pytest.mark.asyncio
async def test_enrich_message_with_transcription_skips_when_stt_disabled():
from gateway.run import GatewayRunner
runner = GatewayRunner.__new__(GatewayRunner)
runner.config = GatewayConfig(stt_enabled=False)
with patch(
"tools.transcription_tools.transcribe_audio",
side_effect=AssertionError("transcribe_audio should not be called when STT is disabled"),
), patch(
"tools.transcription_tools.get_stt_model_from_config",
return_value=None,
):
result = await runner._enrich_message_with_transcription(
"caption",
["/tmp/voice.ogg"],
)
assert "transcription is disabled" in result.lower()
assert "caption" in result
@pytest.mark.asyncio
async def test_enrich_message_with_transcription_avoids_bogus_no_provider_message_for_backend_key_errors():
from gateway.run import GatewayRunner
runner = GatewayRunner.__new__(GatewayRunner)
runner.config = GatewayConfig(stt_enabled=True)
with patch(
"tools.transcription_tools.transcribe_audio",
return_value={"success": False, "error": "VOICE_TOOLS_OPENAI_KEY not set"},
), patch(
"tools.transcription_tools.get_stt_model_from_config",
return_value=None,
):
result = await runner._enrich_message_with_transcription(
"caption",
["/tmp/voice.ogg"],
)
assert "No STT provider is configured" not in result
assert "trouble transcribing" in result
assert "caption" in result

View file

@ -0,0 +1,241 @@
import asyncio
import sys
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
import pytest
from gateway.config import PlatformConfig
def _ensure_telegram_mock():
if "telegram" in sys.modules and hasattr(sys.modules["telegram"], "__file__"):
return
telegram_mod = MagicMock()
telegram_mod.ext.ContextTypes.DEFAULT_TYPE = type(None)
telegram_mod.constants.ParseMode.MARKDOWN_V2 = "MarkdownV2"
telegram_mod.constants.ChatType.GROUP = "group"
telegram_mod.constants.ChatType.SUPERGROUP = "supergroup"
telegram_mod.constants.ChatType.CHANNEL = "channel"
telegram_mod.constants.ChatType.PRIVATE = "private"
for name in ("telegram", "telegram.ext", "telegram.constants"):
sys.modules.setdefault(name, telegram_mod)
_ensure_telegram_mock()
from gateway.platforms.telegram import TelegramAdapter # noqa: E402
@pytest.mark.asyncio
async def test_connect_rejects_same_host_token_lock(monkeypatch):
adapter = TelegramAdapter(PlatformConfig(enabled=True, token="secret-token"))
monkeypatch.setattr(
"gateway.status.acquire_scoped_lock",
lambda scope, identity, metadata=None: (False, {"pid": 4242}),
)
ok = await adapter.connect()
assert ok is False
assert adapter.fatal_error_code == "telegram_token_lock"
assert adapter.has_fatal_error is True
assert "already using this Telegram bot token" in adapter.fatal_error_message
@pytest.mark.asyncio
async def test_polling_conflict_retries_before_fatal(monkeypatch):
"""A single 409 should trigger a retry, not an immediate fatal error."""
adapter = TelegramAdapter(PlatformConfig(enabled=True, token="***"))
fatal_handler = AsyncMock()
adapter.set_fatal_error_handler(fatal_handler)
monkeypatch.setattr(
"gateway.status.acquire_scoped_lock",
lambda scope, identity, metadata=None: (True, None),
)
monkeypatch.setattr(
"gateway.status.release_scoped_lock",
lambda scope, identity: None,
)
captured = {}
async def fake_start_polling(**kwargs):
captured["error_callback"] = kwargs["error_callback"]
updater = SimpleNamespace(
start_polling=AsyncMock(side_effect=fake_start_polling),
stop=AsyncMock(),
running=True,
)
bot = SimpleNamespace(set_my_commands=AsyncMock())
app = SimpleNamespace(
bot=bot,
updater=updater,
add_handler=MagicMock(),
initialize=AsyncMock(),
start=AsyncMock(),
)
builder = MagicMock()
builder.token.return_value = builder
builder.build.return_value = app
monkeypatch.setattr("gateway.platforms.telegram.Application", SimpleNamespace(builder=MagicMock(return_value=builder)))
# Speed up retries for testing
monkeypatch.setattr("asyncio.sleep", AsyncMock())
ok = await adapter.connect()
assert ok is True
assert callable(captured["error_callback"])
conflict = type("Conflict", (Exception,), {})
# First conflict: should retry, NOT be fatal
captured["error_callback"](conflict("Conflict: terminated by other getUpdates request"))
await asyncio.sleep(0)
await asyncio.sleep(0)
# Give the scheduled task a chance to run
for _ in range(10):
await asyncio.sleep(0)
assert adapter.has_fatal_error is False, "First conflict should not be fatal"
assert adapter._polling_conflict_count == 0, "Count should reset after successful retry"
@pytest.mark.asyncio
async def test_polling_conflict_becomes_fatal_after_retries(monkeypatch):
"""After exhausting retries, the conflict should become fatal."""
adapter = TelegramAdapter(PlatformConfig(enabled=True, token="***"))
fatal_handler = AsyncMock()
adapter.set_fatal_error_handler(fatal_handler)
monkeypatch.setattr(
"gateway.status.acquire_scoped_lock",
lambda scope, identity, metadata=None: (True, None),
)
monkeypatch.setattr(
"gateway.status.release_scoped_lock",
lambda scope, identity: None,
)
captured = {}
async def fake_start_polling(**kwargs):
captured["error_callback"] = kwargs["error_callback"]
# Make start_polling fail on retries to exhaust retries
call_count = {"n": 0}
async def failing_start_polling(**kwargs):
call_count["n"] += 1
if call_count["n"] == 1:
# First call (initial connect) succeeds
captured["error_callback"] = kwargs["error_callback"]
else:
# Retry calls fail
raise Exception("Connection refused")
updater = SimpleNamespace(
start_polling=AsyncMock(side_effect=failing_start_polling),
stop=AsyncMock(),
running=True,
)
bot = SimpleNamespace(set_my_commands=AsyncMock())
app = SimpleNamespace(
bot=bot,
updater=updater,
add_handler=MagicMock(),
initialize=AsyncMock(),
start=AsyncMock(),
)
builder = MagicMock()
builder.token.return_value = builder
builder.build.return_value = app
monkeypatch.setattr("gateway.platforms.telegram.Application", SimpleNamespace(builder=MagicMock(return_value=builder)))
# Speed up retries for testing
monkeypatch.setattr("asyncio.sleep", AsyncMock())
ok = await adapter.connect()
assert ok is True
conflict = type("Conflict", (Exception,), {})
# Directly call _handle_polling_conflict to avoid event-loop scheduling
# complexity. Each call simulates one 409 from Telegram.
for i in range(4):
await adapter._handle_polling_conflict(
conflict("Conflict: terminated by other getUpdates request")
)
# After 3 failed retries (count 1-3 each enter the retry branch but
# start_polling raises), the 4th conflict pushes count to 4 which
# exceeds MAX_CONFLICT_RETRIES (3), entering the fatal branch.
assert adapter.fatal_error_code == "telegram_polling_conflict", (
f"Expected fatal after 4 conflicts, got code={adapter.fatal_error_code}, "
f"count={adapter._polling_conflict_count}"
)
assert adapter.has_fatal_error is True
fatal_handler.assert_awaited_once()
@pytest.mark.asyncio
async def test_connect_marks_retryable_fatal_error_for_startup_network_failure(monkeypatch):
adapter = TelegramAdapter(PlatformConfig(enabled=True, token="***"))
monkeypatch.setattr(
"gateway.status.acquire_scoped_lock",
lambda scope, identity, metadata=None: (True, None),
)
monkeypatch.setattr(
"gateway.status.release_scoped_lock",
lambda scope, identity: None,
)
builder = MagicMock()
builder.token.return_value = builder
app = SimpleNamespace(
bot=SimpleNamespace(),
updater=SimpleNamespace(),
add_handler=MagicMock(),
initialize=AsyncMock(side_effect=RuntimeError("Temporary failure in name resolution")),
start=AsyncMock(),
)
builder.build.return_value = app
monkeypatch.setattr("gateway.platforms.telegram.Application", SimpleNamespace(builder=MagicMock(return_value=builder)))
ok = await adapter.connect()
assert ok is False
assert adapter.fatal_error_code == "telegram_connect_error"
assert adapter.fatal_error_retryable is True
assert "Temporary failure in name resolution" in adapter.fatal_error_message
@pytest.mark.asyncio
async def test_disconnect_skips_inactive_updater_and_app(monkeypatch):
adapter = TelegramAdapter(PlatformConfig(enabled=True, token="***"))
updater = SimpleNamespace(running=False, stop=AsyncMock())
app = SimpleNamespace(
updater=updater,
running=False,
stop=AsyncMock(),
shutdown=AsyncMock(),
)
adapter._app = app
warning = MagicMock()
monkeypatch.setattr("gateway.platforms.telegram.logger.warning", warning)
await adapter.disconnect()
updater.stop.assert_not_awaited()
app.stop.assert_not_awaited()
app.shutdown.assert_awaited_once()
warning.assert_not_called()

View file

@ -0,0 +1,694 @@
"""
Tests for Telegram document handling in gateway/platforms/telegram.py.
Covers: document type detection, download/cache flow, size limits,
text injection, error handling.
Note: python-telegram-bot may not be installed in the test environment.
We mock the telegram module at import time to avoid collection errors.
"""
import asyncio
import importlib
import os
import sys
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from gateway.config import Platform, PlatformConfig
from gateway.platforms.base import (
MessageEvent,
MessageType,
SendResult,
SUPPORTED_DOCUMENT_TYPES,
)
# ---------------------------------------------------------------------------
# Mock the telegram package if it's not installed
# ---------------------------------------------------------------------------
def _ensure_telegram_mock():
"""Install mock telegram modules so TelegramAdapter can be imported."""
if "telegram" in sys.modules and hasattr(sys.modules["telegram"], "__file__"):
# Real library is installed — no mocking needed
return
telegram_mod = MagicMock()
# ContextTypes needs DEFAULT_TYPE as an actual attribute for the annotation
telegram_mod.ext.ContextTypes.DEFAULT_TYPE = type(None)
telegram_mod.constants.ParseMode.MARKDOWN_V2 = "MarkdownV2"
telegram_mod.constants.ChatType.GROUP = "group"
telegram_mod.constants.ChatType.SUPERGROUP = "supergroup"
telegram_mod.constants.ChatType.CHANNEL = "channel"
telegram_mod.constants.ChatType.PRIVATE = "private"
for name in ("telegram", "telegram.ext", "telegram.constants"):
sys.modules.setdefault(name, telegram_mod)
_ensure_telegram_mock()
# Now we can safely import
from gateway.platforms.telegram import TelegramAdapter # noqa: E402
# ---------------------------------------------------------------------------
# Helpers to build mock Telegram objects
# ---------------------------------------------------------------------------
def _make_file_obj(data: bytes = b"hello"):
"""Create a mock Telegram File with download_as_bytearray."""
f = AsyncMock()
f.download_as_bytearray = AsyncMock(return_value=bytearray(data))
f.file_path = "documents/file.pdf"
return f
def _make_document(
file_name="report.pdf",
mime_type="application/pdf",
file_size=1024,
file_obj=None,
):
"""Create a mock Telegram Document object."""
doc = MagicMock()
doc.file_name = file_name
doc.mime_type = mime_type
doc.file_size = file_size
doc.get_file = AsyncMock(return_value=file_obj or _make_file_obj())
return doc
def _make_message(document=None, caption=None, media_group_id=None, photo=None):
"""Build a mock Telegram Message with the given document/photo."""
msg = MagicMock()
msg.message_id = 42
msg.text = caption or ""
msg.caption = caption
msg.date = None
# Media flags — all None except explicit payload
msg.photo = photo
msg.video = None
msg.audio = None
msg.voice = None
msg.sticker = None
msg.document = document
msg.media_group_id = media_group_id
# Chat / user
msg.chat = MagicMock()
msg.chat.id = 100
msg.chat.type = "private"
msg.chat.title = None
msg.chat.full_name = "Test User"
msg.from_user = MagicMock()
msg.from_user.id = 1
msg.from_user.full_name = "Test User"
msg.message_thread_id = None
return msg
def _make_update(msg):
"""Wrap a message in a mock Update."""
update = MagicMock()
update.message = msg
return update
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture()
def adapter():
config = PlatformConfig(enabled=True, token="fake-token")
a = TelegramAdapter(config)
# Capture events instead of processing them
a.handle_message = AsyncMock()
return a
@pytest.fixture(autouse=True)
def _redirect_cache(tmp_path, monkeypatch):
"""Point document cache to tmp_path so tests don't touch ~/.hermes."""
monkeypatch.setattr(
"gateway.platforms.base.DOCUMENT_CACHE_DIR", tmp_path / "doc_cache"
)
# ---------------------------------------------------------------------------
# TestDocumentTypeDetection
# ---------------------------------------------------------------------------
class TestDocumentTypeDetection:
@pytest.mark.asyncio
async def test_document_detected_explicitly(self, adapter):
doc = _make_document()
msg = _make_message(document=doc)
update = _make_update(msg)
await adapter._handle_media_message(update, MagicMock())
event = adapter.handle_message.call_args[0][0]
assert event.message_type == MessageType.DOCUMENT
@pytest.mark.asyncio
async def test_fallback_is_document(self, adapter):
"""When no specific media attr is set, message_type defaults to DOCUMENT."""
msg = _make_message()
msg.document = None # no media at all
update = _make_update(msg)
await adapter._handle_media_message(update, MagicMock())
event = adapter.handle_message.call_args[0][0]
assert event.message_type == MessageType.DOCUMENT
# ---------------------------------------------------------------------------
# TestDocumentDownloadBlock
# ---------------------------------------------------------------------------
def _make_photo(file_obj=None):
photo = MagicMock()
photo.get_file = AsyncMock(return_value=file_obj or _make_file_obj(b"photo-bytes"))
return photo
class TestDocumentDownloadBlock:
@pytest.mark.asyncio
async def test_supported_pdf_is_cached(self, adapter):
pdf_bytes = b"%PDF-1.4 fake"
file_obj = _make_file_obj(pdf_bytes)
doc = _make_document(file_name="report.pdf", file_size=1024, file_obj=file_obj)
msg = _make_message(document=doc)
update = _make_update(msg)
await adapter._handle_media_message(update, MagicMock())
event = adapter.handle_message.call_args[0][0]
assert len(event.media_urls) == 1
assert os.path.exists(event.media_urls[0])
assert event.media_types == ["application/pdf"]
@pytest.mark.asyncio
async def test_supported_txt_injects_content(self, adapter):
content = b"Hello from a text file"
file_obj = _make_file_obj(content)
doc = _make_document(
file_name="notes.txt", mime_type="text/plain",
file_size=len(content), file_obj=file_obj,
)
msg = _make_message(document=doc)
update = _make_update(msg)
await adapter._handle_media_message(update, MagicMock())
event = adapter.handle_message.call_args[0][0]
assert "Hello from a text file" in event.text
assert "[Content of notes.txt]" in event.text
@pytest.mark.asyncio
async def test_supported_md_injects_content(self, adapter):
content = b"# Title\nSome markdown"
file_obj = _make_file_obj(content)
doc = _make_document(
file_name="readme.md", mime_type="text/markdown",
file_size=len(content), file_obj=file_obj,
)
msg = _make_message(document=doc)
update = _make_update(msg)
await adapter._handle_media_message(update, MagicMock())
event = adapter.handle_message.call_args[0][0]
assert "# Title" in event.text
@pytest.mark.asyncio
async def test_caption_preserved_with_injection(self, adapter):
content = b"file text"
file_obj = _make_file_obj(content)
doc = _make_document(
file_name="doc.txt", mime_type="text/plain",
file_size=len(content), file_obj=file_obj,
)
msg = _make_message(document=doc, caption="Please summarize")
update = _make_update(msg)
await adapter._handle_media_message(update, MagicMock())
event = adapter.handle_message.call_args[0][0]
assert "file text" in event.text
assert "Please summarize" in event.text
@pytest.mark.asyncio
async def test_unsupported_type_rejected(self, adapter):
doc = _make_document(file_name="archive.zip", mime_type="application/zip", file_size=100)
msg = _make_message(document=doc)
update = _make_update(msg)
await adapter._handle_media_message(update, MagicMock())
event = adapter.handle_message.call_args[0][0]
assert "Unsupported document type" in event.text
assert ".zip" in event.text
@pytest.mark.asyncio
async def test_oversized_file_rejected(self, adapter):
doc = _make_document(file_name="huge.pdf", file_size=25 * 1024 * 1024)
msg = _make_message(document=doc)
update = _make_update(msg)
await adapter._handle_media_message(update, MagicMock())
event = adapter.handle_message.call_args[0][0]
assert "too large" in event.text
@pytest.mark.asyncio
async def test_none_file_size_rejected(self, adapter):
"""Security fix: file_size=None must be rejected (not silently allowed)."""
doc = _make_document(file_name="tricky.pdf", file_size=None)
msg = _make_message(document=doc)
update = _make_update(msg)
await adapter._handle_media_message(update, MagicMock())
event = adapter.handle_message.call_args[0][0]
assert "too large" in event.text or "could not be verified" in event.text
@pytest.mark.asyncio
async def test_missing_filename_uses_mime_lookup(self, adapter):
"""No file_name but valid mime_type should resolve to extension."""
content = b"some pdf bytes"
file_obj = _make_file_obj(content)
doc = _make_document(
file_name=None, mime_type="application/pdf",
file_size=len(content), file_obj=file_obj,
)
msg = _make_message(document=doc)
update = _make_update(msg)
await adapter._handle_media_message(update, MagicMock())
event = adapter.handle_message.call_args[0][0]
assert len(event.media_urls) == 1
assert event.media_types == ["application/pdf"]
@pytest.mark.asyncio
async def test_missing_filename_and_mime_rejected(self, adapter):
doc = _make_document(file_name=None, mime_type=None, file_size=100)
msg = _make_message(document=doc)
update = _make_update(msg)
await adapter._handle_media_message(update, MagicMock())
event = adapter.handle_message.call_args[0][0]
assert "Unsupported" in event.text
@pytest.mark.asyncio
async def test_unicode_decode_error_handled(self, adapter):
"""Binary bytes that aren't valid UTF-8 in a .txt — content not injected but file still cached."""
binary = bytes(range(128, 256)) # not valid UTF-8
file_obj = _make_file_obj(binary)
doc = _make_document(
file_name="binary.txt", mime_type="text/plain",
file_size=len(binary), file_obj=file_obj,
)
msg = _make_message(document=doc)
update = _make_update(msg)
await adapter._handle_media_message(update, MagicMock())
event = adapter.handle_message.call_args[0][0]
# File should still be cached
assert len(event.media_urls) == 1
assert os.path.exists(event.media_urls[0])
# Content NOT injected — text should be empty (no caption set)
assert "[Content of" not in (event.text or "")
@pytest.mark.asyncio
async def test_text_injection_capped(self, adapter):
"""A .txt file over 100 KB should NOT have its content injected."""
large = b"x" * (200 * 1024) # 200 KB
file_obj = _make_file_obj(large)
doc = _make_document(
file_name="big.txt", mime_type="text/plain",
file_size=len(large), file_obj=file_obj,
)
msg = _make_message(document=doc)
update = _make_update(msg)
await adapter._handle_media_message(update, MagicMock())
event = adapter.handle_message.call_args[0][0]
# File should be cached
assert len(event.media_urls) == 1
# Content should NOT be injected
assert "[Content of" not in (event.text or "")
@pytest.mark.asyncio
async def test_download_exception_handled(self, adapter):
"""If get_file() raises, the handler logs the error without crashing."""
doc = _make_document(file_name="crash.pdf", file_size=100)
doc.get_file = AsyncMock(side_effect=RuntimeError("Telegram API down"))
msg = _make_message(document=doc)
update = _make_update(msg)
# Should not raise
await adapter._handle_media_message(update, MagicMock())
# handle_message should still be called (the handler catches the exception)
adapter.handle_message.assert_called_once()
# ---------------------------------------------------------------------------
# TestMediaGroups — media group (album) buffering
# ---------------------------------------------------------------------------
class TestMediaGroups:
@pytest.mark.asyncio
async def test_non_album_photo_burst_is_buffered_and_combined(self, adapter):
first_photo = _make_photo(_make_file_obj(b"first"))
second_photo = _make_photo(_make_file_obj(b"second"))
msg1 = _make_message(caption="two images", photo=[first_photo])
msg2 = _make_message(photo=[second_photo])
with patch("gateway.platforms.telegram.cache_image_from_bytes", side_effect=["/tmp/burst-one.jpg", "/tmp/burst-two.jpg"]):
await adapter._handle_media_message(_make_update(msg1), MagicMock())
await adapter._handle_media_message(_make_update(msg2), MagicMock())
assert adapter.handle_message.await_count == 0
await asyncio.sleep(adapter.MEDIA_GROUP_WAIT_SECONDS + 0.05)
adapter.handle_message.assert_awaited_once()
event = adapter.handle_message.await_args.args[0]
assert event.text == "two images"
assert event.media_urls == ["/tmp/burst-one.jpg", "/tmp/burst-two.jpg"]
assert len(event.media_types) == 2
@pytest.mark.asyncio
async def test_photo_album_is_buffered_and_combined(self, adapter):
first_photo = _make_photo(_make_file_obj(b"first"))
second_photo = _make_photo(_make_file_obj(b"second"))
msg1 = _make_message(caption="two images", media_group_id="album-1", photo=[first_photo])
msg2 = _make_message(media_group_id="album-1", photo=[second_photo])
with patch("gateway.platforms.telegram.cache_image_from_bytes", side_effect=["/tmp/one.jpg", "/tmp/two.jpg"]):
await adapter._handle_media_message(_make_update(msg1), MagicMock())
await adapter._handle_media_message(_make_update(msg2), MagicMock())
assert adapter.handle_message.await_count == 0
await asyncio.sleep(adapter.MEDIA_GROUP_WAIT_SECONDS + 0.05)
adapter.handle_message.assert_awaited_once()
event = adapter.handle_message.call_args[0][0]
assert event.text == "two images"
assert event.media_urls == ["/tmp/one.jpg", "/tmp/two.jpg"]
assert len(event.media_types) == 2
@pytest.mark.asyncio
async def test_disconnect_cancels_pending_media_group_flush(self, adapter):
first_photo = _make_photo(_make_file_obj(b"first"))
msg = _make_message(caption="two images", media_group_id="album-2", photo=[first_photo])
with patch("gateway.platforms.telegram.cache_image_from_bytes", return_value="/tmp/one.jpg"):
await adapter._handle_media_message(_make_update(msg), MagicMock())
assert "album-2" in adapter._media_group_events
assert "album-2" in adapter._media_group_tasks
await adapter.disconnect()
await asyncio.sleep(adapter.MEDIA_GROUP_WAIT_SECONDS + 0.05)
assert adapter._media_group_events == {}
assert adapter._media_group_tasks == {}
adapter.handle_message.assert_not_awaited()
# ---------------------------------------------------------------------------
# TestSendDocument — outbound file attachment delivery
# ---------------------------------------------------------------------------
class TestSendDocument:
"""Tests for TelegramAdapter.send_document() — sending files to users."""
@pytest.fixture()
def connected_adapter(self, adapter):
"""Adapter with a mock bot attached."""
bot = AsyncMock()
adapter._bot = bot
return adapter
@pytest.mark.asyncio
async def test_send_document_success(self, connected_adapter, tmp_path):
"""A local file is sent via bot.send_document and returns success."""
# Create a real temp file
test_file = tmp_path / "report.pdf"
test_file.write_bytes(b"%PDF-1.4 fake content")
mock_msg = MagicMock()
mock_msg.message_id = 99
connected_adapter._bot.send_document = AsyncMock(return_value=mock_msg)
result = await connected_adapter.send_document(
chat_id="12345",
file_path=str(test_file),
caption="Here's the report",
)
assert result.success is True
assert result.message_id == "99"
connected_adapter._bot.send_document.assert_called_once()
call_kwargs = connected_adapter._bot.send_document.call_args[1]
assert call_kwargs["chat_id"] == 12345
assert call_kwargs["filename"] == "report.pdf"
assert call_kwargs["caption"] == "Here's the report"
@pytest.mark.asyncio
async def test_send_document_custom_filename(self, connected_adapter, tmp_path):
"""The file_name parameter overrides the basename for display."""
test_file = tmp_path / "doc_abc123_ugly.csv"
test_file.write_bytes(b"a,b,c\n1,2,3")
mock_msg = MagicMock()
mock_msg.message_id = 100
connected_adapter._bot.send_document = AsyncMock(return_value=mock_msg)
result = await connected_adapter.send_document(
chat_id="12345",
file_path=str(test_file),
file_name="clean_data.csv",
)
assert result.success is True
call_kwargs = connected_adapter._bot.send_document.call_args[1]
assert call_kwargs["filename"] == "clean_data.csv"
@pytest.mark.asyncio
async def test_send_document_file_not_found(self, connected_adapter):
"""Missing file returns error without calling Telegram API."""
result = await connected_adapter.send_document(
chat_id="12345",
file_path="/nonexistent/file.pdf",
)
assert result.success is False
assert "not found" in result.error.lower()
connected_adapter._bot.send_document.assert_not_called()
@pytest.mark.asyncio
async def test_send_document_not_connected(self, adapter):
"""If bot is None, returns not connected error."""
result = await adapter.send_document(
chat_id="12345",
file_path="/some/file.pdf",
)
assert result.success is False
assert "Not connected" in result.error
@pytest.mark.asyncio
async def test_send_document_caption_truncated(self, connected_adapter, tmp_path):
"""Captions longer than 1024 chars are truncated."""
test_file = tmp_path / "data.json"
test_file.write_bytes(b"{}")
mock_msg = MagicMock()
mock_msg.message_id = 101
connected_adapter._bot.send_document = AsyncMock(return_value=mock_msg)
long_caption = "x" * 2000
await connected_adapter.send_document(
chat_id="12345",
file_path=str(test_file),
caption=long_caption,
)
call_kwargs = connected_adapter._bot.send_document.call_args[1]
assert len(call_kwargs["caption"]) == 1024
@pytest.mark.asyncio
async def test_send_document_api_error_falls_back(self, connected_adapter, tmp_path):
"""If Telegram API raises, falls back to base class text message."""
test_file = tmp_path / "file.pdf"
test_file.write_bytes(b"data")
connected_adapter._bot.send_document = AsyncMock(
side_effect=RuntimeError("Telegram API error")
)
# The base fallback calls self.send() which is also on _bot, so mock it
# to avoid cascading errors.
connected_adapter.send = AsyncMock(
return_value=SendResult(success=True, message_id="fallback")
)
result = await connected_adapter.send_document(
chat_id="12345",
file_path=str(test_file),
)
# Should have fallen back to base class
assert result.success is True
assert result.message_id == "fallback"
@pytest.mark.asyncio
async def test_send_document_reply_to(self, connected_adapter, tmp_path):
"""reply_to parameter is forwarded as reply_to_message_id."""
test_file = tmp_path / "spec.md"
test_file.write_bytes(b"# Spec")
mock_msg = MagicMock()
mock_msg.message_id = 102
connected_adapter._bot.send_document = AsyncMock(return_value=mock_msg)
await connected_adapter.send_document(
chat_id="12345",
file_path=str(test_file),
reply_to="50",
)
call_kwargs = connected_adapter._bot.send_document.call_args[1]
assert call_kwargs["reply_to_message_id"] == 50
@pytest.mark.asyncio
async def test_send_document_thread_id(self, connected_adapter, tmp_path):
"""metadata thread_id is forwarded as message_thread_id (required for Telegram forum groups)."""
test_file = tmp_path / "report.pdf"
test_file.write_bytes(b"%PDF-1.4 data")
mock_msg = MagicMock()
mock_msg.message_id = 103
connected_adapter._bot.send_document = AsyncMock(return_value=mock_msg)
await connected_adapter.send_document(
chat_id="12345",
file_path=str(test_file),
metadata={"thread_id": "789"},
)
call_kwargs = connected_adapter._bot.send_document.call_args[1]
assert call_kwargs["message_thread_id"] == 789
class TestTelegramPhotoBatching:
@pytest.mark.asyncio
async def test_flush_photo_batch_does_not_drop_newer_scheduled_task(self, adapter):
old_task = MagicMock()
new_task = MagicMock()
batch_key = "session:photo-burst"
adapter._pending_photo_batch_tasks[batch_key] = new_task
adapter._pending_photo_batches[batch_key] = MessageEvent(
text="",
message_type=MessageType.PHOTO,
source=SimpleNamespace(channel_id="chat-1"),
media_urls=["/tmp/a.jpg"],
media_types=["image/jpeg"],
)
with (
patch("gateway.platforms.telegram.asyncio.current_task", return_value=old_task),
patch("gateway.platforms.telegram.asyncio.sleep", new=AsyncMock()),
):
await adapter._flush_photo_batch(batch_key)
assert adapter._pending_photo_batch_tasks[batch_key] is new_task
@pytest.mark.asyncio
async def test_disconnect_cancels_pending_photo_batch_tasks(self, adapter):
task = MagicMock()
task.done.return_value = False
adapter._pending_photo_batch_tasks["session:photo-burst"] = task
adapter._pending_photo_batches["session:photo-burst"] = MessageEvent(
text="",
message_type=MessageType.PHOTO,
source=SimpleNamespace(channel_id="chat-1"),
)
adapter._app = MagicMock()
adapter._app.updater.stop = AsyncMock()
adapter._app.stop = AsyncMock()
adapter._app.shutdown = AsyncMock()
await adapter.disconnect()
task.cancel.assert_called_once()
assert adapter._pending_photo_batch_tasks == {}
assert adapter._pending_photo_batches == {}
# ---------------------------------------------------------------------------
# TestSendVideo — outbound video delivery
# ---------------------------------------------------------------------------
class TestSendVideo:
"""Tests for TelegramAdapter.send_video() — sending videos to users."""
@pytest.fixture()
def connected_adapter(self, adapter):
bot = AsyncMock()
adapter._bot = bot
return adapter
@pytest.mark.asyncio
async def test_send_video_success(self, connected_adapter, tmp_path):
test_file = tmp_path / "clip.mp4"
test_file.write_bytes(b"\x00\x00\x00\x1c" + b"ftyp" + b"\x00" * 100)
mock_msg = MagicMock()
mock_msg.message_id = 200
connected_adapter._bot.send_video = AsyncMock(return_value=mock_msg)
result = await connected_adapter.send_video(
chat_id="12345",
video_path=str(test_file),
caption="Check this out",
)
assert result.success is True
assert result.message_id == "200"
connected_adapter._bot.send_video.assert_called_once()
@pytest.mark.asyncio
async def test_send_video_file_not_found(self, connected_adapter):
result = await connected_adapter.send_video(
chat_id="12345",
video_path="/nonexistent/video.mp4",
)
assert result.success is False
assert "not found" in result.error.lower()
@pytest.mark.asyncio
async def test_send_video_not_connected(self, adapter):
result = await adapter.send_video(
chat_id="12345",
video_path="/some/video.mp4",
)
assert result.success is False
assert "Not connected" in result.error
@pytest.mark.asyncio
async def test_send_video_thread_id(self, connected_adapter, tmp_path):
"""metadata thread_id is forwarded as message_thread_id (required for Telegram forum groups)."""
test_file = tmp_path / "clip.mp4"
test_file.write_bytes(b"\x00\x00\x00\x1c" + b"ftyp" + b"\x00" * 100)
mock_msg = MagicMock()
mock_msg.message_id = 201
connected_adapter._bot.send_video = AsyncMock(return_value=mock_msg)
await connected_adapter.send_video(
chat_id="12345",
video_path=str(test_file),
metadata={"thread_id": "789"},
)
call_kwargs = connected_adapter._bot.send_video.call_args[1]
assert call_kwargs["message_thread_id"] == 789

View file

@ -0,0 +1,538 @@
"""Tests for Telegram MarkdownV2 formatting in gateway/platforms/telegram.py.
Covers: _escape_mdv2 (pure function), format_message (markdown-to-MarkdownV2
conversion pipeline), and edge cases that could produce invalid MarkdownV2
or corrupt user-visible content.
"""
import re
import sys
from unittest.mock import AsyncMock, MagicMock
import pytest
from gateway.config import PlatformConfig
# ---------------------------------------------------------------------------
# Mock the telegram package if it's not installed
# ---------------------------------------------------------------------------
def _ensure_telegram_mock():
if "telegram" in sys.modules and hasattr(sys.modules["telegram"], "__file__"):
return
mod = MagicMock()
mod.ext.ContextTypes.DEFAULT_TYPE = type(None)
mod.constants.ParseMode.MARKDOWN_V2 = "MarkdownV2"
mod.constants.ChatType.GROUP = "group"
mod.constants.ChatType.SUPERGROUP = "supergroup"
mod.constants.ChatType.CHANNEL = "channel"
mod.constants.ChatType.PRIVATE = "private"
for name in ("telegram", "telegram.ext", "telegram.constants"):
sys.modules.setdefault(name, mod)
_ensure_telegram_mock()
from gateway.platforms.telegram import TelegramAdapter, _escape_mdv2, _strip_mdv2 # noqa: E402
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture()
def adapter():
config = PlatformConfig(enabled=True, token="fake-token")
return TelegramAdapter(config)
# =========================================================================
# _escape_mdv2
# =========================================================================
class TestEscapeMdv2:
def test_escapes_all_special_characters(self):
special = r'_*[]()~`>#+-=|{}.!\ '
escaped = _escape_mdv2(special)
# Every special char should be preceded by backslash
for ch in r'_*[]()~`>#+-=|{}.!\ ':
if ch == ' ':
continue
assert f'\\{ch}' in escaped
def test_empty_string(self):
assert _escape_mdv2("") == ""
def test_no_special_characters(self):
assert _escape_mdv2("hello world 123") == "hello world 123"
def test_backslash_escaped(self):
assert _escape_mdv2("a\\b") == "a\\\\b"
def test_dot_escaped(self):
assert _escape_mdv2("v2.0") == "v2\\.0"
def test_exclamation_escaped(self):
assert _escape_mdv2("wow!") == "wow\\!"
def test_mixed_text_and_specials(self):
result = _escape_mdv2("Hello (world)!")
assert result == "Hello \\(world\\)\\!"
# =========================================================================
# format_message - basic conversions
# =========================================================================
class TestFormatMessageBasic:
def test_empty_string(self, adapter):
assert adapter.format_message("") == ""
def test_none_input(self, adapter):
# content is falsy, returned as-is
assert adapter.format_message(None) is None
def test_plain_text_specials_escaped(self, adapter):
result = adapter.format_message("Price is $5.00!")
assert "\\." in result
assert "\\!" in result
def test_plain_text_no_markdown(self, adapter):
result = adapter.format_message("Hello world")
assert result == "Hello world"
# =========================================================================
# format_message - code blocks
# =========================================================================
class TestFormatMessageCodeBlocks:
def test_fenced_code_block_preserved(self, adapter):
text = "Before\n```python\nprint('hello')\n```\nAfter"
result = adapter.format_message(text)
# Code block contents must NOT be escaped
assert "```python\nprint('hello')\n```" in result
# But "After" should have no escaping needed (plain text)
assert "After" in result
def test_inline_code_preserved(self, adapter):
text = "Use `my_var` here"
result = adapter.format_message(text)
# Inline code content must NOT be escaped
assert "`my_var`" in result
# The surrounding text's underscore-free content should be fine
assert "Use" in result
def test_code_block_special_chars_not_escaped(self, adapter):
text = "```\nif (x > 0) { return !x; }\n```"
result = adapter.format_message(text)
# Inside code block, > and ! and { should NOT be escaped
assert "if (x > 0) { return !x; }" in result
def test_inline_code_special_chars_not_escaped(self, adapter):
text = "Run `rm -rf ./*` carefully"
result = adapter.format_message(text)
assert "`rm -rf ./*`" in result
def test_multiple_code_blocks(self, adapter):
text = "```\nblock1\n```\ntext\n```\nblock2\n```"
result = adapter.format_message(text)
assert "block1" in result
assert "block2" in result
# "text" between blocks should be present
assert "text" in result
def test_inline_code_backslashes_escaped(self, adapter):
r"""Backslashes in inline code must be escaped for MarkdownV2."""
text = r"Check `C:\ProgramData\VMware\` path"
result = adapter.format_message(text)
assert r"`C:\\ProgramData\\VMware\\`" in result
def test_fenced_code_block_backslashes_escaped(self, adapter):
r"""Backslashes in fenced code blocks must be escaped for MarkdownV2."""
text = "```\npath = r'C:\\Users\\test'\n```"
result = adapter.format_message(text)
assert r"C:\\Users\\test" in result
def test_fenced_code_block_backticks_escaped(self, adapter):
r"""Backticks inside fenced code blocks must be escaped for MarkdownV2."""
text = "```\necho `hostname`\n```"
result = adapter.format_message(text)
assert r"echo \`hostname\`" in result
def test_inline_code_no_double_escape(self, adapter):
r"""Already-escaped backslashes should not be quadruple-escaped."""
text = r"Use `\\server\share`"
result = adapter.format_message(text)
# \\ in input → \\\\ in output (each \ escaped once)
assert r"`\\\\server\\share`" in result
# =========================================================================
# format_message - bold and italic
# =========================================================================
class TestFormatMessageBoldItalic:
def test_bold_converted(self, adapter):
result = adapter.format_message("This is **bold** text")
# MarkdownV2 bold uses single *
assert "*bold*" in result
# Original ** should be gone
assert "**" not in result
def test_italic_converted(self, adapter):
result = adapter.format_message("This is *italic* text")
# MarkdownV2 italic uses _
assert "_italic_" in result
def test_bold_with_special_chars(self, adapter):
result = adapter.format_message("**hello.world!**")
# Content inside bold should be escaped
assert "*hello\\.world\\!*" in result
def test_italic_with_special_chars(self, adapter):
result = adapter.format_message("*hello.world*")
assert "_hello\\.world_" in result
def test_bold_and_italic_in_same_line(self, adapter):
result = adapter.format_message("**bold** and *italic*")
assert "*bold*" in result
assert "_italic_" in result
# =========================================================================
# format_message - headers
# =========================================================================
class TestFormatMessageHeaders:
def test_h1_converted_to_bold(self, adapter):
result = adapter.format_message("# Title")
# Header becomes bold in MarkdownV2
assert "*Title*" in result
# Hash should be removed
assert "#" not in result
def test_h2_converted(self, adapter):
result = adapter.format_message("## Subtitle")
assert "*Subtitle*" in result
def test_header_with_inner_bold_stripped(self, adapter):
# Headers strip redundant **...** inside
result = adapter.format_message("## **Important**")
# Should be *Important* not ***Important***
assert "*Important*" in result
count = result.count("*")
# Should have exactly 2 asterisks (open + close)
assert count == 2
def test_header_with_special_chars(self, adapter):
result = adapter.format_message("# Hello (World)!")
assert "\\(" in result
assert "\\)" in result
assert "\\!" in result
def test_multiline_headers(self, adapter):
text = "# First\nSome text\n## Second"
result = adapter.format_message(text)
assert "*First*" in result
assert "*Second*" in result
assert "Some text" in result
# =========================================================================
# format_message - links
# =========================================================================
class TestFormatMessageLinks:
def test_markdown_link_converted(self, adapter):
result = adapter.format_message("[Click here](https://example.com)")
assert "[Click here](https://example.com)" in result
def test_link_display_text_escaped(self, adapter):
result = adapter.format_message("[Hello!](https://example.com)")
# The ! in display text should be escaped
assert "Hello\\!" in result
def test_link_url_parentheses_escaped(self, adapter):
result = adapter.format_message("[link](https://example.com/path_(1))")
# The ) in URL should be escaped
assert "\\)" in result
def test_link_with_surrounding_text(self, adapter):
result = adapter.format_message("Visit [Google](https://google.com) today.")
assert "[Google](https://google.com)" in result
assert "today\\." in result
# =========================================================================
# format_message - BUG: italic regex spans newlines
# =========================================================================
class TestItalicNewlineBug:
r"""Italic regex ``\*([^*]+)\*`` matched across newlines, corrupting content.
This affects bullet lists using * markers and any text where * appears
at the end of one line and start of another.
"""
def test_bullet_list_not_corrupted(self, adapter):
"""Bullet list items using * must NOT be merged into italic."""
text = "* Item one\n* Item two\n* Item three"
result = adapter.format_message(text)
# Each item should appear in the output (not eaten by italic conversion)
assert "Item one" in result
assert "Item two" in result
assert "Item three" in result
# Should NOT contain _ (italic markers) wrapping list items
assert "_" not in result or "Item" not in result.split("_")[1] if "_" in result else True
def test_asterisk_list_items_preserved(self, adapter):
"""Each * list item should remain as a separate line, not become italic."""
text = "* Alpha\n* Beta"
result = adapter.format_message(text)
# Both items must be present in output
assert "Alpha" in result
assert "Beta" in result
# The text between first * and second * must NOT become italic
lines = result.split("\n")
assert len(lines) >= 2
def test_italic_does_not_span_lines(self, adapter):
"""*text on\nmultiple lines* should NOT become italic."""
text = "Start *across\nlines* end"
result = adapter.format_message(text)
# Should NOT have underscore italic markers wrapping cross-line text
# If this fails, the italic regex is matching across newlines
assert "_across\nlines_" not in result
def test_single_line_italic_still_works(self, adapter):
"""Normal single-line italic must still convert correctly."""
text = "This is *italic* text"
result = adapter.format_message(text)
assert "_italic_" in result
# =========================================================================
# format_message - strikethrough
# =========================================================================
class TestFormatMessageStrikethrough:
def test_strikethrough_converted(self, adapter):
result = adapter.format_message("This is ~~deleted~~ text")
assert "~deleted~" in result
assert "~~" not in result
def test_strikethrough_with_special_chars(self, adapter):
result = adapter.format_message("~~hello.world!~~")
assert "~hello\\.world\\!~" in result
def test_strikethrough_in_code_not_converted(self, adapter):
result = adapter.format_message("`~~not struck~~`")
assert "`~~not struck~~`" in result
def test_strikethrough_with_bold(self, adapter):
result = adapter.format_message("**bold** and ~~struck~~")
assert "*bold*" in result
assert "~struck~" in result
# =========================================================================
# format_message - spoiler
# =========================================================================
class TestFormatMessageSpoiler:
def test_spoiler_converted(self, adapter):
result = adapter.format_message("This is ||hidden|| text")
assert "||hidden||" in result
def test_spoiler_with_special_chars(self, adapter):
result = adapter.format_message("||hello.world!||")
assert "||hello\\.world\\!||" in result
def test_spoiler_in_code_not_converted(self, adapter):
result = adapter.format_message("`||not spoiler||`")
assert "`||not spoiler||`" in result
def test_spoiler_pipes_not_escaped(self, adapter):
"""The || delimiters must not be escaped as \\|\\|."""
result = adapter.format_message("||secret||")
assert "\\|\\|" not in result
assert "||secret||" in result
# =========================================================================
# format_message - blockquote
# =========================================================================
class TestFormatMessageBlockquote:
def test_blockquote_converted(self, adapter):
result = adapter.format_message("> This is a quote")
assert "> This is a quote" in result
# > must NOT be escaped
assert "\\>" not in result
def test_blockquote_with_special_chars(self, adapter):
result = adapter.format_message("> Hello (world)!")
assert "> Hello \\(world\\)\\!" in result
assert "\\>" not in result
def test_blockquote_multiline(self, adapter):
text = "> Line one\n> Line two"
result = adapter.format_message(text)
assert "> Line one" in result
assert "> Line two" in result
assert "\\>" not in result
def test_blockquote_in_code_not_converted(self, adapter):
result = adapter.format_message("```\n> not a quote\n```")
assert "> not a quote" in result
def test_nested_blockquote(self, adapter):
result = adapter.format_message(">> Nested quote")
assert ">> Nested quote" in result
assert "\\>" not in result
def test_gt_in_middle_of_line_still_escaped(self, adapter):
"""Only > at line start is a blockquote; mid-line > should be escaped."""
result = adapter.format_message("5 > 3")
assert "\\>" in result
# =========================================================================
# format_message - mixed/complex
# =========================================================================
class TestFormatMessageComplex:
def test_code_block_with_bold_outside(self, adapter):
text = "**Note:**\n```\ncode here\n```"
result = adapter.format_message(text)
assert "*Note:*" in result or "*Note\\:*" in result
assert "```\ncode here\n```" in result
def test_bold_inside_code_not_converted(self, adapter):
"""Bold markers inside code blocks should not be converted."""
text = "```\n**not bold**\n```"
result = adapter.format_message(text)
assert "**not bold**" in result
def test_link_inside_code_not_converted(self, adapter):
text = "`[not a link](url)`"
result = adapter.format_message(text)
assert "`[not a link](url)`" in result
def test_header_after_code_block(self, adapter):
text = "```\ncode\n```\n## Title"
result = adapter.format_message(text)
assert "*Title*" in result
assert "```\ncode\n```" in result
def test_multiple_bold_segments(self, adapter):
result = adapter.format_message("**a** and **b** and **c**")
assert result.count("*") >= 6 # 3 bold pairs = 6 asterisks
def test_special_chars_in_plain_text(self, adapter):
result = adapter.format_message("Price: $5.00 (50% off!)")
assert "\\." in result
assert "\\(" in result
assert "\\)" in result
assert "\\!" in result
def test_empty_bold(self, adapter):
"""**** (empty bold) should not crash."""
result = adapter.format_message("****")
assert result is not None
def test_empty_code_block(self, adapter):
result = adapter.format_message("```\n```")
assert "```" in result
def test_placeholder_collision(self, adapter):
"""Many formatting elements should not cause placeholder collisions."""
text = (
"# Header\n"
"**bold1** *italic1* `code1`\n"
"**bold2** *italic2* `code2`\n"
"```\nblock\n```\n"
"[link](https://url.com)"
)
result = adapter.format_message(text)
# No placeholder tokens should leak into output
assert "\x00" not in result
# All elements should be present
assert "Header" in result
assert "block" in result
assert "url.com" in result
# =========================================================================
# _strip_mdv2 — plaintext fallback
# =========================================================================
class TestStripMdv2:
def test_removes_escape_backslashes(self):
assert _strip_mdv2(r"hello\.world\!") == "hello.world!"
def test_removes_bold_markers(self):
assert _strip_mdv2("*bold text*") == "bold text"
def test_removes_italic_markers(self):
assert _strip_mdv2("_italic text_") == "italic text"
def test_removes_both_bold_and_italic(self):
result = _strip_mdv2("*bold* and _italic_")
assert result == "bold and italic"
def test_preserves_snake_case(self):
assert _strip_mdv2("my_variable_name") == "my_variable_name"
def test_preserves_multi_underscore_identifier(self):
assert _strip_mdv2("some_func_call here") == "some_func_call here"
def test_plain_text_unchanged(self):
assert _strip_mdv2("plain text") == "plain text"
def test_empty_string(self):
assert _strip_mdv2("") == ""
def test_removes_strikethrough_markers(self):
assert _strip_mdv2("~struck text~") == "struck text"
def test_removes_spoiler_markers(self):
assert _strip_mdv2("||hidden text||") == "hidden text"
@pytest.mark.asyncio
async def test_send_escapes_chunk_indicator_for_markdownv2(adapter):
adapter.MAX_MESSAGE_LENGTH = 80
adapter._bot = MagicMock()
sent_texts = []
async def _fake_send_message(**kwargs):
sent_texts.append(kwargs["text"])
msg = MagicMock()
msg.message_id = len(sent_texts)
return msg
adapter._bot.send_message = AsyncMock(side_effect=_fake_send_message)
content = ("**bold** chunk content " * 12).strip()
result = await adapter.send("123", content)
assert result.success is True
assert len(sent_texts) > 1
assert re.search(r" \\\([0-9]+/[0-9]+\\\)$", sent_texts[0])
assert re.search(r" \\\([0-9]+/[0-9]+\\\)$", sent_texts[-1])

View file

@ -0,0 +1,49 @@
import asyncio
from unittest.mock import MagicMock
import pytest
from gateway.config import GatewayConfig, Platform, PlatformConfig
from gateway.platforms.base import MessageEvent, MessageType
from gateway.session import SessionSource, build_session_key
from gateway.run import GatewayRunner
class _PendingAdapter:
def __init__(self):
self._pending_messages = {}
def _make_runner():
runner = object.__new__(GatewayRunner)
runner.config = GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")})
runner.adapters = {Platform.TELEGRAM: _PendingAdapter()}
runner._running_agents = {}
runner._pending_messages = {}
runner._pending_approvals = {}
runner._voice_mode = {}
runner._is_user_authorized = lambda _source: True
return runner
@pytest.mark.asyncio
async def test_handle_message_does_not_priority_interrupt_photo_followup():
runner = _make_runner()
source = SessionSource(platform=Platform.TELEGRAM, chat_id="12345", chat_type="dm")
session_key = build_session_key(source)
running_agent = MagicMock()
runner._running_agents[session_key] = running_agent
event = MessageEvent(
text="caption",
message_type=MessageType.PHOTO,
source=source,
media_urls=["/tmp/photo-a.jpg"],
media_types=["image/jpeg"],
)
result = await runner._handle_message(event)
assert result is None
running_agent.interrupt.assert_not_called()
assert runner.adapters[Platform.TELEGRAM]._pending_messages[session_key] is event

View file

@ -0,0 +1,121 @@
"""Tests for Telegram text message aggregation.
When a user sends a long message, Telegram clients split it into multiple
updates. The TelegramAdapter should buffer rapid successive text messages
from the same session and aggregate them before dispatching.
"""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from gateway.config import Platform, PlatformConfig
from gateway.platforms.base import MessageEvent, MessageType, SessionSource
def _make_adapter():
"""Create a minimal TelegramAdapter for testing text batching."""
from gateway.platforms.telegram import TelegramAdapter
config = PlatformConfig(enabled=True, token="test-token")
adapter = object.__new__(TelegramAdapter)
adapter._platform = Platform.TELEGRAM
adapter.config = config
adapter._pending_text_batches = {}
adapter._pending_text_batch_tasks = {}
adapter._text_batch_delay_seconds = 0.1 # fast for tests
adapter._active_sessions = {}
adapter._pending_messages = {}
adapter._message_handler = AsyncMock()
adapter.handle_message = AsyncMock()
return adapter
def _make_event(text: str, chat_id: str = "12345") -> MessageEvent:
return MessageEvent(
text=text,
message_type=MessageType.TEXT,
source=SessionSource(platform=Platform.TELEGRAM, chat_id=chat_id, chat_type="dm"),
)
class TestTextBatching:
@pytest.mark.asyncio
async def test_single_message_dispatched_after_delay(self):
adapter = _make_adapter()
event = _make_event("hello world")
adapter._enqueue_text_event(event)
# Not dispatched yet
adapter.handle_message.assert_not_called()
# Wait for flush
await asyncio.sleep(0.2)
adapter.handle_message.assert_called_once()
dispatched = adapter.handle_message.call_args[0][0]
assert dispatched.text == "hello world"
@pytest.mark.asyncio
async def test_split_messages_aggregated(self):
"""Two rapid messages from the same chat should be merged."""
adapter = _make_adapter()
adapter._enqueue_text_event(_make_event("This is part one of a long"))
await asyncio.sleep(0.02) # small gap, within batch window
adapter._enqueue_text_event(_make_event("message that was split by Telegram."))
# Not dispatched yet (timer restarted)
adapter.handle_message.assert_not_called()
# Wait for flush
await asyncio.sleep(0.2)
adapter.handle_message.assert_called_once()
dispatched = adapter.handle_message.call_args[0][0]
assert "part one" in dispatched.text
assert "split by Telegram" in dispatched.text
@pytest.mark.asyncio
async def test_three_way_split_aggregated(self):
"""Three rapid messages should all merge."""
adapter = _make_adapter()
adapter._enqueue_text_event(_make_event("chunk 1"))
await asyncio.sleep(0.02)
adapter._enqueue_text_event(_make_event("chunk 2"))
await asyncio.sleep(0.02)
adapter._enqueue_text_event(_make_event("chunk 3"))
await asyncio.sleep(0.2)
adapter.handle_message.assert_called_once()
text = adapter.handle_message.call_args[0][0].text
assert "chunk 1" in text
assert "chunk 2" in text
assert "chunk 3" in text
@pytest.mark.asyncio
async def test_different_chats_not_merged(self):
"""Messages from different chats should be separate batches."""
adapter = _make_adapter()
adapter._enqueue_text_event(_make_event("from user A", chat_id="111"))
adapter._enqueue_text_event(_make_event("from user B", chat_id="222"))
await asyncio.sleep(0.2)
assert adapter.handle_message.call_count == 2
@pytest.mark.asyncio
async def test_batch_cleans_up_after_flush(self):
"""After flushing, internal state should be clean."""
adapter = _make_adapter()
adapter._enqueue_text_event(_make_event("test"))
await asyncio.sleep(0.2)
assert len(adapter._pending_text_batches) == 0
assert len(adapter._pending_text_batch_tasks) == 0

View file

@ -0,0 +1,208 @@
"""Tests for /title gateway slash command.
Tests the _handle_title_command handler (set/show session titles)
across all gateway messenger platforms.
"""
import os
from unittest.mock import MagicMock, patch
import pytest
from gateway.config import Platform
from gateway.platforms.base import MessageEvent
from gateway.session import SessionSource
def _make_event(text="/title", platform=Platform.TELEGRAM,
user_id="12345", chat_id="67890"):
"""Build a MessageEvent for testing."""
source = SessionSource(
platform=platform,
user_id=user_id,
chat_id=chat_id,
user_name="testuser",
)
return MessageEvent(text=text, source=source)
def _make_runner(session_db=None):
"""Create a bare GatewayRunner with a mock session_store and optional session_db."""
from gateway.run import GatewayRunner
runner = object.__new__(GatewayRunner)
runner.adapters = {}
runner._voice_mode = {}
runner._session_db = session_db
# Mock session_store that returns a session entry with a known session_id
mock_session_entry = MagicMock()
mock_session_entry.session_id = "test_session_123"
mock_session_entry.session_key = "telegram:12345:67890"
mock_store = MagicMock()
mock_store.get_or_create_session.return_value = mock_session_entry
runner.session_store = mock_store
return runner
# ---------------------------------------------------------------------------
# _handle_title_command
# ---------------------------------------------------------------------------
class TestHandleTitleCommand:
"""Tests for GatewayRunner._handle_title_command."""
@pytest.mark.asyncio
async def test_set_title(self, tmp_path):
"""Setting a title returns confirmation."""
from hermes_state import SessionDB
db = SessionDB(db_path=tmp_path / "state.db")
db.create_session("test_session_123", "telegram")
runner = _make_runner(session_db=db)
event = _make_event(text="/title My Research Project")
result = await runner._handle_title_command(event)
assert "My Research Project" in result
assert "✏️" in result
# Verify in DB
assert db.get_session_title("test_session_123") == "My Research Project"
db.close()
@pytest.mark.asyncio
async def test_show_title_when_set(self, tmp_path):
"""Showing title when one is set returns the title."""
from hermes_state import SessionDB
db = SessionDB(db_path=tmp_path / "state.db")
db.create_session("test_session_123", "telegram")
db.set_session_title("test_session_123", "Existing Title")
runner = _make_runner(session_db=db)
event = _make_event(text="/title")
result = await runner._handle_title_command(event)
assert "Existing Title" in result
assert "📌" in result
db.close()
@pytest.mark.asyncio
async def test_show_title_when_not_set(self, tmp_path):
"""Showing title when none is set returns usage hint."""
from hermes_state import SessionDB
db = SessionDB(db_path=tmp_path / "state.db")
db.create_session("test_session_123", "telegram")
runner = _make_runner(session_db=db)
event = _make_event(text="/title")
result = await runner._handle_title_command(event)
assert "No title set" in result
assert "/title" in result
db.close()
@pytest.mark.asyncio
async def test_title_conflict(self, tmp_path):
"""Setting a title already used by another session returns error."""
from hermes_state import SessionDB
db = SessionDB(db_path=tmp_path / "state.db")
db.create_session("other_session", "telegram")
db.set_session_title("other_session", "Taken Title")
db.create_session("test_session_123", "telegram")
runner = _make_runner(session_db=db)
event = _make_event(text="/title Taken Title")
result = await runner._handle_title_command(event)
assert "already in use" in result
assert "⚠️" in result
db.close()
@pytest.mark.asyncio
async def test_no_session_db(self):
"""Returns error when session database is not available."""
runner = _make_runner(session_db=None)
event = _make_event(text="/title My Title")
result = await runner._handle_title_command(event)
assert "not available" in result
@pytest.mark.asyncio
async def test_title_too_long(self, tmp_path):
"""Setting a title that exceeds max length returns error."""
from hermes_state import SessionDB
db = SessionDB(db_path=tmp_path / "state.db")
db.create_session("test_session_123", "telegram")
runner = _make_runner(session_db=db)
long_title = "A" * 150
event = _make_event(text=f"/title {long_title}")
result = await runner._handle_title_command(event)
assert "too long" in result
assert "⚠️" in result
db.close()
@pytest.mark.asyncio
async def test_title_control_chars_sanitized(self, tmp_path):
"""Control characters are stripped and sanitized title is stored."""
from hermes_state import SessionDB
db = SessionDB(db_path=tmp_path / "state.db")
db.create_session("test_session_123", "telegram")
runner = _make_runner(session_db=db)
event = _make_event(text="/title hello\x00world")
result = await runner._handle_title_command(event)
assert "helloworld" in result
assert db.get_session_title("test_session_123") == "helloworld"
db.close()
@pytest.mark.asyncio
async def test_title_only_control_chars(self, tmp_path):
"""Title with only control chars returns empty error."""
from hermes_state import SessionDB
db = SessionDB(db_path=tmp_path / "state.db")
db.create_session("test_session_123", "telegram")
runner = _make_runner(session_db=db)
event = _make_event(text="/title \x00\x01\x02")
result = await runner._handle_title_command(event)
assert "empty after cleanup" in result
db.close()
@pytest.mark.asyncio
async def test_works_across_platforms(self, tmp_path):
"""The /title command works for Discord, Slack, and WhatsApp too."""
from hermes_state import SessionDB
for platform in [Platform.DISCORD, Platform.TELEGRAM]:
db = SessionDB(db_path=tmp_path / f"state_{platform.value}.db")
db.create_session("test_session_123", platform.value)
runner = _make_runner(session_db=db)
event = _make_event(text="/title Cross-Platform Test", platform=platform)
result = await runner._handle_title_command(event)
assert "Cross-Platform Test" in result
assert db.get_session_title("test_session_123") == "Cross-Platform Test"
db.close()
# ---------------------------------------------------------------------------
# /title in help and known_commands
# ---------------------------------------------------------------------------
class TestTitleInHelp:
"""Verify /title appears in help text and known commands."""
@pytest.mark.asyncio
async def test_title_in_help_output(self):
"""The /help output includes /title."""
runner = _make_runner()
event = _make_event(text="/help")
# Need hooks for help command
from gateway.hooks import HookRegistry
runner.hooks = HookRegistry()
result = await runner._handle_help_command(event)
assert "/title" in result
def test_title_is_known_command(self):
"""The /title command is in the _known_commands set."""
from gateway.run import GatewayRunner
import inspect
source = inspect.getsource(GatewayRunner._handle_message)
assert '"title"' in source

View file

@ -0,0 +1,267 @@
"""Tests for transcript history offset fix.
Regression tests for a bug where the gateway transcript lost 1 message
per turn from turn 2 onwards. The raw transcript history includes
``session_meta`` entries that are filtered out before being passed to
the agent. The agent returns messages built from this filtered history
plus new messages from the current turn.
The old code used ``len(history)`` (raw count, includes session_meta)
to slice ``agent_messages``, which caused the slice to skip valid new
messages. The fix adds ``history_offset`` (the filtered history length)
to ``_run_agent``'s return dict and uses it for the slice.
"""
import pytest
# ---------------------------------------------------------------------------
# Helpers - replicate the filtering logic from _run_agent
# ---------------------------------------------------------------------------
def _filter_history(history: list) -> list:
"""Replicate the agent_history filtering from GatewayRunner._run_agent.
Strips session_meta and system messages, exactly as the real code does.
"""
agent_history = []
for msg in history:
role = msg.get("role")
if not role:
continue
if role in ("session_meta",):
continue
if role == "system":
continue
has_tool_calls = "tool_calls" in msg
has_tool_call_id = "tool_call_id" in msg
is_tool_message = role == "tool"
if has_tool_calls or has_tool_call_id or is_tool_message:
clean_msg = {k: v for k, v in msg.items() if k != "timestamp"}
agent_history.append(clean_msg)
else:
content = msg.get("content")
if content:
agent_history.append({"role": role, "content": content})
return agent_history
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestTranscriptHistoryOffset:
"""Verify the transcript extraction uses the filtered history length."""
def test_session_meta_causes_offset_mismatch(self):
"""Turn 2: session_meta makes len(history) > len(agent_history).
- history (raw): 1 session_meta + 2 conversation = 3 entries
- agent_history (filtered): 2 entries
- Agent returns 2 old + 2 new = 4 messages
- OLD: agent_messages[3:] = 1 message (lost the user message)
- FIX: agent_messages[2:] = 2 messages (correct)
"""
history = [
{"role": "session_meta", "tools": [], "model": "gpt-4",
"platform": "telegram", "timestamp": "t0"},
{"role": "user", "content": "Hello", "timestamp": "t1"},
{"role": "assistant", "content": "Hi there!", "timestamp": "t1"},
]
agent_history = _filter_history(history)
assert len(agent_history) == 2 # session_meta stripped
# Agent returns: filtered history (2) + new turn (2)
agent_messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "What is Python?"},
{"role": "assistant", "content": "A programming language."},
]
# OLD behavior: len(history) = 3, skips too many
old_offset = len(history)
old_new = (agent_messages[old_offset:]
if len(agent_messages) > old_offset
else agent_messages)
assert len(old_new) == 1 # BUG: lost the user message
# FIXED behavior: history_offset = 2
history_offset = len(agent_history)
fixed_new = (agent_messages[history_offset:]
if len(agent_messages) > history_offset
else [])
assert len(fixed_new) == 2
assert fixed_new[0]["content"] == "What is Python?"
assert fixed_new[1]["content"] == "A programming language."
def test_no_session_meta_same_result(self):
"""First turn has no session_meta, so both approaches agree."""
history = []
agent_history = _filter_history(history)
assert len(agent_history) == 0
agent_messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi!"},
]
old_new = (agent_messages[len(history):]
if len(agent_messages) > len(history)
else agent_messages)
fixed_new = (agent_messages[len(agent_history):]
if len(agent_messages) > len(agent_history)
else [])
assert old_new == fixed_new
assert len(fixed_new) == 2
def test_multiple_session_meta_larger_drift(self):
"""Two session_meta entries double the offset error.
This can happen when the session spans tool definition changes
or model switches that each write a new session_meta record.
"""
history = [
{"role": "session_meta", "tools": [], "timestamp": "t0"},
{"role": "user", "content": "msg1", "timestamp": "t1"},
{"role": "assistant", "content": "reply1", "timestamp": "t1"},
{"role": "session_meta", "tools": ["new_tool"], "timestamp": "t2"},
{"role": "user", "content": "msg2", "timestamp": "t3"},
{"role": "assistant", "content": "reply2", "timestamp": "t3"},
]
agent_history = _filter_history(history)
assert len(agent_history) == 4
assert len(history) == 6 # 2 extra session_meta entries
# Agent returns 4 old + 2 new = 6 total
agent_messages = [
{"role": "user", "content": "msg1"},
{"role": "assistant", "content": "reply1"},
{"role": "user", "content": "msg2"},
{"role": "assistant", "content": "reply2"},
{"role": "user", "content": "msg3"},
{"role": "assistant", "content": "reply3"},
]
# OLD: len(history) == len(agent_messages) == 6 -> else branch
old_offset = len(history)
old_new = (agent_messages[old_offset:]
if len(agent_messages) > old_offset
else agent_messages)
# BUG: treats ALL messages as new (duplicates entire history)
assert old_new == agent_messages
# FIXED: history_offset = 4
fixed_new = (agent_messages[len(agent_history):]
if len(agent_messages) > len(agent_history)
else [])
assert len(fixed_new) == 2
assert fixed_new[0]["content"] == "msg3"
assert fixed_new[1]["content"] == "reply3"
def test_system_messages_also_filtered(self):
"""system messages in history are also stripped from agent_history."""
history = [
{"role": "session_meta", "tools": [], "timestamp": "t0"},
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hi", "timestamp": "t1"},
{"role": "assistant", "content": "Hello!", "timestamp": "t1"},
]
agent_history = _filter_history(history)
assert len(agent_history) == 2 # only user + assistant
agent_messages = [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hello!"},
{"role": "user", "content": "New question"},
{"role": "assistant", "content": "New answer"},
]
# OLD: len(history) = 4, skips everything
old_offset = len(history)
old_new = (agent_messages[old_offset:]
if len(agent_messages) > old_offset
else agent_messages)
assert old_new == agent_messages # BUG: all treated as new
# FIXED
fixed_new = (agent_messages[len(agent_history):]
if len(agent_messages) > len(agent_history)
else [])
assert len(fixed_new) == 2
assert fixed_new[0]["content"] == "New question"
def test_else_branch_returns_empty_list(self):
"""When agent has fewer messages than offset, return [] not all.
The old code had ``else agent_messages`` which would treat the
entire message list as new when the agent compressed or dropped
messages. The fix changes this to ``else []``, falling through
to the simple user/assistant fallback path.
"""
history = [
{"role": "session_meta", "tools": [], "timestamp": "t0"},
{"role": "user", "content": "Hello", "timestamp": "t1"},
{"role": "assistant", "content": "Hi!", "timestamp": "t1"},
]
# Agent compressed and returned fewer messages than history
agent_messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi!"},
]
history_offset = len(_filter_history(history)) # 2
new_messages = (agent_messages[history_offset:]
if len(agent_messages) > history_offset
else [])
# 2 == 2, so no new messages - falls to fallback
assert new_messages == []
def test_tool_call_messages_preserved_in_filter(self):
"""Tool call messages pass through the filter, keeping offset correct."""
history = [
{"role": "session_meta", "tools": [], "timestamp": "t0"},
{"role": "user", "content": "Search for cats", "timestamp": "t1"},
{"role": "assistant", "content": None, "timestamp": "t1",
"tool_calls": [{"id": "tc1", "function": {"name": "web_search"}}]},
{"role": "tool", "tool_call_id": "tc1",
"content": "Results about cats", "timestamp": "t1"},
{"role": "assistant", "content": "Here are results.",
"timestamp": "t1"},
]
agent_history = _filter_history(history)
# session_meta filtered, but tool_calls/tool messages kept
assert len(agent_history) == 4
assert len(history) == 5 # 1 session_meta extra
agent_messages = [
{"role": "user", "content": "Search for cats"},
{"role": "assistant", "content": None,
"tool_calls": [{"id": "tc1", "function": {"name": "web_search"}}]},
{"role": "tool", "tool_call_id": "tc1", "content": "Results about cats"},
{"role": "assistant", "content": "Here are results."},
{"role": "user", "content": "Now search for dogs"},
{"role": "assistant", "content": "Dog results here."},
]
# OLD: len(history) = 5, agent_messages[5:] = 1 message (lost user msg)
old_new = (agent_messages[len(history):]
if len(agent_messages) > len(history)
else agent_messages)
assert len(old_new) == 1 # BUG
# FIXED
fixed_new = (agent_messages[len(agent_history):]
if len(agent_messages) > len(agent_history)
else [])
assert len(fixed_new) == 2
assert fixed_new[0]["content"] == "Now search for dogs"
assert fixed_new[1]["content"] == "Dog results here."

View file

@ -0,0 +1,137 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
import pytest
from gateway.config import GatewayConfig, Platform, PlatformConfig
from gateway.platforms.base import MessageEvent
from gateway.session import SessionSource
def _clear_auth_env(monkeypatch) -> None:
for key in (
"TELEGRAM_ALLOWED_USERS",
"DISCORD_ALLOWED_USERS",
"WHATSAPP_ALLOWED_USERS",
"SLACK_ALLOWED_USERS",
"SIGNAL_ALLOWED_USERS",
"EMAIL_ALLOWED_USERS",
"SMS_ALLOWED_USERS",
"MATTERMOST_ALLOWED_USERS",
"MATRIX_ALLOWED_USERS",
"DINGTALK_ALLOWED_USERS",
"GATEWAY_ALLOWED_USERS",
"TELEGRAM_ALLOW_ALL_USERS",
"DISCORD_ALLOW_ALL_USERS",
"WHATSAPP_ALLOW_ALL_USERS",
"SLACK_ALLOW_ALL_USERS",
"SIGNAL_ALLOW_ALL_USERS",
"EMAIL_ALLOW_ALL_USERS",
"SMS_ALLOW_ALL_USERS",
"MATTERMOST_ALLOW_ALL_USERS",
"MATRIX_ALLOW_ALL_USERS",
"DINGTALK_ALLOW_ALL_USERS",
"GATEWAY_ALLOW_ALL_USERS",
):
monkeypatch.delenv(key, raising=False)
def _make_event(platform: Platform, user_id: str, chat_id: str) -> MessageEvent:
return MessageEvent(
text="hello",
message_id="m1",
source=SessionSource(
platform=platform,
user_id=user_id,
chat_id=chat_id,
user_name="tester",
chat_type="dm",
),
)
def _make_runner(platform: Platform, config: GatewayConfig):
from gateway.run import GatewayRunner
runner = object.__new__(GatewayRunner)
runner.config = config
adapter = SimpleNamespace(send=AsyncMock())
runner.adapters = {platform: adapter}
runner.pairing_store = MagicMock()
runner.pairing_store.is_approved.return_value = False
return runner, adapter
@pytest.mark.asyncio
async def test_unauthorized_dm_pairs_by_default(monkeypatch):
_clear_auth_env(monkeypatch)
config = GatewayConfig(
platforms={Platform.WHATSAPP: PlatformConfig(enabled=True)},
)
runner, adapter = _make_runner(Platform.WHATSAPP, config)
runner.pairing_store.generate_code.return_value = "ABC12DEF"
result = await runner._handle_message(
_make_event(
Platform.WHATSAPP,
"15551234567@s.whatsapp.net",
"15551234567@s.whatsapp.net",
)
)
assert result is None
runner.pairing_store.generate_code.assert_called_once_with(
"whatsapp",
"15551234567@s.whatsapp.net",
"tester",
)
adapter.send.assert_awaited_once()
assert "ABC12DEF" in adapter.send.await_args.args[1]
@pytest.mark.asyncio
async def test_unauthorized_whatsapp_dm_can_be_ignored(monkeypatch):
_clear_auth_env(monkeypatch)
config = GatewayConfig(
platforms={
Platform.WHATSAPP: PlatformConfig(
enabled=True,
extra={"unauthorized_dm_behavior": "ignore"},
),
},
)
runner, adapter = _make_runner(Platform.WHATSAPP, config)
result = await runner._handle_message(
_make_event(
Platform.WHATSAPP,
"15551234567@s.whatsapp.net",
"15551234567@s.whatsapp.net",
)
)
assert result is None
runner.pairing_store.generate_code.assert_not_called()
adapter.send.assert_not_awaited()
@pytest.mark.asyncio
async def test_global_ignore_suppresses_pairing_reply(monkeypatch):
_clear_auth_env(monkeypatch)
config = GatewayConfig(
unauthorized_dm_behavior="ignore",
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")},
)
runner, adapter = _make_runner(Platform.TELEGRAM, config)
result = await runner._handle_message(
_make_event(
Platform.TELEGRAM,
"12345",
"12345",
)
)
assert result is None
runner.pairing_store.generate_code.assert_not_called()
adapter.send.assert_not_awaited()

View file

@ -0,0 +1,637 @@
"""Tests for /update gateway slash command.
Tests both the _handle_update_command handler (spawns update process) and
the _send_update_notification startup hook (sends results after restart).
"""
import json
import os
from pathlib import Path
from unittest.mock import patch, MagicMock, AsyncMock
import pytest
from gateway.config import Platform
from gateway.platforms.base import MessageEvent
from gateway.session import SessionSource
def _make_event(text="/update", platform=Platform.TELEGRAM,
user_id="12345", chat_id="67890"):
"""Build a MessageEvent for testing."""
source = SessionSource(
platform=platform,
user_id=user_id,
chat_id=chat_id,
user_name="testuser",
)
return MessageEvent(text=text, source=source)
def _make_runner():
"""Create a bare GatewayRunner without calling __init__."""
from gateway.run import GatewayRunner
runner = object.__new__(GatewayRunner)
runner.adapters = {}
runner._voice_mode = {}
return runner
# ---------------------------------------------------------------------------
# _handle_update_command
# ---------------------------------------------------------------------------
class TestHandleUpdateCommand:
"""Tests for GatewayRunner._handle_update_command."""
@pytest.mark.asyncio
async def test_no_git_directory(self, tmp_path):
"""Returns an error when .git does not exist."""
runner = _make_runner()
event = _make_event()
# Point _hermes_home to tmp_path and project_root to a dir without .git
fake_root = tmp_path / "project"
fake_root.mkdir()
with patch("gateway.run._hermes_home", tmp_path), \
patch("gateway.run.Path") as MockPath:
# Path(__file__).parent.parent.resolve() -> fake_root
MockPath.return_value = MagicMock()
MockPath.__truediv__ = Path.__truediv__
# Easier: just patch the __file__ resolution in the method
pass
# Simpler approach — mock at method level using a wrapper
from gateway.run import GatewayRunner
runner = _make_runner()
with patch("gateway.run._hermes_home", tmp_path):
# The handler does Path(__file__).parent.parent.resolve()
# We need to make project_root / '.git' not exist.
# Since Path(__file__) resolves to the real gateway/run.py,
# project_root will be the real hermes-agent dir (which HAS .git).
# Patch Path to control this.
original_path = Path
class FakePath(type(Path())):
pass
# Actually, simplest: just patch the specific file attr
fake_file = str(fake_root / "gateway" / "run.py")
(fake_root / "gateway").mkdir(parents=True)
(fake_root / "gateway" / "run.py").touch()
with patch("gateway.run.__file__", fake_file):
result = await runner._handle_update_command(event)
assert "Not a git repository" in result
@pytest.mark.asyncio
async def test_no_hermes_binary(self, tmp_path):
"""Returns error when hermes is not on PATH and hermes_cli is not importable."""
runner = _make_runner()
event = _make_event()
# Create project dir WITH .git
fake_root = tmp_path / "project"
fake_root.mkdir()
(fake_root / ".git").mkdir()
(fake_root / "gateway").mkdir()
(fake_root / "gateway" / "run.py").touch()
fake_file = str(fake_root / "gateway" / "run.py")
with patch("gateway.run._hermes_home", tmp_path), \
patch("gateway.run.__file__", fake_file), \
patch("shutil.which", return_value=None), \
patch("importlib.util.find_spec", return_value=None):
result = await runner._handle_update_command(event)
assert "Could not locate" in result
assert "hermes update" in result
@pytest.mark.asyncio
async def test_fallback_to_sys_executable(self, tmp_path):
"""Falls back to sys.executable -m hermes_cli.main when hermes not on PATH."""
runner = _make_runner()
event = _make_event()
fake_root = tmp_path / "project"
fake_root.mkdir()
(fake_root / ".git").mkdir()
(fake_root / "gateway").mkdir()
(fake_root / "gateway" / "run.py").touch()
fake_file = str(fake_root / "gateway" / "run.py")
hermes_home = tmp_path / "hermes"
hermes_home.mkdir()
mock_popen = MagicMock()
fake_spec = MagicMock()
with patch("gateway.run._hermes_home", hermes_home), \
patch("gateway.run.__file__", fake_file), \
patch("shutil.which", return_value=None), \
patch("importlib.util.find_spec", return_value=fake_spec), \
patch("subprocess.Popen", mock_popen):
result = await runner._handle_update_command(event)
assert "Starting Hermes update" in result
call_args = mock_popen.call_args[0][0]
# The update_cmd uses sys.executable -m hermes_cli.main
joined = " ".join(call_args) if isinstance(call_args, list) else call_args
assert "hermes_cli.main" in joined or "bash" in call_args[0]
@pytest.mark.asyncio
async def test_resolve_hermes_bin_prefers_which(self, tmp_path):
"""_resolve_hermes_bin returns argv parts from shutil.which when available."""
from gateway.run import _resolve_hermes_bin
with patch("shutil.which", return_value="/custom/path/hermes"):
result = _resolve_hermes_bin()
assert result == ["/custom/path/hermes"]
@pytest.mark.asyncio
async def test_resolve_hermes_bin_fallback(self):
"""_resolve_hermes_bin falls back to sys.executable argv when which fails."""
import sys
from gateway.run import _resolve_hermes_bin
fake_spec = MagicMock()
with patch("shutil.which", return_value=None), \
patch("importlib.util.find_spec", return_value=fake_spec):
result = _resolve_hermes_bin()
assert result == [sys.executable, "-m", "hermes_cli.main"]
@pytest.mark.asyncio
async def test_resolve_hermes_bin_returns_none_when_both_fail(self):
"""_resolve_hermes_bin returns None when both strategies fail."""
from gateway.run import _resolve_hermes_bin
with patch("shutil.which", return_value=None), \
patch("importlib.util.find_spec", return_value=None):
result = _resolve_hermes_bin()
assert result is None
@pytest.mark.asyncio
async def test_writes_pending_marker(self, tmp_path):
"""Writes .update_pending.json with correct platform and chat info."""
runner = _make_runner()
event = _make_event(platform=Platform.TELEGRAM, chat_id="99999")
fake_root = tmp_path / "project"
fake_root.mkdir()
(fake_root / ".git").mkdir()
(fake_root / "gateway").mkdir()
(fake_root / "gateway" / "run.py").touch()
fake_file = str(fake_root / "gateway" / "run.py")
hermes_home = tmp_path / "hermes"
hermes_home.mkdir()
with patch("gateway.run._hermes_home", hermes_home), \
patch("gateway.run.__file__", fake_file), \
patch("shutil.which", side_effect=lambda x: "/usr/bin/hermes" if x == "hermes" else "/usr/bin/systemd-run"), \
patch("subprocess.Popen"):
result = await runner._handle_update_command(event)
pending_path = hermes_home / ".update_pending.json"
assert pending_path.exists()
data = json.loads(pending_path.read_text())
assert data["platform"] == "telegram"
assert data["chat_id"] == "99999"
assert "timestamp" in data
assert not (hermes_home / ".update_exit_code").exists()
@pytest.mark.asyncio
async def test_spawns_systemd_run(self, tmp_path):
"""Uses systemd-run when available."""
runner = _make_runner()
event = _make_event()
fake_root = tmp_path / "project"
fake_root.mkdir()
(fake_root / ".git").mkdir()
(fake_root / "gateway").mkdir()
(fake_root / "gateway" / "run.py").touch()
fake_file = str(fake_root / "gateway" / "run.py")
hermes_home = tmp_path / "hermes"
hermes_home.mkdir()
mock_popen = MagicMock()
with patch("gateway.run._hermes_home", hermes_home), \
patch("gateway.run.__file__", fake_file), \
patch("shutil.which", side_effect=lambda x: f"/usr/bin/{x}"), \
patch("subprocess.Popen", mock_popen):
result = await runner._handle_update_command(event)
# Verify systemd-run was used
call_args = mock_popen.call_args[0][0]
assert call_args[0] == "/usr/bin/systemd-run"
assert "--scope" in call_args
assert ".update_exit_code" in call_args[-1]
assert "Starting Hermes update" in result
@pytest.mark.asyncio
async def test_fallback_nohup_when_no_systemd_run(self, tmp_path):
"""Falls back to nohup when systemd-run is not available."""
runner = _make_runner()
event = _make_event()
fake_root = tmp_path / "project"
fake_root.mkdir()
(fake_root / ".git").mkdir()
(fake_root / "gateway").mkdir()
(fake_root / "gateway" / "run.py").touch()
fake_file = str(fake_root / "gateway" / "run.py")
hermes_home = tmp_path / "hermes"
hermes_home.mkdir()
mock_popen = MagicMock()
def which_no_systemd(x):
if x == "hermes":
return "/usr/bin/hermes"
if x == "systemd-run":
return None
return None
with patch("gateway.run._hermes_home", hermes_home), \
patch("gateway.run.__file__", fake_file), \
patch("shutil.which", side_effect=which_no_systemd), \
patch("subprocess.Popen", mock_popen):
result = await runner._handle_update_command(event)
# Verify bash -c nohup fallback was used
call_args = mock_popen.call_args[0][0]
assert call_args[0] == "bash"
assert "nohup" in call_args[2]
assert ".update_exit_code" in call_args[2]
assert "Starting Hermes update" in result
@pytest.mark.asyncio
async def test_popen_failure_cleans_up(self, tmp_path):
"""Cleans up pending file and returns error on Popen failure."""
runner = _make_runner()
event = _make_event()
fake_root = tmp_path / "project"
fake_root.mkdir()
(fake_root / ".git").mkdir()
(fake_root / "gateway").mkdir()
(fake_root / "gateway" / "run.py").touch()
fake_file = str(fake_root / "gateway" / "run.py")
hermes_home = tmp_path / "hermes"
hermes_home.mkdir()
with patch("gateway.run._hermes_home", hermes_home), \
patch("gateway.run.__file__", fake_file), \
patch("shutil.which", side_effect=lambda x: f"/usr/bin/{x}"), \
patch("subprocess.Popen", side_effect=OSError("spawn failed")):
result = await runner._handle_update_command(event)
assert "Failed to start update" in result
# Pending file should be cleaned up
assert not (hermes_home / ".update_pending.json").exists()
assert not (hermes_home / ".update_exit_code").exists()
@pytest.mark.asyncio
async def test_returns_user_friendly_message(self, tmp_path):
"""The success response is user-friendly."""
runner = _make_runner()
event = _make_event()
fake_root = tmp_path / "project"
fake_root.mkdir()
(fake_root / ".git").mkdir()
(fake_root / "gateway").mkdir()
(fake_root / "gateway" / "run.py").touch()
fake_file = str(fake_root / "gateway" / "run.py")
hermes_home = tmp_path / "hermes"
hermes_home.mkdir()
with patch("gateway.run._hermes_home", hermes_home), \
patch("gateway.run.__file__", fake_file), \
patch("shutil.which", side_effect=lambda x: f"/usr/bin/{x}"), \
patch("subprocess.Popen"):
result = await runner._handle_update_command(event)
assert "notify you when it's done" in result
# ---------------------------------------------------------------------------
# _send_update_notification
# ---------------------------------------------------------------------------
class TestSendUpdateNotification:
"""Tests for GatewayRunner._send_update_notification."""
@pytest.mark.asyncio
async def test_no_pending_file_is_noop(self, tmp_path):
"""Does nothing when no pending file exists."""
runner = _make_runner()
hermes_home = tmp_path / "hermes"
hermes_home.mkdir()
with patch("gateway.run._hermes_home", hermes_home):
# Should not raise
await runner._send_update_notification()
@pytest.mark.asyncio
async def test_defers_notification_while_update_still_running(self, tmp_path):
"""Returns False and keeps marker files when the update has not exited yet."""
runner = _make_runner()
hermes_home = tmp_path / "hermes"
hermes_home.mkdir()
pending_path = hermes_home / ".update_pending.json"
pending_path.write_text(json.dumps({
"platform": "telegram", "chat_id": "67890", "user_id": "12345",
}))
(hermes_home / ".update_output.txt").write_text("still running")
mock_adapter = AsyncMock()
runner.adapters = {Platform.TELEGRAM: mock_adapter}
with patch("gateway.run._hermes_home", hermes_home):
result = await runner._send_update_notification()
assert result is False
mock_adapter.send.assert_not_called()
assert pending_path.exists()
@pytest.mark.asyncio
async def test_recovers_from_claimed_pending_file(self, tmp_path):
"""A claimed pending file from a crashed notifier is still deliverable."""
runner = _make_runner()
hermes_home = tmp_path / "hermes"
hermes_home.mkdir()
claimed_path = hermes_home / ".update_pending.claimed.json"
claimed_path.write_text(json.dumps({
"platform": "telegram", "chat_id": "67890", "user_id": "12345",
}))
(hermes_home / ".update_output.txt").write_text("done")
(hermes_home / ".update_exit_code").write_text("0")
mock_adapter = AsyncMock()
runner.adapters = {Platform.TELEGRAM: mock_adapter}
with patch("gateway.run._hermes_home", hermes_home):
result = await runner._send_update_notification()
assert result is True
mock_adapter.send.assert_called_once()
assert not claimed_path.exists()
@pytest.mark.asyncio
async def test_sends_notification_with_output(self, tmp_path):
"""Sends update output to the correct platform and chat."""
runner = _make_runner()
hermes_home = tmp_path / "hermes"
hermes_home.mkdir()
# Write pending marker
pending = {
"platform": "telegram",
"chat_id": "67890",
"user_id": "12345",
"timestamp": "2026-03-04T21:00:00",
}
(hermes_home / ".update_pending.json").write_text(json.dumps(pending))
(hermes_home / ".update_output.txt").write_text(
"→ Found 3 new commit(s)\n✓ Code updated!\n✓ Update complete!"
)
(hermes_home / ".update_exit_code").write_text("0")
# Mock the adapter
mock_adapter = AsyncMock()
mock_adapter.send = AsyncMock()
runner.adapters = {Platform.TELEGRAM: mock_adapter}
with patch("gateway.run._hermes_home", hermes_home):
await runner._send_update_notification()
mock_adapter.send.assert_called_once()
call_args = mock_adapter.send.call_args
assert call_args[0][0] == "67890" # chat_id
assert "Update complete" in call_args[0][1] or "update finished" in call_args[0][1].lower()
@pytest.mark.asyncio
async def test_strips_ansi_codes(self, tmp_path):
"""ANSI escape codes are removed from output."""
runner = _make_runner()
hermes_home = tmp_path / "hermes"
hermes_home.mkdir()
pending = {"platform": "telegram", "chat_id": "111", "user_id": "222"}
(hermes_home / ".update_pending.json").write_text(json.dumps(pending))
(hermes_home / ".update_output.txt").write_text(
"\x1b[32m✓ Code updated!\x1b[0m\n\x1b[1mDone\x1b[0m"
)
(hermes_home / ".update_exit_code").write_text("0")
mock_adapter = AsyncMock()
runner.adapters = {Platform.TELEGRAM: mock_adapter}
with patch("gateway.run._hermes_home", hermes_home):
await runner._send_update_notification()
sent_text = mock_adapter.send.call_args[0][1]
assert "\x1b[" not in sent_text
assert "Code updated" in sent_text
@pytest.mark.asyncio
async def test_truncates_long_output(self, tmp_path):
"""Output longer than 3500 chars is truncated."""
runner = _make_runner()
hermes_home = tmp_path / "hermes"
hermes_home.mkdir()
pending = {"platform": "telegram", "chat_id": "111", "user_id": "222"}
(hermes_home / ".update_pending.json").write_text(json.dumps(pending))
(hermes_home / ".update_output.txt").write_text("x" * 5000)
(hermes_home / ".update_exit_code").write_text("0")
mock_adapter = AsyncMock()
runner.adapters = {Platform.TELEGRAM: mock_adapter}
with patch("gateway.run._hermes_home", hermes_home):
await runner._send_update_notification()
sent_text = mock_adapter.send.call_args[0][1]
# Should start with truncation marker
assert "" in sent_text
# Total message should not be absurdly long
assert len(sent_text) < 4500
@pytest.mark.asyncio
async def test_sends_failure_message_when_update_fails(self, tmp_path):
"""Non-zero exit codes produce a failure notification with captured output."""
runner = _make_runner()
hermes_home = tmp_path / "hermes"
hermes_home.mkdir()
pending = {"platform": "telegram", "chat_id": "111", "user_id": "222"}
(hermes_home / ".update_pending.json").write_text(json.dumps(pending))
(hermes_home / ".update_output.txt").write_text("Traceback: boom")
(hermes_home / ".update_exit_code").write_text("1")
mock_adapter = AsyncMock()
runner.adapters = {Platform.TELEGRAM: mock_adapter}
with patch("gateway.run._hermes_home", hermes_home):
result = await runner._send_update_notification()
assert result is True
sent_text = mock_adapter.send.call_args[0][1]
assert "update failed" in sent_text.lower()
assert "Traceback: boom" in sent_text
@pytest.mark.asyncio
async def test_sends_generic_message_when_no_output(self, tmp_path):
"""Sends a success message even if the output file is missing."""
runner = _make_runner()
hermes_home = tmp_path / "hermes"
hermes_home.mkdir()
pending = {"platform": "telegram", "chat_id": "111", "user_id": "222"}
(hermes_home / ".update_pending.json").write_text(json.dumps(pending))
# No .update_output.txt created
(hermes_home / ".update_exit_code").write_text("0")
mock_adapter = AsyncMock()
runner.adapters = {Platform.TELEGRAM: mock_adapter}
with patch("gateway.run._hermes_home", hermes_home):
await runner._send_update_notification()
sent_text = mock_adapter.send.call_args[0][1]
assert "finished successfully" in sent_text
@pytest.mark.asyncio
async def test_cleans_up_files_after_notification(self, tmp_path):
"""Both marker and output files are deleted after notification."""
runner = _make_runner()
hermes_home = tmp_path / "hermes"
hermes_home.mkdir()
pending_path = hermes_home / ".update_pending.json"
output_path = hermes_home / ".update_output.txt"
exit_code_path = hermes_home / ".update_exit_code"
pending_path.write_text(json.dumps({
"platform": "telegram", "chat_id": "111", "user_id": "222",
}))
output_path.write_text("✓ Done")
exit_code_path.write_text("0")
mock_adapter = AsyncMock()
runner.adapters = {Platform.TELEGRAM: mock_adapter}
with patch("gateway.run._hermes_home", hermes_home):
await runner._send_update_notification()
assert not pending_path.exists()
assert not output_path.exists()
assert not exit_code_path.exists()
@pytest.mark.asyncio
async def test_cleans_up_on_error(self, tmp_path):
"""Files are cleaned up even if notification fails."""
runner = _make_runner()
hermes_home = tmp_path / "hermes"
hermes_home.mkdir()
pending_path = hermes_home / ".update_pending.json"
output_path = hermes_home / ".update_output.txt"
exit_code_path = hermes_home / ".update_exit_code"
pending_path.write_text(json.dumps({
"platform": "telegram", "chat_id": "111", "user_id": "222",
}))
output_path.write_text("✓ Done")
exit_code_path.write_text("0")
# Adapter send raises
mock_adapter = AsyncMock()
mock_adapter.send.side_effect = RuntimeError("network error")
runner.adapters = {Platform.TELEGRAM: mock_adapter}
with patch("gateway.run._hermes_home", hermes_home):
await runner._send_update_notification()
# Files should still be cleaned up (finally block)
assert not pending_path.exists()
assert not output_path.exists()
assert not exit_code_path.exists()
@pytest.mark.asyncio
async def test_handles_corrupt_pending_file(self, tmp_path):
"""Gracefully handles a malformed pending JSON file."""
runner = _make_runner()
hermes_home = tmp_path / "hermes"
hermes_home.mkdir()
pending_path = hermes_home / ".update_pending.json"
pending_path.write_text("{corrupt json!!")
with patch("gateway.run._hermes_home", hermes_home):
# Should not raise
await runner._send_update_notification()
# File should be cleaned up
assert not pending_path.exists()
@pytest.mark.asyncio
async def test_no_adapter_for_platform(self, tmp_path):
"""Does not crash if the platform adapter is not connected."""
runner = _make_runner()
hermes_home = tmp_path / "hermes"
hermes_home.mkdir()
pending = {"platform": "discord", "chat_id": "111", "user_id": "222"}
pending_path = hermes_home / ".update_pending.json"
output_path = hermes_home / ".update_output.txt"
exit_code_path = hermes_home / ".update_exit_code"
pending_path.write_text(json.dumps(pending))
output_path.write_text("Done")
exit_code_path.write_text("0")
# Only telegram adapter available, but pending says discord
mock_adapter = AsyncMock()
runner.adapters = {Platform.TELEGRAM: mock_adapter}
with patch("gateway.run._hermes_home", hermes_home):
await runner._send_update_notification()
# send should not have been called (wrong platform)
mock_adapter.send.assert_not_called()
# Files should still be cleaned up
assert not pending_path.exists()
assert not exit_code_path.exists()
# ---------------------------------------------------------------------------
# /update in help and known_commands
# ---------------------------------------------------------------------------
class TestUpdateInHelp:
"""Verify /update appears in help text and known commands set."""
@pytest.mark.asyncio
async def test_update_in_help_output(self):
"""The /help output includes /update."""
runner = _make_runner()
event = _make_event(text="/help")
result = await runner._handle_help_command(event)
assert "/update" in result
def test_update_is_known_command(self):
"""The /update command is in the help text (proxy for _known_commands)."""
# _known_commands is local to _handle_message, so we verify by
# checking the help output includes it.
from gateway.run import GatewayRunner
import inspect
source = inspect.getsource(GatewayRunner._handle_message)
assert '"update"' in source

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,619 @@
"""Unit tests for the generic webhook platform adapter.
Covers:
- HMAC signature validation (GitHub, GitLab, generic)
- Prompt rendering with dot-notation template variables
- Event type filtering
- HTTP handler behaviour (404, 202, health)
- Idempotency cache (duplicate delivery IDs)
- Rate limiting (fixed-window, per route)
- Body size limits
- INSECURE_NO_AUTH bypass
- Session isolation for concurrent webhooks
- Delivery info cleanup after send()
- connect / disconnect lifecycle
"""
import asyncio
import hashlib
import hmac
import json
import time
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from aiohttp import web
from aiohttp.test_utils import TestClient, TestServer
from gateway.config import Platform, PlatformConfig
from gateway.platforms.base import MessageEvent, MessageType, SendResult
from gateway.platforms.webhook import (
WebhookAdapter,
_INSECURE_NO_AUTH,
check_webhook_requirements,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_config(
routes=None,
secret="",
rate_limit=30,
max_body_bytes=1_048_576,
host="0.0.0.0",
port=0, # let OS pick a free port in tests
):
"""Build a PlatformConfig suitable for WebhookAdapter."""
extra = {
"host": host,
"port": port,
"routes": routes or {},
"rate_limit": rate_limit,
"max_body_bytes": max_body_bytes,
}
if secret:
extra["secret"] = secret
return PlatformConfig(enabled=True, extra=extra)
def _make_adapter(routes=None, **kwargs):
"""Create a WebhookAdapter with sensible defaults for testing."""
config = _make_config(routes=routes, **kwargs)
return WebhookAdapter(config)
def _create_app(adapter: WebhookAdapter) -> web.Application:
"""Build the aiohttp Application from the adapter (without starting a full server)."""
app = web.Application()
app.router.add_get("/health", adapter._handle_health)
app.router.add_post("/webhooks/{route_name}", adapter._handle_webhook)
return app
def _mock_request(headers=None, body=b"", content_length=None, match_info=None):
"""Build a lightweight mock aiohttp request for non-HTTP tests."""
req = MagicMock()
req.headers = headers or {}
req.content_length = content_length if content_length is not None else len(body)
req.match_info = match_info or {}
req.method = "POST"
async def _read():
return body
req.read = _read
return req
def _github_signature(body: bytes, secret: str) -> str:
"""Compute X-Hub-Signature-256 for *body* using *secret*."""
return "sha256=" + hmac.new(
secret.encode(), body, hashlib.sha256
).hexdigest()
def _generic_signature(body: bytes, secret: str) -> str:
"""Compute X-Webhook-Signature (plain HMAC-SHA256 hex) for *body*."""
return hmac.new(secret.encode(), body, hashlib.sha256).hexdigest()
# ===================================================================
# Signature validation
# ===================================================================
class TestValidateSignature:
"""Tests for WebhookAdapter._validate_signature."""
def test_validate_github_signature_valid(self):
"""Valid X-Hub-Signature-256 is accepted."""
adapter = _make_adapter()
body = b'{"action": "opened"}'
secret = "webhook-secret-42"
sig = _github_signature(body, secret)
req = _mock_request(headers={"X-Hub-Signature-256": sig})
assert adapter._validate_signature(req, body, secret) is True
def test_validate_github_signature_invalid(self):
"""Wrong X-Hub-Signature-256 is rejected."""
adapter = _make_adapter()
body = b'{"action": "opened"}'
secret = "webhook-secret-42"
req = _mock_request(headers={"X-Hub-Signature-256": "sha256=deadbeef"})
assert adapter._validate_signature(req, body, secret) is False
def test_validate_gitlab_token(self):
"""GitLab plain-token match via X-Gitlab-Token."""
adapter = _make_adapter()
secret = "gl-token-value"
req = _mock_request(headers={"X-Gitlab-Token": secret})
assert adapter._validate_signature(req, b"{}", secret) is True
def test_validate_gitlab_token_wrong(self):
"""Wrong X-Gitlab-Token is rejected."""
adapter = _make_adapter()
req = _mock_request(headers={"X-Gitlab-Token": "wrong"})
assert adapter._validate_signature(req, b"{}", "correct") is False
def test_validate_no_signature_with_secret_rejects(self):
"""Secret configured but no recognised signature header → reject."""
adapter = _make_adapter()
req = _mock_request(headers={}) # no sig headers at all
assert adapter._validate_signature(req, b"{}", "my-secret") is False
def test_validate_no_secret_allows_all(self):
"""When the secret is empty/falsy, the validator is never even called
by the handler (secret check is 'if secret and secret != _INSECURE...').
Verify that an empty secret isn't accidentally passed to the validator."""
# This tests the semantics: empty secret means skip validation entirely.
# The handler code does: if secret and secret != _INSECURE_NO_AUTH: validate
# So with an empty secret, _validate_signature is never reached.
# We just verify the code path is correct by constructing an adapter
# with no secret and confirming the route config resolves to "".
adapter = _make_adapter(
routes={"test": {"prompt": "hello"}},
secret="",
)
# The route has no secret, global secret is empty
route_secret = adapter._routes["test"].get("secret", adapter._global_secret)
assert not route_secret # empty → validation is skipped in handler
def test_validate_generic_signature_valid(self):
"""Valid X-Webhook-Signature (generic HMAC-SHA256 hex) is accepted."""
adapter = _make_adapter()
body = b'{"event": "push"}'
secret = "generic-secret"
sig = _generic_signature(body, secret)
req = _mock_request(headers={"X-Webhook-Signature": sig})
assert adapter._validate_signature(req, body, secret) is True
# ===================================================================
# Prompt rendering
# ===================================================================
class TestRenderPrompt:
"""Tests for WebhookAdapter._render_prompt."""
def test_render_prompt_dot_notation(self):
"""Dot-notation {pull_request.title} resolves nested keys."""
adapter = _make_adapter()
payload = {"pull_request": {"title": "Fix bug", "number": 42}}
result = adapter._render_prompt(
"PR #{pull_request.number}: {pull_request.title}",
payload,
"pull_request",
"github",
)
assert result == "PR #42: Fix bug"
def test_render_prompt_missing_key_preserved(self):
"""{nonexistent} is left as-is when key doesn't exist in payload."""
adapter = _make_adapter()
result = adapter._render_prompt(
"Hello {nonexistent}!",
{"action": "opened"},
"push",
"test",
)
assert "{nonexistent}" in result
def test_render_prompt_no_template_dumps_json(self):
"""Empty template → JSON dump fallback with event/route context."""
adapter = _make_adapter()
payload = {"key": "value"}
result = adapter._render_prompt("", payload, "push", "my-route")
assert "push" in result
assert "my-route" in result
assert "key" in result
# ===================================================================
# Delivery extra rendering
# ===================================================================
class TestRenderDeliveryExtra:
def test_render_delivery_extra_templates(self):
"""String values in deliver_extra are rendered with payload data."""
adapter = _make_adapter()
extra = {"repo": "{repository.full_name}", "pr_number": "{number}", "static": 42}
payload = {"repository": {"full_name": "org/repo"}, "number": 7}
result = adapter._render_delivery_extra(extra, payload)
assert result["repo"] == "org/repo"
assert result["pr_number"] == "7"
assert result["static"] == 42 # non-string left as-is
# ===================================================================
# Event filtering
# ===================================================================
class TestEventFilter:
"""Tests for event type filtering in _handle_webhook."""
@pytest.mark.asyncio
async def test_event_filter_accepts_matching(self):
"""Matching event type passes through."""
routes = {
"gh": {
"secret": _INSECURE_NO_AUTH,
"events": ["pull_request"],
"prompt": "PR: {action}",
}
}
adapter = _make_adapter(routes=routes)
# Stub handle_message to avoid running the agent
adapter.handle_message = AsyncMock()
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
resp = await cli.post(
"/webhooks/gh",
json={"action": "opened"},
headers={"X-GitHub-Event": "pull_request"},
)
assert resp.status == 202
@pytest.mark.asyncio
async def test_event_filter_rejects_non_matching(self):
"""Non-matching event type returns 200 with status=ignored."""
routes = {
"gh": {
"secret": _INSECURE_NO_AUTH,
"events": ["pull_request"],
"prompt": "test",
}
}
adapter = _make_adapter(routes=routes)
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
resp = await cli.post(
"/webhooks/gh",
json={"action": "opened"},
headers={"X-GitHub-Event": "push"},
)
assert resp.status == 200
data = await resp.json()
assert data["status"] == "ignored"
@pytest.mark.asyncio
async def test_event_filter_empty_allows_all(self):
"""No events list → accept any event type."""
routes = {
"all": {
"secret": _INSECURE_NO_AUTH,
"prompt": "got it",
}
}
adapter = _make_adapter(routes=routes)
adapter.handle_message = AsyncMock()
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
resp = await cli.post(
"/webhooks/all",
json={"action": "any"},
headers={"X-GitHub-Event": "whatever"},
)
assert resp.status == 202
# ===================================================================
# HTTP handling
# ===================================================================
class TestHTTPHandling:
@pytest.mark.asyncio
async def test_unknown_route_returns_404(self):
"""POST to an unknown route returns 404."""
adapter = _make_adapter(routes={"real": {"secret": _INSECURE_NO_AUTH, "prompt": "x"}})
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
resp = await cli.post("/webhooks/nonexistent", json={"a": 1})
assert resp.status == 404
@pytest.mark.asyncio
async def test_webhook_handler_returns_202(self):
"""Valid request returns 202 Accepted."""
routes = {"test": {"secret": _INSECURE_NO_AUTH, "prompt": "hi"}}
adapter = _make_adapter(routes=routes)
adapter.handle_message = AsyncMock()
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
resp = await cli.post("/webhooks/test", json={"data": "value"})
assert resp.status == 202
data = await resp.json()
assert data["status"] == "accepted"
assert data["route"] == "test"
@pytest.mark.asyncio
async def test_health_endpoint(self):
"""GET /health returns 200 with status=ok."""
adapter = _make_adapter()
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
resp = await cli.get("/health")
assert resp.status == 200
data = await resp.json()
assert data["status"] == "ok"
assert data["platform"] == "webhook"
@pytest.mark.asyncio
async def test_connect_starts_server(self):
"""connect() starts the HTTP listener and marks adapter as connected."""
routes = {"r1": {"secret": _INSECURE_NO_AUTH, "prompt": "x"}}
adapter = _make_adapter(routes=routes, port=0)
# Use port 0 — the OS picks a free port, but aiohttp requires a real bind.
# We just test that the method completes and marks connected.
# Need to mock TCPSite to avoid actual binding.
with patch("gateway.platforms.webhook.web.AppRunner") as MockRunner, \
patch("gateway.platforms.webhook.web.TCPSite") as MockSite:
mock_runner_inst = AsyncMock()
MockRunner.return_value = mock_runner_inst
mock_site_inst = AsyncMock()
MockSite.return_value = mock_site_inst
result = await adapter.connect()
assert result is True
assert adapter.is_connected
mock_runner_inst.setup.assert_awaited_once()
mock_site_inst.start.assert_awaited_once()
await adapter.disconnect()
@pytest.mark.asyncio
async def test_disconnect_cleans_up(self):
"""disconnect() stops the server and marks adapter disconnected."""
adapter = _make_adapter()
# Simulate a runner that was previously set up
mock_runner = AsyncMock()
adapter._runner = mock_runner
adapter._running = True
await adapter.disconnect()
mock_runner.cleanup.assert_awaited_once()
assert adapter._runner is None
assert not adapter.is_connected
# ===================================================================
# Idempotency
# ===================================================================
class TestIdempotency:
@pytest.mark.asyncio
async def test_duplicate_delivery_id_returns_200(self):
"""Second request with same delivery ID returns 200 duplicate."""
routes = {"idem": {"secret": _INSECURE_NO_AUTH, "prompt": "test"}}
adapter = _make_adapter(routes=routes)
adapter.handle_message = AsyncMock()
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
headers = {"X-GitHub-Delivery": "delivery-123"}
resp1 = await cli.post("/webhooks/idem", json={"a": 1}, headers=headers)
assert resp1.status == 202
resp2 = await cli.post("/webhooks/idem", json={"a": 1}, headers=headers)
assert resp2.status == 200
data = await resp2.json()
assert data["status"] == "duplicate"
@pytest.mark.asyncio
async def test_expired_delivery_id_allows_reprocess(self):
"""After TTL expires, the same delivery ID is accepted again."""
routes = {"idem": {"secret": _INSECURE_NO_AUTH, "prompt": "test"}}
adapter = _make_adapter(routes=routes)
adapter._idempotency_ttl = 1 # 1 second TTL for test speed
adapter.handle_message = AsyncMock()
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
headers = {"X-GitHub-Delivery": "delivery-456"}
resp1 = await cli.post("/webhooks/idem", json={"x": 1}, headers=headers)
assert resp1.status == 202
# Backdate the cache entry so it appears expired
adapter._seen_deliveries["delivery-456"] = time.time() - 3700
resp2 = await cli.post("/webhooks/idem", json={"x": 1}, headers=headers)
assert resp2.status == 202 # re-accepted
# ===================================================================
# Rate limiting
# ===================================================================
class TestRateLimiting:
@pytest.mark.asyncio
async def test_rate_limit_rejects_excess(self):
"""Exceeding the rate limit returns 429."""
routes = {"limited": {"secret": _INSECURE_NO_AUTH, "prompt": "test"}}
adapter = _make_adapter(routes=routes, rate_limit=2)
adapter.handle_message = AsyncMock()
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
# Two requests within limit
for i in range(2):
resp = await cli.post(
"/webhooks/limited",
json={"n": i},
headers={"X-GitHub-Delivery": f"d-{i}"},
)
assert resp.status == 202, f"Request {i} should be accepted"
# Third request should be rate-limited
resp = await cli.post(
"/webhooks/limited",
json={"n": 99},
headers={"X-GitHub-Delivery": "d-99"},
)
assert resp.status == 429
@pytest.mark.asyncio
async def test_rate_limit_window_resets(self):
"""After the 60-second window passes, requests are allowed again."""
routes = {"limited": {"secret": _INSECURE_NO_AUTH, "prompt": "test"}}
adapter = _make_adapter(routes=routes, rate_limit=1)
adapter.handle_message = AsyncMock()
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
resp = await cli.post(
"/webhooks/limited",
json={"n": 1},
headers={"X-GitHub-Delivery": "d-a"},
)
assert resp.status == 202
# Backdate all rate-limit timestamps to > 60 seconds ago
adapter._rate_counts["limited"] = [time.time() - 120]
resp = await cli.post(
"/webhooks/limited",
json={"n": 2},
headers={"X-GitHub-Delivery": "d-b"},
)
assert resp.status == 202 # allowed again
# ===================================================================
# Body size limit
# ===================================================================
class TestBodySize:
@pytest.mark.asyncio
async def test_oversized_payload_rejected(self):
"""Content-Length > max_body_bytes returns 413."""
routes = {"big": {"secret": _INSECURE_NO_AUTH, "prompt": "test"}}
adapter = _make_adapter(routes=routes, max_body_bytes=100)
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
large_payload = {"data": "x" * 200}
resp = await cli.post(
"/webhooks/big",
json=large_payload,
headers={"Content-Length": "999999"},
)
assert resp.status == 413
# ===================================================================
# INSECURE_NO_AUTH
# ===================================================================
class TestInsecureNoAuth:
@pytest.mark.asyncio
async def test_insecure_no_auth_skips_validation(self):
"""Setting secret to _INSECURE_NO_AUTH bypasses signature check."""
routes = {"open": {"secret": _INSECURE_NO_AUTH, "prompt": "hello"}}
adapter = _make_adapter(routes=routes)
adapter.handle_message = AsyncMock()
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
# No signature header at all — should still be accepted
resp = await cli.post("/webhooks/open", json={"test": True})
assert resp.status == 202
# ===================================================================
# Session isolation
# ===================================================================
class TestSessionIsolation:
@pytest.mark.asyncio
async def test_concurrent_webhooks_get_independent_sessions(self):
"""Two events on the same route produce different session keys."""
routes = {"ci": {"secret": _INSECURE_NO_AUTH, "prompt": "build"}}
adapter = _make_adapter(routes=routes)
captured_events = []
async def _capture(event):
captured_events.append(event)
adapter.handle_message = _capture
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
resp1 = await cli.post(
"/webhooks/ci",
json={"ref": "main"},
headers={"X-GitHub-Delivery": "aaa-111"},
)
assert resp1.status == 202
resp2 = await cli.post(
"/webhooks/ci",
json={"ref": "dev"},
headers={"X-GitHub-Delivery": "bbb-222"},
)
assert resp2.status == 202
# Wait for the async tasks to be created
await asyncio.sleep(0.05)
assert len(captured_events) == 2
ids = {ev.source.chat_id for ev in captured_events}
assert len(ids) == 2, "Each delivery must have a unique session chat_id"
# ===================================================================
# Delivery info cleanup
# ===================================================================
class TestDeliveryCleanup:
@pytest.mark.asyncio
async def test_delivery_info_cleaned_after_send(self):
"""send() pops delivery_info so the entry doesn't leak memory."""
adapter = _make_adapter()
chat_id = "webhook:test:d-xyz"
adapter._delivery_info[chat_id] = {
"deliver": "log",
"deliver_extra": {},
"payload": {"x": 1},
}
result = await adapter.send(chat_id, "Agent response here")
assert result.success is True
assert chat_id not in adapter._delivery_info
# ===================================================================
# check_webhook_requirements
# ===================================================================
class TestCheckRequirements:
def test_returns_true_when_aiohttp_available(self):
assert check_webhook_requirements() is True
@patch("gateway.platforms.webhook.AIOHTTP_AVAILABLE", False)
def test_returns_false_without_aiohttp(self):
assert check_webhook_requirements() is False

View file

@ -0,0 +1,337 @@
"""Integration tests for the generic webhook platform adapter.
These tests exercise end-to-end flows through the webhook adapter:
1. GitHub PR webhook agent MessageEvent created
2. Skills config injects skill content into the prompt
3. Cross-platform delivery routes to a mock Telegram adapter
4. GitHub comment delivery invokes ``gh`` CLI (mocked subprocess)
"""
import asyncio
import hashlib
import hmac
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from aiohttp import web
from aiohttp.test_utils import TestClient, TestServer
from gateway.config import (
GatewayConfig,
HomeChannel,
Platform,
PlatformConfig,
)
from gateway.platforms.base import MessageEvent, MessageType, SendResult
from gateway.platforms.webhook import WebhookAdapter, _INSECURE_NO_AUTH
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_adapter(routes, **extra_kw) -> WebhookAdapter:
"""Create a WebhookAdapter with the given routes."""
extra = {"host": "0.0.0.0", "port": 0, "routes": routes}
extra.update(extra_kw)
config = PlatformConfig(enabled=True, extra=extra)
return WebhookAdapter(config)
def _create_app(adapter: WebhookAdapter) -> web.Application:
"""Build the aiohttp Application from the adapter."""
app = web.Application()
app.router.add_get("/health", adapter._handle_health)
app.router.add_post("/webhooks/{route_name}", adapter._handle_webhook)
return app
def _github_signature(body: bytes, secret: str) -> str:
"""Compute X-Hub-Signature-256 for *body* using *secret*."""
return "sha256=" + hmac.new(
secret.encode(), body, hashlib.sha256
).hexdigest()
# A realistic GitHub pull_request event payload (trimmed)
GITHUB_PR_PAYLOAD = {
"action": "opened",
"number": 42,
"pull_request": {
"title": "Add webhook adapter",
"body": "This PR adds a generic webhook platform adapter.",
"html_url": "https://github.com/org/repo/pull/42",
"user": {"login": "contributor"},
"head": {"ref": "feature/webhooks"},
"base": {"ref": "main"},
},
"repository": {
"full_name": "org/repo",
"html_url": "https://github.com/org/repo",
},
"sender": {"login": "contributor"},
}
# ===================================================================
# Test 1: GitHub PR webhook triggers agent
# ===================================================================
class TestGitHubPRWebhook:
@pytest.mark.asyncio
async def test_github_pr_webhook_triggers_agent(self):
"""POST with a realistic GitHub PR payload should:
1. Return 202 Accepted
2. Call handle_message with a MessageEvent
3. The event text contains the rendered prompt
4. The event source has chat_type 'webhook'
"""
secret = "gh-webhook-test-secret"
routes = {
"github-pr": {
"secret": secret,
"events": ["pull_request"],
"prompt": (
"Review PR #{number} by {sender.login}: "
"{pull_request.title}\n\n{pull_request.body}"
),
"deliver": "log",
}
}
adapter = _make_adapter(routes)
captured_events: list[MessageEvent] = []
async def _capture(event: MessageEvent):
captured_events.append(event)
adapter.handle_message = _capture
app = _create_app(adapter)
body = json.dumps(GITHUB_PR_PAYLOAD).encode()
sig = _github_signature(body, secret)
async with TestClient(TestServer(app)) as cli:
resp = await cli.post(
"/webhooks/github-pr",
data=body,
headers={
"Content-Type": "application/json",
"X-GitHub-Event": "pull_request",
"X-Hub-Signature-256": sig,
"X-GitHub-Delivery": "gh-delivery-001",
},
)
assert resp.status == 202
data = await resp.json()
assert data["status"] == "accepted"
assert data["route"] == "github-pr"
assert data["event"] == "pull_request"
assert data["delivery_id"] == "gh-delivery-001"
# Let the asyncio.create_task fire
await asyncio.sleep(0.05)
assert len(captured_events) == 1
event = captured_events[0]
assert "Review PR #42 by contributor" in event.text
assert "Add webhook adapter" in event.text
assert event.source.chat_type == "webhook"
assert event.source.platform == Platform.WEBHOOK
assert "github-pr" in event.source.chat_id
assert event.message_id == "gh-delivery-001"
# ===================================================================
# Test 2: Skills injected into prompt
# ===================================================================
class TestSkillsInjection:
@pytest.mark.asyncio
async def test_skills_injected_into_prompt(self):
"""When a route has skills: [code-review], the adapter should
call build_skill_invocation_message() and use its output as the
prompt instead of the raw template render."""
routes = {
"pr-review": {
"secret": _INSECURE_NO_AUTH,
"events": ["pull_request"],
"prompt": "Review this PR: {pull_request.title}",
"skills": ["code-review"],
}
}
adapter = _make_adapter(routes)
captured_events: list[MessageEvent] = []
async def _capture(event: MessageEvent):
captured_events.append(event)
adapter.handle_message = _capture
skill_content = (
"You are a code reviewer. Review the following:\n"
"Review this PR: Add webhook adapter"
)
# The imports are lazy (inside the handler), so patch the source module
with patch(
"agent.skill_commands.build_skill_invocation_message",
return_value=skill_content,
) as mock_build, patch(
"agent.skill_commands.get_skill_commands",
return_value={"/code-review": {"name": "code-review"}},
):
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
resp = await cli.post(
"/webhooks/pr-review",
json=GITHUB_PR_PAYLOAD,
headers={
"X-GitHub-Event": "pull_request",
"X-GitHub-Delivery": "skill-test-001",
},
)
assert resp.status == 202
await asyncio.sleep(0.05)
assert len(captured_events) == 1
event = captured_events[0]
# The prompt should be the skill content, not the raw template
assert "You are a code reviewer" in event.text
mock_build.assert_called_once()
# ===================================================================
# Test 3: Cross-platform delivery (webhook → Telegram)
# ===================================================================
class TestCrossPlatformDelivery:
@pytest.mark.asyncio
async def test_cross_platform_delivery(self):
"""When deliver='telegram', the response is routed to the
Telegram adapter via gateway_runner.adapters."""
routes = {
"alerts": {
"secret": _INSECURE_NO_AUTH,
"prompt": "Alert: {message}",
"deliver": "telegram",
"deliver_extra": {"chat_id": "12345"},
}
}
adapter = _make_adapter(routes)
adapter.handle_message = AsyncMock()
# Set up a mock gateway runner with a mock Telegram adapter
mock_tg_adapter = AsyncMock()
mock_tg_adapter.send = AsyncMock(return_value=SendResult(success=True))
mock_runner = MagicMock()
mock_runner.adapters = {Platform.TELEGRAM: mock_tg_adapter}
mock_runner.config = GatewayConfig(
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="fake")}
)
adapter.gateway_runner = mock_runner
# First, simulate a webhook POST to set up delivery_info
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
resp = await cli.post(
"/webhooks/alerts",
json={"message": "Server is on fire!"},
headers={"X-GitHub-Delivery": "alert-001"},
)
assert resp.status == 202
# The adapter should have stored delivery info
chat_id = "webhook:alerts:alert-001"
assert chat_id in adapter._delivery_info
# Now call send() as if the agent has finished
result = await adapter.send(chat_id, "I've acknowledged the alert.")
assert result.success is True
mock_tg_adapter.send.assert_awaited_once_with(
"12345", "I've acknowledged the alert."
)
# Delivery info should be cleaned up
assert chat_id not in adapter._delivery_info
# ===================================================================
# Test 4: GitHub comment delivery via gh CLI
# ===================================================================
class TestGitHubCommentDelivery:
@pytest.mark.asyncio
async def test_github_comment_delivery(self):
"""When deliver='github_comment', the adapter invokes
``gh pr comment`` via subprocess.run (mocked)."""
routes = {
"pr-bot": {
"secret": _INSECURE_NO_AUTH,
"prompt": "Review: {pull_request.title}",
"deliver": "github_comment",
"deliver_extra": {
"repo": "{repository.full_name}",
"pr_number": "{number}",
},
}
}
adapter = _make_adapter(routes)
adapter.handle_message = AsyncMock()
# POST a webhook to set up delivery info
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
resp = await cli.post(
"/webhooks/pr-bot",
json=GITHUB_PR_PAYLOAD,
headers={
"X-GitHub-Event": "pull_request",
"X-GitHub-Delivery": "gh-comment-001",
},
)
assert resp.status == 202
chat_id = "webhook:pr-bot:gh-comment-001"
assert chat_id in adapter._delivery_info
# Verify deliver_extra was rendered with payload data
delivery = adapter._delivery_info[chat_id]
assert delivery["deliver_extra"]["repo"] == "org/repo"
assert delivery["deliver_extra"]["pr_number"] == "42"
# Mock subprocess.run and call send()
mock_result = MagicMock()
mock_result.returncode = 0
mock_result.stdout = "Comment posted"
mock_result.stderr = ""
with patch(
"gateway.platforms.webhook.subprocess.run",
return_value=mock_result,
) as mock_run:
result = await adapter.send(
chat_id, "LGTM! The code looks great."
)
assert result.success is True
mock_run.assert_called_once_with(
[
"gh", "pr", "comment", "42",
"--repo", "org/repo",
"--body", "LGTM! The code looks great.",
],
capture_output=True,
text=True,
timeout=30,
)
# Delivery info cleaned up
assert chat_id not in adapter._delivery_info

View file

@ -0,0 +1,419 @@
"""Tests for WhatsApp connect() error handling.
Regression tests for two bugs in WhatsAppAdapter.connect():
1. Uninitialized ``data`` variable: when ``resp.json()`` raised after the
health endpoint returned HTTP 200, ``http_ready`` was set to True but
``data`` was never assigned. The subsequent ``data.get("status")``
check raised ``NameError``.
2. Bridge log file handle leaked on error paths: the file was opened before
the health-check loop but never closed when ``connect()`` returned False.
Repeated connection failures accumulated open file descriptors.
"""
import asyncio
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from gateway.config import Platform
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
class _AsyncCM:
"""Minimal async context manager returning a fixed value."""
def __init__(self, value):
self.value = value
async def __aenter__(self):
return self.value
async def __aexit__(self, *exc):
return False
def _make_adapter():
"""Create a WhatsAppAdapter with test attributes (bypass __init__)."""
from gateway.platforms.whatsapp import WhatsAppAdapter
adapter = WhatsAppAdapter.__new__(WhatsAppAdapter)
adapter.platform = Platform.WHATSAPP
adapter.config = MagicMock()
adapter._bridge_port = 19876
adapter._bridge_script = "/tmp/test-bridge.js"
adapter._session_path = Path("/tmp/test-wa-session")
adapter._bridge_log_fh = None
adapter._bridge_log = None
adapter._bridge_process = None
adapter._reply_prefix = None
adapter._running = False
adapter._message_handler = None
adapter._fatal_error_code = None
adapter._fatal_error_message = None
adapter._fatal_error_retryable = True
adapter._fatal_error_handler = None
adapter._active_sessions = {}
adapter._pending_messages = {}
adapter._background_tasks = set()
adapter._auto_tts_disabled_chats = set()
adapter._message_queue = asyncio.Queue()
return adapter
def _mock_aiohttp(status=200, json_data=None, json_side_effect=None):
"""Build a mock ``aiohttp.ClientSession`` returning a fixed response."""
mock_resp = MagicMock()
mock_resp.status = status
if json_side_effect:
mock_resp.json = AsyncMock(side_effect=json_side_effect)
else:
mock_resp.json = AsyncMock(return_value=json_data or {})
mock_session = MagicMock()
mock_session.get = MagicMock(return_value=_AsyncCM(mock_resp))
return MagicMock(return_value=_AsyncCM(mock_session))
def _connect_patches(mock_proc, mock_fh, mock_client_cls=None):
"""Return a dict of common patches needed to reach the health-check loop."""
patches = {
"gateway.platforms.whatsapp.check_whatsapp_requirements": True,
"gateway.platforms.whatsapp.asyncio.create_task": MagicMock(),
}
base = [
patch("gateway.platforms.whatsapp.check_whatsapp_requirements", return_value=True),
patch.object(Path, "exists", return_value=True),
patch.object(Path, "mkdir", return_value=None),
patch("subprocess.run", return_value=MagicMock(returncode=0)),
patch("subprocess.Popen", return_value=mock_proc),
patch("builtins.open", return_value=mock_fh),
patch("gateway.platforms.whatsapp.asyncio.sleep", new_callable=AsyncMock),
patch("gateway.platforms.whatsapp.asyncio.create_task"),
]
if mock_client_cls is not None:
base.append(patch("aiohttp.ClientSession", mock_client_cls))
return base
# ---------------------------------------------------------------------------
# _close_bridge_log() unit tests
# ---------------------------------------------------------------------------
class TestCloseBridgeLog:
"""Direct tests for the _close_bridge_log() helper method."""
@staticmethod
def _bare_adapter():
from gateway.platforms.whatsapp import WhatsAppAdapter
a = WhatsAppAdapter.__new__(WhatsAppAdapter)
a._bridge_log_fh = None
return a
def test_closes_open_handle(self):
adapter = self._bare_adapter()
mock_fh = MagicMock()
adapter._bridge_log_fh = mock_fh
adapter._close_bridge_log()
mock_fh.close.assert_called_once()
assert adapter._bridge_log_fh is None
def test_noop_when_no_handle(self):
adapter = self._bare_adapter()
adapter._close_bridge_log() # must not raise
assert adapter._bridge_log_fh is None
def test_suppresses_close_exception(self):
adapter = self._bare_adapter()
mock_fh = MagicMock()
mock_fh.close.side_effect = OSError("already closed")
adapter._bridge_log_fh = mock_fh
adapter._close_bridge_log() # must not raise
assert adapter._bridge_log_fh is None
# ---------------------------------------------------------------------------
# data variable initialization
# ---------------------------------------------------------------------------
class TestDataInitialized:
"""Verify ``data = {}`` prevents NameError when resp.json() fails."""
@pytest.mark.asyncio
async def test_no_name_error_when_json_always_fails(self):
"""HTTP 200 sets http_ready but json() always raises.
Without the fix, ``data`` was never assigned and the Phase 2 check
``data.get("status")`` raised NameError. With ``data = {}``, the
check evaluates to ``None != "connected"`` and Phase 2 runs normally.
"""
adapter = _make_adapter()
mock_proc = MagicMock()
mock_proc.poll.return_value = None # bridge stays alive
mock_client_cls = _mock_aiohttp(
status=200, json_side_effect=ValueError("bad json"),
)
mock_fh = MagicMock()
patches = _connect_patches(mock_proc, mock_fh, mock_client_cls)
with patches[0], patches[1], patches[2], patches[3], patches[4], \
patches[5], patches[6], patches[7], patches[8], \
patch.object(type(adapter), "_poll_messages", return_value=MagicMock()):
# Must NOT raise NameError
result = await adapter.connect()
# connect() returns True (warn-and-proceed path)
assert result is True
assert adapter._running is True
# ---------------------------------------------------------------------------
# File handle cleanup on error paths
# ---------------------------------------------------------------------------
class TestFileHandleClosedOnError:
"""Verify the bridge log file handle is closed on every failure path."""
@pytest.mark.asyncio
async def test_closed_when_bridge_dies_phase1(self):
"""Bridge process exits during Phase 1 health-check loop."""
adapter = _make_adapter()
mock_proc = MagicMock()
mock_proc.poll.return_value = 1 # dead immediately
mock_proc.returncode = 1
mock_fh = MagicMock()
patches = _connect_patches(mock_proc, mock_fh)
with patches[0], patches[1], patches[2], patches[3], patches[4], \
patches[5], patches[6], patches[7]:
result = await adapter.connect()
assert result is False
mock_fh.close.assert_called_once()
assert adapter._bridge_log_fh is None
class TestBridgeRuntimeFailure:
"""Verify runtime bridge death is surfaced as a fatal adapter error."""
@pytest.mark.asyncio
async def test_send_marks_retryable_fatal_when_managed_bridge_exits(self):
adapter = _make_adapter()
fatal_handler = AsyncMock()
adapter.set_fatal_error_handler(fatal_handler)
adapter._running = True
mock_fh = MagicMock()
adapter._bridge_log_fh = mock_fh
mock_proc = MagicMock()
mock_proc.poll.return_value = 7
adapter._bridge_process = mock_proc
result = await adapter.send("chat-123", "hello")
assert result.success is False
assert "exited unexpectedly" in result.error
assert adapter.fatal_error_code == "whatsapp_bridge_exited"
assert adapter.fatal_error_retryable is True
fatal_handler.assert_awaited_once()
mock_fh.close.assert_called_once()
assert adapter._bridge_log_fh is None
@pytest.mark.asyncio
async def test_poll_messages_marks_retryable_fatal_when_managed_bridge_exits(self):
adapter = _make_adapter()
fatal_handler = AsyncMock()
adapter.set_fatal_error_handler(fatal_handler)
adapter._running = True
mock_fh = MagicMock()
adapter._bridge_log_fh = mock_fh
mock_proc = MagicMock()
mock_proc.poll.return_value = 23
adapter._bridge_process = mock_proc
await adapter._poll_messages()
assert adapter.fatal_error_code == "whatsapp_bridge_exited"
assert adapter.fatal_error_retryable is True
fatal_handler.assert_awaited_once()
mock_fh.close.assert_called_once()
assert adapter._bridge_log_fh is None
@pytest.mark.asyncio
async def test_closed_when_http_not_ready(self):
"""Health endpoint never returns 200 within 15 attempts."""
adapter = _make_adapter()
mock_proc = MagicMock()
mock_proc.poll.return_value = None # bridge alive
mock_client_cls = _mock_aiohttp(status=503)
mock_fh = MagicMock()
patches = _connect_patches(mock_proc, mock_fh, mock_client_cls)
with patches[0], patches[1], patches[2], patches[3], patches[4], \
patches[5], patches[6], patches[7], patches[8]:
result = await adapter.connect()
assert result is False
mock_fh.close.assert_called_once()
assert adapter._bridge_log_fh is None
@pytest.mark.asyncio
async def test_closed_when_bridge_dies_phase2(self):
"""Bridge alive during Phase 1 but dies during Phase 2."""
adapter = _make_adapter()
# Phase 1 (15 iterations): alive. Phase 2 (iteration 16): dead.
call_count = [0]
def poll_side_effect():
call_count[0] += 1
return None if call_count[0] <= 15 else 1
mock_proc = MagicMock()
mock_proc.poll.side_effect = poll_side_effect
mock_proc.returncode = 1
# Health returns 200 with status != "connected" -> triggers Phase 2
mock_client_cls = _mock_aiohttp(
status=200, json_data={"status": "disconnected"},
)
mock_fh = MagicMock()
patches = _connect_patches(mock_proc, mock_fh, mock_client_cls)
with patches[0], patches[1], patches[2], patches[3], patches[4], \
patches[5], patches[6], patches[7], patches[8]:
result = await adapter.connect()
assert result is False
mock_fh.close.assert_called_once()
assert adapter._bridge_log_fh is None
@pytest.mark.asyncio
async def test_closed_on_unexpected_exception(self):
"""Popen raises, outer except block must still close the handle."""
adapter = _make_adapter()
mock_fh = MagicMock()
with patch("gateway.platforms.whatsapp.check_whatsapp_requirements", return_value=True), \
patch.object(Path, "exists", return_value=True), \
patch.object(Path, "mkdir", return_value=None), \
patch("subprocess.run", return_value=MagicMock(returncode=0)), \
patch("subprocess.Popen", side_effect=OSError("spawn failed")), \
patch("builtins.open", return_value=mock_fh):
result = await adapter.connect()
assert result is False
mock_fh.close.assert_called_once()
assert adapter._bridge_log_fh is None
# ---------------------------------------------------------------------------
# _kill_port_process() cross-platform tests
# ---------------------------------------------------------------------------
class TestKillPortProcess:
"""Verify _kill_port_process uses platform-appropriate commands."""
def test_uses_netstat_and_taskkill_on_windows(self):
from gateway.platforms.whatsapp import _kill_port_process
netstat_output = (
" Proto Local Address Foreign Address State PID\n"
" TCP 0.0.0.0:3000 0.0.0.0:0 LISTENING 12345\n"
" TCP 0.0.0.0:3001 0.0.0.0:0 LISTENING 99999\n"
)
mock_netstat = MagicMock(stdout=netstat_output)
mock_taskkill = MagicMock()
def run_side_effect(cmd, **kwargs):
if cmd[0] == "netstat":
return mock_netstat
if cmd[0] == "taskkill":
return mock_taskkill
return MagicMock()
with patch("gateway.platforms.whatsapp._IS_WINDOWS", True), \
patch("gateway.platforms.whatsapp.subprocess.run", side_effect=run_side_effect) as mock_run:
_kill_port_process(3000)
# netstat called
assert any(
call.args[0][0] == "netstat" for call in mock_run.call_args_list
)
# taskkill called with correct PID
assert any(
call.args[0] == ["taskkill", "/PID", "12345", "/F"]
for call in mock_run.call_args_list
)
def test_does_not_kill_wrong_port_on_windows(self):
from gateway.platforms.whatsapp import _kill_port_process
netstat_output = (
" TCP 0.0.0.0:30000 0.0.0.0:0 LISTENING 55555\n"
)
mock_netstat = MagicMock(stdout=netstat_output)
with patch("gateway.platforms.whatsapp._IS_WINDOWS", True), \
patch("gateway.platforms.whatsapp.subprocess.run", return_value=mock_netstat) as mock_run:
_kill_port_process(3000)
# Should NOT call taskkill because port 30000 != 3000
assert not any(
call.args[0][0] == "taskkill"
for call in mock_run.call_args_list
)
def test_uses_fuser_on_linux(self):
from gateway.platforms.whatsapp import _kill_port_process
mock_check = MagicMock(returncode=0)
with patch("gateway.platforms.whatsapp._IS_WINDOWS", False), \
patch("gateway.platforms.whatsapp.subprocess.run", return_value=mock_check) as mock_run:
_kill_port_process(3000)
calls = [c.args[0] for c in mock_run.call_args_list]
assert ["fuser", "3000/tcp"] in calls
assert ["fuser", "-k", "3000/tcp"] in calls
def test_skips_fuser_kill_when_port_free(self):
from gateway.platforms.whatsapp import _kill_port_process
mock_check = MagicMock(returncode=1) # port not in use
with patch("gateway.platforms.whatsapp._IS_WINDOWS", False), \
patch("gateway.platforms.whatsapp.subprocess.run", return_value=mock_check) as mock_run:
_kill_port_process(3000)
calls = [c.args[0] for c in mock_run.call_args_list]
assert ["fuser", "3000/tcp"] in calls
assert ["fuser", "-k", "3000/tcp"] not in calls
def test_suppresses_exceptions(self):
from gateway.platforms.whatsapp import _kill_port_process
with patch("gateway.platforms.whatsapp._IS_WINDOWS", True), \
patch("gateway.platforms.whatsapp.subprocess.run", side_effect=OSError("no netstat")):
_kill_port_process(3000) # must not raise

View file

@ -0,0 +1,121 @@
"""Tests for WhatsApp reply_prefix config.yaml support.
Covers:
- config.yaml whatsapp.reply_prefix bridging into PlatformConfig.extra
- WhatsAppAdapter reading reply_prefix from config.extra
- Bridge subprocess receiving WHATSAPP_REPLY_PREFIX env var
- Config version covers all ENV_VARS_BY_VERSION keys (regression guard)
"""
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from gateway.config import GatewayConfig, Platform, PlatformConfig
# ---------------------------------------------------------------------------
# Config bridging from config.yaml
# ---------------------------------------------------------------------------
class TestConfigYamlBridging:
"""Test that whatsapp.reply_prefix in config.yaml flows into PlatformConfig."""
def test_reply_prefix_bridged_from_yaml(self, tmp_path):
"""whatsapp.reply_prefix in config.yaml sets PlatformConfig.extra."""
config_yaml = tmp_path / "config.yaml"
config_yaml.write_text('whatsapp:\n reply_prefix: "Custom Bot"\n')
with patch("gateway.config.get_hermes_home", return_value=tmp_path):
from gateway.config import load_gateway_config
# Need to also patch WHATSAPP_ENABLED so the platform exists
with patch.dict("os.environ", {"WHATSAPP_ENABLED": "true"}, clear=False):
config = load_gateway_config()
wa_config = config.platforms.get(Platform.WHATSAPP)
assert wa_config is not None
assert wa_config.extra.get("reply_prefix") == "Custom Bot"
def test_empty_reply_prefix_bridged(self, tmp_path):
"""Empty string reply_prefix disables the header."""
config_yaml = tmp_path / "config.yaml"
config_yaml.write_text('whatsapp:\n reply_prefix: ""\n')
with patch("gateway.config.get_hermes_home", return_value=tmp_path):
from gateway.config import load_gateway_config
with patch.dict("os.environ", {"WHATSAPP_ENABLED": "true"}, clear=False):
config = load_gateway_config()
wa_config = config.platforms.get(Platform.WHATSAPP)
assert wa_config is not None
assert wa_config.extra.get("reply_prefix") == ""
def test_no_whatsapp_section_no_extra(self, tmp_path):
"""Without whatsapp section, no reply_prefix is set."""
config_yaml = tmp_path / "config.yaml"
config_yaml.write_text("timezone: UTC\n")
with patch("gateway.config.get_hermes_home", return_value=tmp_path):
from gateway.config import load_gateway_config
with patch.dict("os.environ", {"WHATSAPP_ENABLED": "true"}, clear=False):
config = load_gateway_config()
wa_config = config.platforms.get(Platform.WHATSAPP)
assert wa_config is not None
assert "reply_prefix" not in wa_config.extra
def test_whatsapp_section_without_reply_prefix(self, tmp_path):
"""whatsapp section present but without reply_prefix key."""
config_yaml = tmp_path / "config.yaml"
config_yaml.write_text("whatsapp:\n other_setting: true\n")
with patch("gateway.config.get_hermes_home", return_value=tmp_path):
from gateway.config import load_gateway_config
with patch.dict("os.environ", {"WHATSAPP_ENABLED": "true"}, clear=False):
config = load_gateway_config()
wa_config = config.platforms.get(Platform.WHATSAPP)
assert "reply_prefix" not in wa_config.extra
# ---------------------------------------------------------------------------
# WhatsAppAdapter __init__
# ---------------------------------------------------------------------------
class TestAdapterInit:
"""Test that WhatsAppAdapter reads reply_prefix from config.extra."""
def test_reply_prefix_from_extra(self):
from gateway.platforms.whatsapp import WhatsAppAdapter
config = PlatformConfig(enabled=True, extra={"reply_prefix": "Bot\\n"})
adapter = WhatsAppAdapter(config)
assert adapter._reply_prefix == "Bot\\n"
def test_reply_prefix_default_none(self):
from gateway.platforms.whatsapp import WhatsAppAdapter
config = PlatformConfig(enabled=True)
adapter = WhatsAppAdapter(config)
assert adapter._reply_prefix is None
def test_reply_prefix_empty_string(self):
from gateway.platforms.whatsapp import WhatsAppAdapter
config = PlatformConfig(enabled=True, extra={"reply_prefix": ""})
adapter = WhatsAppAdapter(config)
assert adapter._reply_prefix == ""
# ---------------------------------------------------------------------------
# Config version regression guard
# ---------------------------------------------------------------------------
class TestConfigVersionCoverage:
"""Ensure _config_version covers all ENV_VARS_BY_VERSION keys."""
def test_default_config_version_covers_env_var_versions(self):
"""_config_version must be >= the highest ENV_VARS_BY_VERSION key."""
from hermes_cli.config import DEFAULT_CONFIG, ENV_VARS_BY_VERSION
assert DEFAULT_CONFIG["_config_version"] >= max(ENV_VARS_BY_VERSION)