The architecture has been updated
This commit is contained in:
parent
805f7a017e
commit
a01257ead9
1119 changed files with 226 additions and 352 deletions
0
hermes_code/tests/gateway/__init__.py
Normal file
0
hermes_code/tests/gateway/__init__.py
Normal file
238
hermes_code/tests/gateway/test_agent_cache.py
Normal file
238
hermes_code/tests/gateway/test_agent_cache.py
Normal file
|
|
@ -0,0 +1,238 @@
|
|||
"""Integration tests for gateway AIAgent caching.
|
||||
|
||||
Verifies that the agent cache correctly:
|
||||
- Reuses agents across messages (same config → same instance)
|
||||
- Rebuilds agents when config changes (model, provider, toolsets)
|
||||
- Updates reasoning_config in-place without rebuilding
|
||||
- Evicts on session reset
|
||||
- Evicts on fallback activation
|
||||
- Preserves frozen system prompt across turns
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import threading
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _make_runner():
|
||||
"""Create a minimal GatewayRunner with just the cache infrastructure."""
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = GatewayRunner.__new__(GatewayRunner)
|
||||
runner._agent_cache = {}
|
||||
runner._agent_cache_lock = threading.Lock()
|
||||
return runner
|
||||
|
||||
|
||||
class TestAgentConfigSignature:
|
||||
"""Config signature produces stable, distinct keys."""
|
||||
|
||||
def test_same_config_same_signature(self):
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runtime = {"api_key": "sk-test12345678", "base_url": "https://openrouter.ai/api/v1",
|
||||
"provider": "openrouter", "api_mode": "chat_completions"}
|
||||
sig1 = GatewayRunner._agent_config_signature("claude-sonnet-4", runtime, ["hermes-telegram"], "")
|
||||
sig2 = GatewayRunner._agent_config_signature("claude-sonnet-4", runtime, ["hermes-telegram"], "")
|
||||
assert sig1 == sig2
|
||||
|
||||
def test_model_change_different_signature(self):
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runtime = {"api_key": "sk-test12345678", "base_url": "https://openrouter.ai/api/v1",
|
||||
"provider": "openrouter"}
|
||||
sig1 = GatewayRunner._agent_config_signature("claude-sonnet-4", runtime, ["hermes-telegram"], "")
|
||||
sig2 = GatewayRunner._agent_config_signature("claude-opus-4.6", runtime, ["hermes-telegram"], "")
|
||||
assert sig1 != sig2
|
||||
|
||||
def test_provider_change_different_signature(self):
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
rt1 = {"api_key": "sk-test12345678", "base_url": "https://openrouter.ai/api/v1", "provider": "openrouter"}
|
||||
rt2 = {"api_key": "sk-test12345678", "base_url": "https://api.anthropic.com", "provider": "anthropic"}
|
||||
sig1 = GatewayRunner._agent_config_signature("claude-sonnet-4", rt1, ["hermes-telegram"], "")
|
||||
sig2 = GatewayRunner._agent_config_signature("claude-sonnet-4", rt2, ["hermes-telegram"], "")
|
||||
assert sig1 != sig2
|
||||
|
||||
def test_toolset_change_different_signature(self):
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runtime = {"api_key": "sk-test12345678", "base_url": "https://openrouter.ai/api/v1", "provider": "openrouter"}
|
||||
sig1 = GatewayRunner._agent_config_signature("claude-sonnet-4", runtime, ["hermes-telegram"], "")
|
||||
sig2 = GatewayRunner._agent_config_signature("claude-sonnet-4", runtime, ["hermes-discord"], "")
|
||||
assert sig1 != sig2
|
||||
|
||||
def test_reasoning_not_in_signature(self):
|
||||
"""Reasoning config is set per-message, not part of the signature."""
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runtime = {"api_key": "sk-test12345678", "base_url": "https://openrouter.ai/api/v1", "provider": "openrouter"}
|
||||
# Same config — signature should be identical regardless of what
|
||||
# reasoning_config the caller might have (it's not passed in)
|
||||
sig1 = GatewayRunner._agent_config_signature("claude-sonnet-4", runtime, ["hermes-telegram"], "")
|
||||
sig2 = GatewayRunner._agent_config_signature("claude-sonnet-4", runtime, ["hermes-telegram"], "")
|
||||
assert sig1 == sig2
|
||||
|
||||
|
||||
class TestAgentCacheLifecycle:
|
||||
"""End-to-end cache behavior with real AIAgent construction."""
|
||||
|
||||
def test_cache_hit_returns_same_agent(self):
|
||||
"""Second message with same config reuses the cached agent instance."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
runner = _make_runner()
|
||||
session_key = "telegram:12345"
|
||||
runtime = {"api_key": "test", "base_url": "https://openrouter.ai/api/v1",
|
||||
"provider": "openrouter", "api_mode": "chat_completions"}
|
||||
sig = runner._agent_config_signature("anthropic/claude-sonnet-4", runtime, ["hermes-telegram"], "")
|
||||
|
||||
# First message — create and cache
|
||||
agent1 = AIAgent(
|
||||
model="anthropic/claude-sonnet-4", api_key="test",
|
||||
base_url="https://openrouter.ai/api/v1", provider="openrouter",
|
||||
max_iterations=5, quiet_mode=True, skip_context_files=True,
|
||||
skip_memory=True, platform="telegram",
|
||||
)
|
||||
with runner._agent_cache_lock:
|
||||
runner._agent_cache[session_key] = (agent1, sig)
|
||||
|
||||
# Second message — cache hit
|
||||
with runner._agent_cache_lock:
|
||||
cached = runner._agent_cache.get(session_key)
|
||||
assert cached is not None
|
||||
assert cached[1] == sig
|
||||
assert cached[0] is agent1 # same instance
|
||||
|
||||
def test_cache_miss_on_model_change(self):
|
||||
"""Model change produces different signature → cache miss."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
runner = _make_runner()
|
||||
session_key = "telegram:12345"
|
||||
runtime = {"api_key": "test", "base_url": "https://openrouter.ai/api/v1",
|
||||
"provider": "openrouter", "api_mode": "chat_completions"}
|
||||
|
||||
old_sig = runner._agent_config_signature("anthropic/claude-sonnet-4", runtime, ["hermes-telegram"], "")
|
||||
agent1 = AIAgent(
|
||||
model="anthropic/claude-sonnet-4", api_key="test",
|
||||
base_url="https://openrouter.ai/api/v1", provider="openrouter",
|
||||
max_iterations=5, quiet_mode=True, skip_context_files=True,
|
||||
skip_memory=True, platform="telegram",
|
||||
)
|
||||
with runner._agent_cache_lock:
|
||||
runner._agent_cache[session_key] = (agent1, old_sig)
|
||||
|
||||
# New model → different signature
|
||||
new_sig = runner._agent_config_signature("anthropic/claude-opus-4.6", runtime, ["hermes-telegram"], "")
|
||||
assert new_sig != old_sig
|
||||
|
||||
with runner._agent_cache_lock:
|
||||
cached = runner._agent_cache.get(session_key)
|
||||
assert cached[1] != new_sig # signature mismatch → would create new agent
|
||||
|
||||
def test_evict_on_session_reset(self):
|
||||
"""_evict_cached_agent removes the entry."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
runner = _make_runner()
|
||||
session_key = "telegram:12345"
|
||||
|
||||
agent = AIAgent(
|
||||
model="anthropic/claude-sonnet-4", api_key="test",
|
||||
base_url="https://openrouter.ai/api/v1", provider="openrouter",
|
||||
max_iterations=5, quiet_mode=True, skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
with runner._agent_cache_lock:
|
||||
runner._agent_cache[session_key] = (agent, "sig123")
|
||||
|
||||
runner._evict_cached_agent(session_key)
|
||||
|
||||
with runner._agent_cache_lock:
|
||||
assert session_key not in runner._agent_cache
|
||||
|
||||
def test_evict_does_not_affect_other_sessions(self):
|
||||
"""Evicting one session leaves other sessions cached."""
|
||||
runner = _make_runner()
|
||||
with runner._agent_cache_lock:
|
||||
runner._agent_cache["session-A"] = ("agent-A", "sig-A")
|
||||
runner._agent_cache["session-B"] = ("agent-B", "sig-B")
|
||||
|
||||
runner._evict_cached_agent("session-A")
|
||||
|
||||
with runner._agent_cache_lock:
|
||||
assert "session-A" not in runner._agent_cache
|
||||
assert "session-B" in runner._agent_cache
|
||||
|
||||
def test_reasoning_config_updates_in_place(self):
|
||||
"""Reasoning config can be set on a cached agent without eviction."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
agent = AIAgent(
|
||||
model="anthropic/claude-sonnet-4", api_key="test",
|
||||
base_url="https://openrouter.ai/api/v1", provider="openrouter",
|
||||
max_iterations=5, quiet_mode=True, skip_context_files=True,
|
||||
skip_memory=True,
|
||||
reasoning_config={"enabled": True, "effort": "medium"},
|
||||
)
|
||||
|
||||
# Simulate per-message reasoning update
|
||||
agent.reasoning_config = {"enabled": True, "effort": "high"}
|
||||
assert agent.reasoning_config["effort"] == "high"
|
||||
|
||||
# System prompt should not be affected by reasoning change
|
||||
prompt1 = agent._build_system_prompt()
|
||||
agent._cached_system_prompt = prompt1 # simulate run_conversation caching
|
||||
agent.reasoning_config = {"enabled": True, "effort": "low"}
|
||||
prompt2 = agent._cached_system_prompt
|
||||
assert prompt1 is prompt2 # same object — not invalidated by reasoning change
|
||||
|
||||
def test_system_prompt_frozen_across_cache_reuse(self):
|
||||
"""The cached agent's system prompt stays identical across turns."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
agent = AIAgent(
|
||||
model="anthropic/claude-sonnet-4", api_key="test",
|
||||
base_url="https://openrouter.ai/api/v1", provider="openrouter",
|
||||
max_iterations=5, quiet_mode=True, skip_context_files=True,
|
||||
skip_memory=True, platform="telegram",
|
||||
)
|
||||
|
||||
# Build system prompt (simulates first run_conversation)
|
||||
prompt1 = agent._build_system_prompt()
|
||||
agent._cached_system_prompt = prompt1
|
||||
|
||||
# Simulate second turn — prompt should be frozen
|
||||
prompt2 = agent._cached_system_prompt
|
||||
assert prompt1 is prompt2 # same object, not rebuilt
|
||||
|
||||
def test_callbacks_update_without_cache_eviction(self):
|
||||
"""Per-message callbacks can be set on cached agent."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
agent = AIAgent(
|
||||
model="anthropic/claude-sonnet-4", api_key="test",
|
||||
base_url="https://openrouter.ai/api/v1", provider="openrouter",
|
||||
max_iterations=5, quiet_mode=True, skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
|
||||
# Set callbacks like the gateway does per-message
|
||||
cb1 = lambda *a: None
|
||||
cb2 = lambda *a: None
|
||||
agent.tool_progress_callback = cb1
|
||||
agent.step_callback = cb2
|
||||
agent.stream_delta_callback = None
|
||||
agent.status_callback = None
|
||||
|
||||
assert agent.tool_progress_callback is cb1
|
||||
assert agent.step_callback is cb2
|
||||
|
||||
# Update for next message
|
||||
cb3 = lambda *a: None
|
||||
agent.tool_progress_callback = cb3
|
||||
assert agent.tool_progress_callback is cb3
|
||||
1391
hermes_code/tests/gateway/test_api_server.py
Normal file
1391
hermes_code/tests/gateway/test_api_server.py
Normal file
File diff suppressed because it is too large
Load diff
597
hermes_code/tests/gateway/test_api_server_jobs.py
Normal file
597
hermes_code/tests/gateway/test_api_server_jobs.py
Normal file
|
|
@ -0,0 +1,597 @@
|
|||
"""
|
||||
Tests for the Cron Jobs API endpoints on the API server adapter.
|
||||
|
||||
Covers:
|
||||
- CRUD operations for cron jobs (list, create, get, update, delete)
|
||||
- Pause / resume / run (trigger) actions
|
||||
- Input validation (missing name, name too long, prompt too long, invalid repeat)
|
||||
- Job ID validation (invalid hex)
|
||||
- Auth enforcement (401 when API_SERVER_KEY is set)
|
||||
- Cron module unavailability (501 when _CRON_AVAILABLE is False)
|
||||
"""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
from aiohttp.test_utils import TestClient, TestServer
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.api_server import APIServerAdapter, cors_middleware
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SAMPLE_JOB = {
|
||||
"id": "aabbccddeeff",
|
||||
"name": "test-job",
|
||||
"schedule": "*/5 * * * *",
|
||||
"prompt": "do something",
|
||||
"deliver": "local",
|
||||
"enabled": True,
|
||||
}
|
||||
|
||||
VALID_JOB_ID = "aabbccddeeff"
|
||||
|
||||
|
||||
def _make_adapter(api_key: str = "") -> APIServerAdapter:
|
||||
"""Create an adapter with optional API key."""
|
||||
extra = {}
|
||||
if api_key:
|
||||
extra["key"] = api_key
|
||||
config = PlatformConfig(enabled=True, extra=extra)
|
||||
return APIServerAdapter(config)
|
||||
|
||||
|
||||
def _create_app(adapter: APIServerAdapter) -> web.Application:
|
||||
"""Create the aiohttp app with jobs routes registered."""
|
||||
app = web.Application(middlewares=[cors_middleware])
|
||||
app["api_server_adapter"] = adapter
|
||||
# Register only job routes (plus health for sanity)
|
||||
app.router.add_get("/health", adapter._handle_health)
|
||||
app.router.add_get("/api/jobs", adapter._handle_list_jobs)
|
||||
app.router.add_post("/api/jobs", adapter._handle_create_job)
|
||||
app.router.add_get("/api/jobs/{job_id}", adapter._handle_get_job)
|
||||
app.router.add_patch("/api/jobs/{job_id}", adapter._handle_update_job)
|
||||
app.router.add_delete("/api/jobs/{job_id}", adapter._handle_delete_job)
|
||||
app.router.add_post("/api/jobs/{job_id}/pause", adapter._handle_pause_job)
|
||||
app.router.add_post("/api/jobs/{job_id}/resume", adapter._handle_resume_job)
|
||||
app.router.add_post("/api/jobs/{job_id}/run", adapter._handle_run_job)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def adapter():
|
||||
return _make_adapter()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_adapter():
|
||||
return _make_adapter(api_key="sk-secret")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. test_list_jobs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestListJobs:
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_jobs(self, adapter):
|
||||
"""GET /api/jobs returns job list."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(
|
||||
APIServerAdapter, "_CRON_AVAILABLE", True
|
||||
), patch.object(
|
||||
APIServerAdapter, "_cron_list", return_value=[SAMPLE_JOB]
|
||||
):
|
||||
resp = await cli.get("/api/jobs")
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert "jobs" in data
|
||||
assert data["jobs"] == [SAMPLE_JOB]
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# 2. test_list_jobs_include_disabled
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_jobs_include_disabled(self, adapter):
|
||||
"""GET /api/jobs?include_disabled=true passes the flag."""
|
||||
app = _create_app(adapter)
|
||||
mock_list = MagicMock(return_value=[SAMPLE_JOB])
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(
|
||||
APIServerAdapter, "_CRON_AVAILABLE", True
|
||||
), patch.object(
|
||||
APIServerAdapter, "_cron_list", mock_list
|
||||
):
|
||||
resp = await cli.get("/api/jobs?include_disabled=true")
|
||||
assert resp.status == 200
|
||||
mock_list.assert_called_once_with(include_disabled=True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_jobs_default_excludes_disabled(self, adapter):
|
||||
"""GET /api/jobs without flag passes include_disabled=False."""
|
||||
app = _create_app(adapter)
|
||||
mock_list = MagicMock(return_value=[])
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(
|
||||
APIServerAdapter, "_CRON_AVAILABLE", True
|
||||
), patch.object(
|
||||
APIServerAdapter, "_cron_list", mock_list
|
||||
):
|
||||
resp = await cli.get("/api/jobs")
|
||||
assert resp.status == 200
|
||||
mock_list.assert_called_once_with(include_disabled=False)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3-7. test_create_job and validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCreateJob:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_job(self, adapter):
|
||||
"""POST /api/jobs with valid body returns created job."""
|
||||
app = _create_app(adapter)
|
||||
mock_create = MagicMock(return_value=SAMPLE_JOB)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(
|
||||
APIServerAdapter, "_CRON_AVAILABLE", True
|
||||
), patch.object(
|
||||
APIServerAdapter, "_cron_create", mock_create
|
||||
):
|
||||
resp = await cli.post("/api/jobs", json={
|
||||
"name": "test-job",
|
||||
"schedule": "*/5 * * * *",
|
||||
"prompt": "do something",
|
||||
})
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data["job"] == SAMPLE_JOB
|
||||
mock_create.assert_called_once()
|
||||
call_kwargs = mock_create.call_args[1]
|
||||
assert call_kwargs["name"] == "test-job"
|
||||
assert call_kwargs["schedule"] == "*/5 * * * *"
|
||||
assert call_kwargs["prompt"] == "do something"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_job_missing_name(self, adapter):
|
||||
"""POST /api/jobs without name returns 400."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
|
||||
resp = await cli.post("/api/jobs", json={
|
||||
"schedule": "*/5 * * * *",
|
||||
"prompt": "do something",
|
||||
})
|
||||
assert resp.status == 400
|
||||
data = await resp.json()
|
||||
assert "name" in data["error"].lower() or "Name" in data["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_job_name_too_long(self, adapter):
|
||||
"""POST /api/jobs with name > 200 chars returns 400."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
|
||||
resp = await cli.post("/api/jobs", json={
|
||||
"name": "x" * 201,
|
||||
"schedule": "*/5 * * * *",
|
||||
})
|
||||
assert resp.status == 400
|
||||
data = await resp.json()
|
||||
assert "200" in data["error"] or "Name" in data["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_job_prompt_too_long(self, adapter):
|
||||
"""POST /api/jobs with prompt > 5000 chars returns 400."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
|
||||
resp = await cli.post("/api/jobs", json={
|
||||
"name": "test-job",
|
||||
"schedule": "*/5 * * * *",
|
||||
"prompt": "x" * 5001,
|
||||
})
|
||||
assert resp.status == 400
|
||||
data = await resp.json()
|
||||
assert "5000" in data["error"] or "Prompt" in data["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_job_invalid_repeat(self, adapter):
|
||||
"""POST /api/jobs with repeat=0 returns 400."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
|
||||
resp = await cli.post("/api/jobs", json={
|
||||
"name": "test-job",
|
||||
"schedule": "*/5 * * * *",
|
||||
"repeat": 0,
|
||||
})
|
||||
assert resp.status == 400
|
||||
data = await resp.json()
|
||||
assert "repeat" in data["error"].lower() or "Repeat" in data["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_job_missing_schedule(self, adapter):
|
||||
"""POST /api/jobs without schedule returns 400."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
|
||||
resp = await cli.post("/api/jobs", json={
|
||||
"name": "test-job",
|
||||
})
|
||||
assert resp.status == 400
|
||||
data = await resp.json()
|
||||
assert "schedule" in data["error"].lower() or "Schedule" in data["error"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 8-10. test_get_job
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGetJob:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_job(self, adapter):
|
||||
"""GET /api/jobs/{id} returns job."""
|
||||
app = _create_app(adapter)
|
||||
mock_get = MagicMock(return_value=SAMPLE_JOB)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(
|
||||
APIServerAdapter, "_CRON_AVAILABLE", True
|
||||
), patch.object(
|
||||
APIServerAdapter, "_cron_get", mock_get
|
||||
):
|
||||
resp = await cli.get(f"/api/jobs/{VALID_JOB_ID}")
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data["job"] == SAMPLE_JOB
|
||||
mock_get.assert_called_once_with(VALID_JOB_ID)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_job_not_found(self, adapter):
|
||||
"""GET /api/jobs/{id} returns 404 when job doesn't exist."""
|
||||
app = _create_app(adapter)
|
||||
mock_get = MagicMock(return_value=None)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(
|
||||
APIServerAdapter, "_CRON_AVAILABLE", True
|
||||
), patch.object(
|
||||
APIServerAdapter, "_cron_get", mock_get
|
||||
):
|
||||
resp = await cli.get(f"/api/jobs/{VALID_JOB_ID}")
|
||||
assert resp.status == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_job_invalid_id(self, adapter):
|
||||
"""GET /api/jobs/{id} with non-hex id returns 400."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
|
||||
resp = await cli.get("/api/jobs/not-a-valid-hex!")
|
||||
assert resp.status == 400
|
||||
data = await resp.json()
|
||||
assert "Invalid" in data["error"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 11-12. test_update_job
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestUpdateJob:
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_job(self, adapter):
|
||||
"""PATCH /api/jobs/{id} updates with whitelisted fields."""
|
||||
app = _create_app(adapter)
|
||||
updated_job = {**SAMPLE_JOB, "name": "updated-name"}
|
||||
mock_update = MagicMock(return_value=updated_job)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(
|
||||
APIServerAdapter, "_CRON_AVAILABLE", True
|
||||
), patch.object(
|
||||
APIServerAdapter, "_cron_update", mock_update
|
||||
):
|
||||
resp = await cli.patch(
|
||||
f"/api/jobs/{VALID_JOB_ID}",
|
||||
json={"name": "updated-name", "schedule": "0 * * * *"},
|
||||
)
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data["job"] == updated_job
|
||||
mock_update.assert_called_once()
|
||||
call_args = mock_update.call_args
|
||||
assert call_args[0][0] == VALID_JOB_ID
|
||||
sanitized = call_args[0][1]
|
||||
assert "name" in sanitized
|
||||
assert "schedule" in sanitized
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_job_rejects_unknown_fields(self, adapter):
|
||||
"""PATCH /api/jobs/{id} — only allowed fields pass through."""
|
||||
app = _create_app(adapter)
|
||||
updated_job = {**SAMPLE_JOB, "name": "new-name"}
|
||||
mock_update = MagicMock(return_value=updated_job)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(
|
||||
APIServerAdapter, "_CRON_AVAILABLE", True
|
||||
), patch.object(
|
||||
APIServerAdapter, "_cron_update", mock_update
|
||||
):
|
||||
resp = await cli.patch(
|
||||
f"/api/jobs/{VALID_JOB_ID}",
|
||||
json={
|
||||
"name": "new-name",
|
||||
"evil_field": "malicious",
|
||||
"__proto__": "hack",
|
||||
},
|
||||
)
|
||||
assert resp.status == 200
|
||||
call_args = mock_update.call_args
|
||||
sanitized = call_args[0][1]
|
||||
assert "name" in sanitized
|
||||
assert "evil_field" not in sanitized
|
||||
assert "__proto__" not in sanitized
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_job_no_valid_fields(self, adapter):
|
||||
"""PATCH /api/jobs/{id} with only unknown fields returns 400."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
|
||||
resp = await cli.patch(
|
||||
f"/api/jobs/{VALID_JOB_ID}",
|
||||
json={"evil_field": "malicious"},
|
||||
)
|
||||
assert resp.status == 400
|
||||
data = await resp.json()
|
||||
assert "No valid fields" in data["error"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 13. test_delete_job
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDeleteJob:
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_job(self, adapter):
|
||||
"""DELETE /api/jobs/{id} returns ok."""
|
||||
app = _create_app(adapter)
|
||||
mock_remove = MagicMock(return_value=True)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(
|
||||
APIServerAdapter, "_CRON_AVAILABLE", True
|
||||
), patch.object(
|
||||
APIServerAdapter, "_cron_remove", mock_remove
|
||||
):
|
||||
resp = await cli.delete(f"/api/jobs/{VALID_JOB_ID}")
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data["ok"] is True
|
||||
mock_remove.assert_called_once_with(VALID_JOB_ID)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_job_not_found(self, adapter):
|
||||
"""DELETE /api/jobs/{id} returns 404 when job doesn't exist."""
|
||||
app = _create_app(adapter)
|
||||
mock_remove = MagicMock(return_value=False)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(
|
||||
APIServerAdapter, "_CRON_AVAILABLE", True
|
||||
), patch.object(
|
||||
APIServerAdapter, "_cron_remove", mock_remove
|
||||
):
|
||||
resp = await cli.delete(f"/api/jobs/{VALID_JOB_ID}")
|
||||
assert resp.status == 404
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 14. test_pause_job
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPauseJob:
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_job(self, adapter):
|
||||
"""POST /api/jobs/{id}/pause returns updated job."""
|
||||
app = _create_app(adapter)
|
||||
paused_job = {**SAMPLE_JOB, "enabled": False}
|
||||
mock_pause = MagicMock(return_value=paused_job)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(
|
||||
APIServerAdapter, "_CRON_AVAILABLE", True
|
||||
), patch.object(
|
||||
APIServerAdapter, "_cron_pause", mock_pause
|
||||
):
|
||||
resp = await cli.post(f"/api/jobs/{VALID_JOB_ID}/pause")
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data["job"] == paused_job
|
||||
assert data["job"]["enabled"] is False
|
||||
mock_pause.assert_called_once_with(VALID_JOB_ID)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 15. test_resume_job
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestResumeJob:
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_job(self, adapter):
|
||||
"""POST /api/jobs/{id}/resume returns updated job."""
|
||||
app = _create_app(adapter)
|
||||
resumed_job = {**SAMPLE_JOB, "enabled": True}
|
||||
mock_resume = MagicMock(return_value=resumed_job)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(
|
||||
APIServerAdapter, "_CRON_AVAILABLE", True
|
||||
), patch.object(
|
||||
APIServerAdapter, "_cron_resume", mock_resume
|
||||
):
|
||||
resp = await cli.post(f"/api/jobs/{VALID_JOB_ID}/resume")
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data["job"] == resumed_job
|
||||
assert data["job"]["enabled"] is True
|
||||
mock_resume.assert_called_once_with(VALID_JOB_ID)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 16. test_run_job
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRunJob:
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_job(self, adapter):
|
||||
"""POST /api/jobs/{id}/run returns triggered job."""
|
||||
app = _create_app(adapter)
|
||||
triggered_job = {**SAMPLE_JOB, "last_run": "2025-01-01T00:00:00Z"}
|
||||
mock_trigger = MagicMock(return_value=triggered_job)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(
|
||||
APIServerAdapter, "_CRON_AVAILABLE", True
|
||||
), patch.object(
|
||||
APIServerAdapter, "_cron_trigger", mock_trigger
|
||||
):
|
||||
resp = await cli.post(f"/api/jobs/{VALID_JOB_ID}/run")
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data["job"] == triggered_job
|
||||
mock_trigger.assert_called_once_with(VALID_JOB_ID)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 17. test_auth_required
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestAuthRequired:
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_required_list_jobs(self, auth_adapter):
|
||||
"""GET /api/jobs without API key returns 401 when key is set."""
|
||||
app = _create_app(auth_adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
|
||||
resp = await cli.get("/api/jobs")
|
||||
assert resp.status == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_required_create_job(self, auth_adapter):
|
||||
"""POST /api/jobs without API key returns 401 when key is set."""
|
||||
app = _create_app(auth_adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
|
||||
resp = await cli.post("/api/jobs", json={
|
||||
"name": "test", "schedule": "* * * * *",
|
||||
})
|
||||
assert resp.status == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_required_get_job(self, auth_adapter):
|
||||
"""GET /api/jobs/{id} without API key returns 401 when key is set."""
|
||||
app = _create_app(auth_adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
|
||||
resp = await cli.get(f"/api/jobs/{VALID_JOB_ID}")
|
||||
assert resp.status == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_required_delete_job(self, auth_adapter):
|
||||
"""DELETE /api/jobs/{id} without API key returns 401."""
|
||||
app = _create_app(auth_adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
|
||||
resp = await cli.delete(f"/api/jobs/{VALID_JOB_ID}")
|
||||
assert resp.status == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_passes_with_valid_key(self, auth_adapter):
|
||||
"""GET /api/jobs with correct API key succeeds."""
|
||||
app = _create_app(auth_adapter)
|
||||
mock_list = MagicMock(return_value=[])
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(
|
||||
APIServerAdapter, "_CRON_AVAILABLE", True
|
||||
), patch.object(
|
||||
APIServerAdapter, "_cron_list", mock_list
|
||||
):
|
||||
resp = await cli.get(
|
||||
"/api/jobs",
|
||||
headers={"Authorization": "Bearer sk-secret"},
|
||||
)
|
||||
assert resp.status == 200
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 18. test_cron_unavailable
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCronUnavailable:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cron_unavailable_list(self, adapter):
|
||||
"""GET /api/jobs returns 501 when _CRON_AVAILABLE is False."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", False):
|
||||
resp = await cli.get("/api/jobs")
|
||||
assert resp.status == 501
|
||||
data = await resp.json()
|
||||
assert "not available" in data["error"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cron_unavailable_create(self, adapter):
|
||||
"""POST /api/jobs returns 501 when _CRON_AVAILABLE is False."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", False):
|
||||
resp = await cli.post("/api/jobs", json={
|
||||
"name": "test", "schedule": "* * * * *",
|
||||
})
|
||||
assert resp.status == 501
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cron_unavailable_get(self, adapter):
|
||||
"""GET /api/jobs/{id} returns 501 when _CRON_AVAILABLE is False."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", False):
|
||||
resp = await cli.get(f"/api/jobs/{VALID_JOB_ID}")
|
||||
assert resp.status == 501
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cron_unavailable_delete(self, adapter):
|
||||
"""DELETE /api/jobs/{id} returns 501 when _CRON_AVAILABLE is False."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", False):
|
||||
resp = await cli.delete(f"/api/jobs/{VALID_JOB_ID}")
|
||||
assert resp.status == 501
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cron_unavailable_pause(self, adapter):
|
||||
"""POST /api/jobs/{id}/pause returns 501 when _CRON_AVAILABLE is False."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", False):
|
||||
resp = await cli.post(f"/api/jobs/{VALID_JOB_ID}/pause")
|
||||
assert resp.status == 501
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cron_unavailable_resume(self, adapter):
|
||||
"""POST /api/jobs/{id}/resume returns 501 when _CRON_AVAILABLE is False."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", False):
|
||||
resp = await cli.post(f"/api/jobs/{VALID_JOB_ID}/resume")
|
||||
assert resp.status == 501
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cron_unavailable_run(self, adapter):
|
||||
"""POST /api/jobs/{id}/run returns 501 when _CRON_AVAILABLE is False."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", False):
|
||||
resp = await cli.post(f"/api/jobs/{VALID_JOB_ID}/run")
|
||||
assert resp.status == 501
|
||||
240
hermes_code/tests/gateway/test_approve_deny_commands.py
Normal file
240
hermes_code/tests/gateway/test_approve_deny_commands.py
Normal file
|
|
@ -0,0 +1,240 @@
|
|||
"""Tests for /approve and /deny gateway commands.
|
||||
|
||||
Verifies that dangerous command approvals require explicit /approve or /deny
|
||||
slash commands, not bare "yes"/"no" text matching.
|
||||
"""
|
||||
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import MessageEvent
|
||||
from gateway.session import SessionEntry, SessionSource, build_session_key
|
||||
|
||||
|
||||
def _make_source() -> SessionSource:
|
||||
return SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
user_id="u1",
|
||||
chat_id="c1",
|
||||
user_name="tester",
|
||||
chat_type="dm",
|
||||
)
|
||||
|
||||
|
||||
def _make_event(text: str) -> MessageEvent:
|
||||
return MessageEvent(
|
||||
text=text,
|
||||
source=_make_source(),
|
||||
message_id="m1",
|
||||
)
|
||||
|
||||
|
||||
def _make_runner():
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(
|
||||
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}
|
||||
)
|
||||
adapter = MagicMock()
|
||||
adapter.send = AsyncMock()
|
||||
runner.adapters = {Platform.TELEGRAM: adapter}
|
||||
runner._voice_mode = {}
|
||||
runner.hooks = SimpleNamespace(emit=AsyncMock(), loaded_hooks=False)
|
||||
runner.session_store = MagicMock()
|
||||
runner._running_agents = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._session_db = None
|
||||
runner._reasoning_config = None
|
||||
runner._provider_routing = {}
|
||||
runner._fallback_model = None
|
||||
runner._show_reasoning = False
|
||||
runner._is_user_authorized = lambda _source: True
|
||||
runner._set_session_env = lambda _context: None
|
||||
return runner
|
||||
|
||||
|
||||
def _make_pending_approval(command="sudo rm -rf /tmp/test", pattern_key="sudo"):
|
||||
return {
|
||||
"command": command,
|
||||
"pattern_key": pattern_key,
|
||||
"pattern_keys": [pattern_key],
|
||||
"description": "sudo command",
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# /approve command
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestApproveCommand:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approve_executes_pending_command(self):
|
||||
"""Basic /approve executes the pending command."""
|
||||
runner = _make_runner()
|
||||
source = _make_source()
|
||||
session_key = runner._session_key_for_source(source)
|
||||
runner._pending_approvals[session_key] = _make_pending_approval()
|
||||
|
||||
event = _make_event("/approve")
|
||||
with patch("tools.terminal_tool.terminal_tool", return_value="done") as mock_term:
|
||||
result = await runner._handle_approve_command(event)
|
||||
|
||||
assert "✅ Command approved and executed" in result
|
||||
mock_term.assert_called_once_with(command="sudo rm -rf /tmp/test", force=True)
|
||||
assert session_key not in runner._pending_approvals
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approve_session_remembers_pattern(self):
|
||||
"""/approve session approves the pattern for the session."""
|
||||
runner = _make_runner()
|
||||
source = _make_source()
|
||||
session_key = runner._session_key_for_source(source)
|
||||
runner._pending_approvals[session_key] = _make_pending_approval()
|
||||
|
||||
event = _make_event("/approve session")
|
||||
with (
|
||||
patch("tools.terminal_tool.terminal_tool", return_value="done"),
|
||||
patch("tools.approval.approve_session") as mock_session,
|
||||
):
|
||||
result = await runner._handle_approve_command(event)
|
||||
|
||||
assert "pattern approved for this session" in result
|
||||
mock_session.assert_called_once_with(session_key, "sudo")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approve_always_approves_permanently(self):
|
||||
"""/approve always approves the pattern permanently."""
|
||||
runner = _make_runner()
|
||||
source = _make_source()
|
||||
session_key = runner._session_key_for_source(source)
|
||||
runner._pending_approvals[session_key] = _make_pending_approval()
|
||||
|
||||
event = _make_event("/approve always")
|
||||
with (
|
||||
patch("tools.terminal_tool.terminal_tool", return_value="done"),
|
||||
patch("tools.approval.approve_permanent") as mock_perm,
|
||||
):
|
||||
result = await runner._handle_approve_command(event)
|
||||
|
||||
assert "pattern approved permanently" in result
|
||||
mock_perm.assert_called_once_with("sudo")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approve_no_pending(self):
|
||||
"""/approve with no pending approval returns helpful message."""
|
||||
runner = _make_runner()
|
||||
event = _make_event("/approve")
|
||||
result = await runner._handle_approve_command(event)
|
||||
assert "No pending command" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approve_expired(self):
|
||||
"""/approve on a timed-out approval rejects it."""
|
||||
runner = _make_runner()
|
||||
source = _make_source()
|
||||
session_key = runner._session_key_for_source(source)
|
||||
approval = _make_pending_approval()
|
||||
approval["timestamp"] = time.time() - 600 # 10 minutes ago
|
||||
runner._pending_approvals[session_key] = approval
|
||||
|
||||
event = _make_event("/approve")
|
||||
result = await runner._handle_approve_command(event)
|
||||
|
||||
assert "expired" in result
|
||||
assert session_key not in runner._pending_approvals
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# /deny command
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDenyCommand:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deny_clears_pending(self):
|
||||
"""/deny clears the pending approval."""
|
||||
runner = _make_runner()
|
||||
source = _make_source()
|
||||
session_key = runner._session_key_for_source(source)
|
||||
runner._pending_approvals[session_key] = _make_pending_approval()
|
||||
|
||||
event = _make_event("/deny")
|
||||
result = await runner._handle_deny_command(event)
|
||||
|
||||
assert "❌ Command denied" in result
|
||||
assert session_key not in runner._pending_approvals
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deny_no_pending(self):
|
||||
"""/deny with no pending approval returns helpful message."""
|
||||
runner = _make_runner()
|
||||
event = _make_event("/deny")
|
||||
result = await runner._handle_deny_command(event)
|
||||
assert "No pending command" in result
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Bare "yes" must NOT trigger approval
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBareTextNoLongerApproves:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_yes_does_not_execute_pending_command(self):
|
||||
"""Saying 'yes' in normal conversation must not execute a pending command.
|
||||
|
||||
This is the core bug from issue #1888: bare text matching against
|
||||
'yes'/'no' could intercept unrelated user messages.
|
||||
"""
|
||||
runner = _make_runner()
|
||||
source = _make_source()
|
||||
session_key = runner._session_key_for_source(source)
|
||||
runner._pending_approvals[session_key] = _make_pending_approval()
|
||||
|
||||
# Simulate the user saying "yes" as a normal message.
|
||||
# The old code would have executed the pending command.
|
||||
# Now it should fall through to normal processing (agent handles it).
|
||||
event = _make_event("yes")
|
||||
|
||||
# The approval should still be pending — "yes" is not /approve
|
||||
# We can't easily run _handle_message end-to-end, but we CAN verify
|
||||
# the old text-matching block no longer exists by confirming the
|
||||
# approval is untouched after the command dispatch section.
|
||||
# The key assertion is that _pending_approvals is NOT consumed.
|
||||
assert session_key in runner._pending_approvals
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Approval hint appended to response
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestApprovalHint:
|
||||
|
||||
def test_approval_hint_appended_to_response(self):
|
||||
"""When a pending approval is collected, structured instructions
|
||||
should be appended to the agent response."""
|
||||
# This tests the approval collection logic at the end of _handle_message.
|
||||
# We verify the hint format directly.
|
||||
cmd = "sudo rm -rf /tmp/dangerous"
|
||||
cmd_preview = cmd
|
||||
hint = (
|
||||
f"\n\n⚠️ **Dangerous command requires approval:**\n"
|
||||
f"```\n{cmd_preview}\n```\n"
|
||||
f"Reply `/approve` to execute, `/approve session` to approve this pattern "
|
||||
f"for the session, or `/deny` to cancel."
|
||||
)
|
||||
assert "/approve" in hint
|
||||
assert "/deny" in hint
|
||||
assert cmd in hint
|
||||
180
hermes_code/tests/gateway/test_async_memory_flush.py
Normal file
180
hermes_code/tests/gateway/test_async_memory_flush.py
Normal file
|
|
@ -0,0 +1,180 @@
|
|||
"""Tests for proactive memory flush on session expiry.
|
||||
|
||||
Verifies that:
|
||||
1. _is_session_expired() works from a SessionEntry alone (no source needed)
|
||||
2. The sync callback is no longer called in get_or_create_session
|
||||
3. _pre_flushed_sessions tracking works correctly
|
||||
4. The background watcher can detect expired sessions
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from gateway.config import Platform, GatewayConfig, SessionResetPolicy
|
||||
from gateway.session import SessionSource, SessionStore, SessionEntry
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def idle_store(tmp_path):
|
||||
"""SessionStore with a 60-minute idle reset policy."""
|
||||
config = GatewayConfig(
|
||||
default_reset_policy=SessionResetPolicy(mode="idle", idle_minutes=60),
|
||||
)
|
||||
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||
s = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
s._db = None
|
||||
s._loaded = True
|
||||
return s
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def no_reset_store(tmp_path):
|
||||
"""SessionStore with no reset policy (mode=none)."""
|
||||
config = GatewayConfig(
|
||||
default_reset_policy=SessionResetPolicy(mode="none"),
|
||||
)
|
||||
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||
s = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
s._db = None
|
||||
s._loaded = True
|
||||
return s
|
||||
|
||||
|
||||
class TestIsSessionExpired:
|
||||
"""_is_session_expired should detect expiry from entry alone."""
|
||||
|
||||
def test_idle_session_expired(self, idle_store):
|
||||
entry = SessionEntry(
|
||||
session_key="agent:main:telegram:dm",
|
||||
session_id="sid_1",
|
||||
created_at=datetime.now() - timedelta(hours=3),
|
||||
updated_at=datetime.now() - timedelta(minutes=120),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
)
|
||||
assert idle_store._is_session_expired(entry) is True
|
||||
|
||||
def test_active_session_not_expired(self, idle_store):
|
||||
entry = SessionEntry(
|
||||
session_key="agent:main:telegram:dm",
|
||||
session_id="sid_2",
|
||||
created_at=datetime.now() - timedelta(hours=1),
|
||||
updated_at=datetime.now() - timedelta(minutes=10),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
)
|
||||
assert idle_store._is_session_expired(entry) is False
|
||||
|
||||
def test_none_mode_never_expires(self, no_reset_store):
|
||||
entry = SessionEntry(
|
||||
session_key="agent:main:telegram:dm",
|
||||
session_id="sid_3",
|
||||
created_at=datetime.now() - timedelta(days=30),
|
||||
updated_at=datetime.now() - timedelta(days=30),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
)
|
||||
assert no_reset_store._is_session_expired(entry) is False
|
||||
|
||||
def test_active_processes_prevent_expiry(self, idle_store):
|
||||
"""Sessions with active background processes should never expire."""
|
||||
idle_store._has_active_processes_fn = lambda key: True
|
||||
entry = SessionEntry(
|
||||
session_key="agent:main:telegram:dm",
|
||||
session_id="sid_4",
|
||||
created_at=datetime.now() - timedelta(hours=5),
|
||||
updated_at=datetime.now() - timedelta(hours=5),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
)
|
||||
assert idle_store._is_session_expired(entry) is False
|
||||
|
||||
def test_daily_mode_expired(self, tmp_path):
|
||||
"""Daily mode should expire sessions from before today's reset hour."""
|
||||
config = GatewayConfig(
|
||||
default_reset_policy=SessionResetPolicy(mode="daily", at_hour=4),
|
||||
)
|
||||
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||
store = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
store._db = None
|
||||
store._loaded = True
|
||||
|
||||
entry = SessionEntry(
|
||||
session_key="agent:main:telegram:dm",
|
||||
session_id="sid_5",
|
||||
created_at=datetime.now() - timedelta(days=2),
|
||||
updated_at=datetime.now() - timedelta(days=2),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
)
|
||||
assert store._is_session_expired(entry) is True
|
||||
|
||||
|
||||
class TestGetOrCreateSessionNoCallback:
|
||||
"""get_or_create_session should NOT call a sync flush callback."""
|
||||
|
||||
def test_auto_reset_cleans_pre_flushed_marker(self, idle_store):
|
||||
"""When a session auto-resets, the pre_flushed marker should be discarded."""
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="123",
|
||||
chat_type="dm",
|
||||
)
|
||||
# Create initial session
|
||||
entry1 = idle_store.get_or_create_session(source)
|
||||
old_sid = entry1.session_id
|
||||
|
||||
# Simulate the watcher having flushed it
|
||||
idle_store._pre_flushed_sessions.add(old_sid)
|
||||
|
||||
# Simulate the session going idle
|
||||
entry1.updated_at = datetime.now() - timedelta(minutes=120)
|
||||
idle_store._save()
|
||||
|
||||
# Next call should auto-reset
|
||||
entry2 = idle_store.get_or_create_session(source)
|
||||
assert entry2.session_id != old_sid
|
||||
assert entry2.was_auto_reset is True
|
||||
|
||||
# The old session_id should be removed from pre_flushed
|
||||
assert old_sid not in idle_store._pre_flushed_sessions
|
||||
|
||||
def test_no_sync_callback_invoked(self, idle_store):
|
||||
"""No synchronous callback should block during auto-reset."""
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="123",
|
||||
chat_type="dm",
|
||||
)
|
||||
entry1 = idle_store.get_or_create_session(source)
|
||||
entry1.updated_at = datetime.now() - timedelta(minutes=120)
|
||||
idle_store._save()
|
||||
|
||||
# Verify no _on_auto_reset attribute
|
||||
assert not hasattr(idle_store, '_on_auto_reset')
|
||||
|
||||
# This should NOT block (no sync LLM call)
|
||||
entry2 = idle_store.get_or_create_session(source)
|
||||
assert entry2.was_auto_reset is True
|
||||
|
||||
|
||||
class TestPreFlushedSessionsTracking:
|
||||
"""The _pre_flushed_sessions set should prevent double-flushing."""
|
||||
|
||||
def test_starts_empty(self, idle_store):
|
||||
assert len(idle_store._pre_flushed_sessions) == 0
|
||||
|
||||
def test_add_and_check(self, idle_store):
|
||||
idle_store._pre_flushed_sessions.add("sid_old")
|
||||
assert "sid_old" in idle_store._pre_flushed_sessions
|
||||
assert "sid_other" not in idle_store._pre_flushed_sessions
|
||||
|
||||
def test_discard_on_reset(self, idle_store):
|
||||
"""discard should remove without raising if not present."""
|
||||
idle_store._pre_flushed_sessions.add("sid_a")
|
||||
idle_store._pre_flushed_sessions.discard("sid_a")
|
||||
assert "sid_a" not in idle_store._pre_flushed_sessions
|
||||
# discard on non-existent should not raise
|
||||
idle_store._pre_flushed_sessions.discard("sid_nonexistent")
|
||||
322
hermes_code/tests/gateway/test_background_command.py
Normal file
322
hermes_code/tests/gateway/test_background_command.py
Normal file
|
|
@ -0,0 +1,322 @@
|
|||
"""Tests for /background gateway slash command.
|
||||
|
||||
Tests the _handle_background_command handler (run a prompt in a separate
|
||||
background session) across gateway messenger platforms.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform
|
||||
from gateway.platforms.base import MessageEvent
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
def _make_event(text="/background", platform=Platform.TELEGRAM,
|
||||
user_id="12345", chat_id="67890"):
|
||||
"""Build a MessageEvent for testing."""
|
||||
source = SessionSource(
|
||||
platform=platform,
|
||||
user_id=user_id,
|
||||
chat_id=chat_id,
|
||||
user_name="testuser",
|
||||
)
|
||||
return MessageEvent(text=text, source=source)
|
||||
|
||||
|
||||
def _make_runner():
|
||||
"""Create a bare GatewayRunner with minimal mocks."""
|
||||
from gateway.run import GatewayRunner
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.adapters = {}
|
||||
runner._voice_mode = {}
|
||||
runner._session_db = None
|
||||
runner._reasoning_config = None
|
||||
runner._provider_routing = {}
|
||||
runner._fallback_model = None
|
||||
runner._running_agents = {}
|
||||
|
||||
mock_store = MagicMock()
|
||||
runner.session_store = mock_store
|
||||
|
||||
from gateway.hooks import HookRegistry
|
||||
runner.hooks = HookRegistry()
|
||||
|
||||
return runner
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _handle_background_command
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHandleBackgroundCommand:
|
||||
"""Tests for GatewayRunner._handle_background_command."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_prompt_shows_usage(self):
|
||||
"""Running /background with no prompt shows usage."""
|
||||
runner = _make_runner()
|
||||
event = _make_event(text="/background")
|
||||
result = await runner._handle_background_command(event)
|
||||
assert "Usage:" in result
|
||||
assert "/background" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bg_alias_no_prompt_shows_usage(self):
|
||||
"""Running /bg with no prompt shows usage."""
|
||||
runner = _make_runner()
|
||||
event = _make_event(text="/bg")
|
||||
result = await runner._handle_background_command(event)
|
||||
assert "Usage:" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_prompt_shows_usage(self):
|
||||
"""Running /background with only whitespace shows usage."""
|
||||
runner = _make_runner()
|
||||
event = _make_event(text="/background ")
|
||||
result = await runner._handle_background_command(event)
|
||||
assert "Usage:" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_prompt_starts_task(self):
|
||||
"""Running /background with a prompt returns confirmation and starts task."""
|
||||
runner = _make_runner()
|
||||
|
||||
# Patch asyncio.create_task to capture the coroutine
|
||||
created_tasks = []
|
||||
original_create_task = asyncio.create_task
|
||||
|
||||
def capture_task(coro, *args, **kwargs):
|
||||
# Close the coroutine to avoid warnings
|
||||
coro.close()
|
||||
mock_task = MagicMock()
|
||||
created_tasks.append(mock_task)
|
||||
return mock_task
|
||||
|
||||
with patch("gateway.run.asyncio.create_task", side_effect=capture_task):
|
||||
event = _make_event(text="/background Summarize the top HN stories")
|
||||
result = await runner._handle_background_command(event)
|
||||
|
||||
assert "🔄" in result
|
||||
assert "Background task started" in result
|
||||
assert "bg_" in result # task ID starts with bg_
|
||||
assert "Summarize the top HN stories" in result
|
||||
assert len(created_tasks) == 1 # background task was created
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_truncated_in_preview(self):
|
||||
"""Long prompts are truncated to 60 chars in the confirmation message."""
|
||||
runner = _make_runner()
|
||||
long_prompt = "A" * 100
|
||||
|
||||
with patch("gateway.run.asyncio.create_task", side_effect=lambda c, **kw: (c.close(), MagicMock())[1]):
|
||||
event = _make_event(text=f"/background {long_prompt}")
|
||||
result = await runner._handle_background_command(event)
|
||||
|
||||
assert "..." in result
|
||||
# Should not contain the full prompt
|
||||
assert long_prompt not in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_id_is_unique(self):
|
||||
"""Each background task gets a unique task ID."""
|
||||
runner = _make_runner()
|
||||
task_ids = set()
|
||||
|
||||
with patch("gateway.run.asyncio.create_task", side_effect=lambda c, **kw: (c.close(), MagicMock())[1]):
|
||||
for i in range(5):
|
||||
event = _make_event(text=f"/background task {i}")
|
||||
result = await runner._handle_background_command(event)
|
||||
# Extract task ID from result (format: "Task ID: bg_HHMMSS_hex")
|
||||
for line in result.split("\n"):
|
||||
if "Task ID:" in line:
|
||||
tid = line.split("Task ID:")[1].strip()
|
||||
task_ids.add(tid)
|
||||
|
||||
assert len(task_ids) == 5 # all unique
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_works_across_platforms(self):
|
||||
"""The /background command works for all platforms."""
|
||||
for platform in [Platform.TELEGRAM, Platform.DISCORD, Platform.SLACK]:
|
||||
runner = _make_runner()
|
||||
with patch("gateway.run.asyncio.create_task", side_effect=lambda c, **kw: (c.close(), MagicMock())[1]):
|
||||
event = _make_event(
|
||||
text="/background test task",
|
||||
platform=platform,
|
||||
)
|
||||
result = await runner._handle_background_command(event)
|
||||
assert "Background task started" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _run_background_task
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunBackgroundTask:
|
||||
"""Tests for GatewayRunner._run_background_task (the actual execution)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_adapter_returns_silently(self):
|
||||
"""When no adapter is available, the task returns without error."""
|
||||
runner = _make_runner()
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
user_id="12345",
|
||||
chat_id="67890",
|
||||
user_name="testuser",
|
||||
)
|
||||
# No adapters set — should not raise
|
||||
await runner._run_background_task("test prompt", source, "bg_test")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_credentials_sends_error(self):
|
||||
"""When provider credentials are missing, an error is sent."""
|
||||
runner = _make_runner()
|
||||
mock_adapter = AsyncMock()
|
||||
mock_adapter.send = AsyncMock()
|
||||
runner.adapters[Platform.TELEGRAM] = mock_adapter
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
user_id="12345",
|
||||
chat_id="67890",
|
||||
user_name="testuser",
|
||||
)
|
||||
|
||||
with patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": None}):
|
||||
await runner._run_background_task("test prompt", source, "bg_test")
|
||||
|
||||
# Should have sent an error message
|
||||
mock_adapter.send.assert_called_once()
|
||||
call_args = mock_adapter.send.call_args
|
||||
assert "failed" in call_args[1].get("content", call_args[0][1] if len(call_args[0]) > 1 else "").lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_task_sends_result(self):
|
||||
"""When the agent completes successfully, the result is sent."""
|
||||
runner = _make_runner()
|
||||
mock_adapter = AsyncMock()
|
||||
mock_adapter.send = AsyncMock()
|
||||
mock_adapter.extract_media = MagicMock(return_value=([], "Hello from background!"))
|
||||
mock_adapter.extract_images = MagicMock(return_value=([], "Hello from background!"))
|
||||
runner.adapters[Platform.TELEGRAM] = mock_adapter
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
user_id="12345",
|
||||
chat_id="67890",
|
||||
user_name="testuser",
|
||||
)
|
||||
|
||||
mock_result = {"final_response": "Hello from background!", "messages": []}
|
||||
|
||||
with patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "test-key"}), \
|
||||
patch("run_agent.AIAgent") as MockAgent:
|
||||
mock_agent_instance = MagicMock()
|
||||
mock_agent_instance.run_conversation.return_value = mock_result
|
||||
MockAgent.return_value = mock_agent_instance
|
||||
|
||||
await runner._run_background_task("say hello", source, "bg_test")
|
||||
|
||||
# Should have sent the result
|
||||
mock_adapter.send.assert_called_once()
|
||||
call_args = mock_adapter.send.call_args
|
||||
content = call_args[1].get("content", call_args[0][1] if len(call_args[0]) > 1 else "")
|
||||
assert "Background task complete" in content
|
||||
assert "Hello from background!" in content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_sends_error_message(self):
|
||||
"""When the agent raises an exception, an error message is sent."""
|
||||
runner = _make_runner()
|
||||
mock_adapter = AsyncMock()
|
||||
mock_adapter.send = AsyncMock()
|
||||
runner.adapters[Platform.TELEGRAM] = mock_adapter
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
user_id="12345",
|
||||
chat_id="67890",
|
||||
user_name="testuser",
|
||||
)
|
||||
|
||||
with patch("gateway.run._resolve_runtime_agent_kwargs", side_effect=RuntimeError("boom")):
|
||||
await runner._run_background_task("test prompt", source, "bg_test")
|
||||
|
||||
mock_adapter.send.assert_called_once()
|
||||
call_args = mock_adapter.send.call_args
|
||||
content = call_args[1].get("content", call_args[0][1] if len(call_args[0]) > 1 else "")
|
||||
assert "failed" in content.lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# /background in help and known_commands
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBackgroundInHelp:
|
||||
"""Verify /background appears in help text and known commands."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_background_in_help_output(self):
|
||||
"""The /help output includes /background."""
|
||||
runner = _make_runner()
|
||||
event = _make_event(text="/help")
|
||||
result = await runner._handle_help_command(event)
|
||||
assert "/background" in result
|
||||
|
||||
def test_background_is_known_command(self):
|
||||
"""The /background command is in GATEWAY_KNOWN_COMMANDS."""
|
||||
from hermes_cli.commands import GATEWAY_KNOWN_COMMANDS
|
||||
assert "background" in GATEWAY_KNOWN_COMMANDS
|
||||
|
||||
def test_bg_alias_is_known_command(self):
|
||||
"""The /bg alias is in GATEWAY_KNOWN_COMMANDS."""
|
||||
from hermes_cli.commands import GATEWAY_KNOWN_COMMANDS
|
||||
assert "bg" in GATEWAY_KNOWN_COMMANDS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI /background command definition
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBackgroundInCLICommands:
|
||||
"""Verify /background is registered in the CLI command system."""
|
||||
|
||||
def test_background_in_commands_dict(self):
|
||||
"""The /background command is in the COMMANDS dict."""
|
||||
from hermes_cli.commands import COMMANDS
|
||||
assert "/background" in COMMANDS
|
||||
|
||||
def test_bg_alias_in_commands_dict(self):
|
||||
"""The /bg alias is in the COMMANDS dict."""
|
||||
from hermes_cli.commands import COMMANDS
|
||||
assert "/bg" in COMMANDS
|
||||
|
||||
def test_background_in_session_category(self):
|
||||
"""The /background command is in the Session category."""
|
||||
from hermes_cli.commands import COMMANDS_BY_CATEGORY
|
||||
assert "/background" in COMMANDS_BY_CATEGORY["Session"]
|
||||
|
||||
def test_background_autocompletes(self):
|
||||
"""The /background command appears in autocomplete results."""
|
||||
from hermes_cli.commands import SlashCommandCompleter
|
||||
from prompt_toolkit.document import Document
|
||||
|
||||
completer = SlashCommandCompleter()
|
||||
doc = Document("backgro") # Partial match
|
||||
completions = list(completer.get_completions(doc, None))
|
||||
# Text doesn't start with / so no completions
|
||||
assert len(completions) == 0
|
||||
|
||||
doc = Document("/backgro") # With slash prefix
|
||||
completions = list(completer.get_completions(doc, None))
|
||||
cmd_displays = [str(c.display) for c in completions]
|
||||
assert any("/background" in d for d in cmd_displays)
|
||||
|
|
@ -0,0 +1,245 @@
|
|||
"""Tests for configurable background process notification modes.
|
||||
|
||||
The gateway process watcher pushes status updates to users' chats when
|
||||
background terminal commands run. ``display.background_process_notifications``
|
||||
controls verbosity: off | result | error | all (default).
|
||||
|
||||
Contributed by @PeterFile (PR #593), reimplemented on current main.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _FakeRegistry:
|
||||
"""Return pre-canned sessions, then None once exhausted."""
|
||||
|
||||
def __init__(self, sessions):
|
||||
self._sessions = list(sessions)
|
||||
|
||||
def get(self, session_id):
|
||||
if self._sessions:
|
||||
return self._sessions.pop(0)
|
||||
return None
|
||||
|
||||
|
||||
def _build_runner(monkeypatch, tmp_path, mode: str) -> GatewayRunner:
|
||||
"""Create a GatewayRunner with a fake config for the given mode."""
|
||||
(tmp_path / "config.yaml").write_text(
|
||||
f"display:\n background_process_notifications: {mode}\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
import gateway.run as gateway_run
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
|
||||
runner = GatewayRunner(GatewayConfig())
|
||||
adapter = SimpleNamespace(send=AsyncMock())
|
||||
runner.adapters[Platform.TELEGRAM] = adapter
|
||||
return runner
|
||||
|
||||
|
||||
def _watcher_dict(session_id="proc_test", thread_id=""):
|
||||
d = {
|
||||
"session_id": session_id,
|
||||
"check_interval": 0,
|
||||
"platform": "telegram",
|
||||
"chat_id": "123",
|
||||
}
|
||||
if thread_id:
|
||||
d["thread_id"] = thread_id
|
||||
return d
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _load_background_notifications_mode unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestLoadBackgroundNotificationsMode:
|
||||
|
||||
def test_defaults_to_all(self, monkeypatch, tmp_path):
|
||||
import gateway.run as gw
|
||||
monkeypatch.setattr(gw, "_hermes_home", tmp_path)
|
||||
monkeypatch.delenv("HERMES_BACKGROUND_NOTIFICATIONS", raising=False)
|
||||
assert GatewayRunner._load_background_notifications_mode() == "all"
|
||||
|
||||
def test_reads_config_yaml(self, monkeypatch, tmp_path):
|
||||
(tmp_path / "config.yaml").write_text(
|
||||
"display:\n background_process_notifications: error\n"
|
||||
)
|
||||
import gateway.run as gw
|
||||
monkeypatch.setattr(gw, "_hermes_home", tmp_path)
|
||||
monkeypatch.delenv("HERMES_BACKGROUND_NOTIFICATIONS", raising=False)
|
||||
assert GatewayRunner._load_background_notifications_mode() == "error"
|
||||
|
||||
def test_env_var_overrides_config(self, monkeypatch, tmp_path):
|
||||
(tmp_path / "config.yaml").write_text(
|
||||
"display:\n background_process_notifications: error\n"
|
||||
)
|
||||
import gateway.run as gw
|
||||
monkeypatch.setattr(gw, "_hermes_home", tmp_path)
|
||||
monkeypatch.setenv("HERMES_BACKGROUND_NOTIFICATIONS", "off")
|
||||
assert GatewayRunner._load_background_notifications_mode() == "off"
|
||||
|
||||
def test_false_value_maps_to_off(self, monkeypatch, tmp_path):
|
||||
(tmp_path / "config.yaml").write_text(
|
||||
"display:\n background_process_notifications: false\n"
|
||||
)
|
||||
import gateway.run as gw
|
||||
monkeypatch.setattr(gw, "_hermes_home", tmp_path)
|
||||
monkeypatch.delenv("HERMES_BACKGROUND_NOTIFICATIONS", raising=False)
|
||||
assert GatewayRunner._load_background_notifications_mode() == "off"
|
||||
|
||||
def test_invalid_value_defaults_to_all(self, monkeypatch, tmp_path):
|
||||
(tmp_path / "config.yaml").write_text(
|
||||
"display:\n background_process_notifications: banana\n"
|
||||
)
|
||||
import gateway.run as gw
|
||||
monkeypatch.setattr(gw, "_hermes_home", tmp_path)
|
||||
monkeypatch.delenv("HERMES_BACKGROUND_NOTIFICATIONS", raising=False)
|
||||
assert GatewayRunner._load_background_notifications_mode() == "all"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _run_process_watcher integration tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("mode", "sessions", "expected_calls", "expected_fragment"),
|
||||
[
|
||||
# all mode: running output → sends update
|
||||
(
|
||||
"all",
|
||||
[
|
||||
SimpleNamespace(output_buffer="building...\n", exited=False, exit_code=None),
|
||||
None, # process disappears → watcher exits
|
||||
],
|
||||
1,
|
||||
"is still running",
|
||||
),
|
||||
# result mode: running output → no update
|
||||
(
|
||||
"result",
|
||||
[
|
||||
SimpleNamespace(output_buffer="building...\n", exited=False, exit_code=None),
|
||||
None,
|
||||
],
|
||||
0,
|
||||
None,
|
||||
),
|
||||
# off mode: exited process → no notification
|
||||
(
|
||||
"off",
|
||||
[SimpleNamespace(output_buffer="done\n", exited=True, exit_code=0)],
|
||||
0,
|
||||
None,
|
||||
),
|
||||
# result mode: exited → notifies
|
||||
(
|
||||
"result",
|
||||
[SimpleNamespace(output_buffer="done\n", exited=True, exit_code=0)],
|
||||
1,
|
||||
"finished with exit code 0",
|
||||
),
|
||||
# error mode: exit 0 → no notification
|
||||
(
|
||||
"error",
|
||||
[SimpleNamespace(output_buffer="done\n", exited=True, exit_code=0)],
|
||||
0,
|
||||
None,
|
||||
),
|
||||
# error mode: exit 1 → notifies
|
||||
(
|
||||
"error",
|
||||
[SimpleNamespace(output_buffer="traceback\n", exited=True, exit_code=1)],
|
||||
1,
|
||||
"finished with exit code 1",
|
||||
),
|
||||
# all mode: exited → notifies
|
||||
(
|
||||
"all",
|
||||
[SimpleNamespace(output_buffer="ok\n", exited=True, exit_code=0)],
|
||||
1,
|
||||
"finished with exit code 0",
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_run_process_watcher_respects_notification_mode(
|
||||
monkeypatch, tmp_path, mode, sessions, expected_calls, expected_fragment
|
||||
):
|
||||
import tools.process_registry as pr_module
|
||||
|
||||
monkeypatch.setattr(pr_module, "process_registry", _FakeRegistry(sessions))
|
||||
|
||||
# Patch asyncio.sleep to avoid real delays
|
||||
async def _instant_sleep(*_a, **_kw):
|
||||
pass
|
||||
monkeypatch.setattr(asyncio, "sleep", _instant_sleep)
|
||||
|
||||
runner = _build_runner(monkeypatch, tmp_path, mode)
|
||||
adapter = runner.adapters[Platform.TELEGRAM]
|
||||
|
||||
await runner._run_process_watcher(_watcher_dict())
|
||||
|
||||
assert adapter.send.await_count == expected_calls, (
|
||||
f"mode={mode}: expected {expected_calls} sends, got {adapter.send.await_count}"
|
||||
)
|
||||
if expected_fragment is not None:
|
||||
sent_message = adapter.send.await_args.args[1]
|
||||
assert expected_fragment in sent_message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_thread_id_passed_to_send(monkeypatch, tmp_path):
|
||||
"""thread_id from watcher dict is forwarded as metadata to adapter.send()."""
|
||||
import tools.process_registry as pr_module
|
||||
|
||||
sessions = [SimpleNamespace(output_buffer="done\n", exited=True, exit_code=0)]
|
||||
monkeypatch.setattr(pr_module, "process_registry", _FakeRegistry(sessions))
|
||||
|
||||
async def _instant_sleep(*_a, **_kw):
|
||||
pass
|
||||
monkeypatch.setattr(asyncio, "sleep", _instant_sleep)
|
||||
|
||||
runner = _build_runner(monkeypatch, tmp_path, "all")
|
||||
adapter = runner.adapters[Platform.TELEGRAM]
|
||||
|
||||
await runner._run_process_watcher(_watcher_dict(thread_id="42"))
|
||||
|
||||
assert adapter.send.await_count == 1
|
||||
_, kwargs = adapter.send.call_args
|
||||
assert kwargs["metadata"] == {"thread_id": "42"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_thread_id_sends_no_metadata(monkeypatch, tmp_path):
|
||||
"""When thread_id is empty, metadata should be None (general topic)."""
|
||||
import tools.process_registry as pr_module
|
||||
|
||||
sessions = [SimpleNamespace(output_buffer="done\n", exited=True, exit_code=0)]
|
||||
monkeypatch.setattr(pr_module, "process_registry", _FakeRegistry(sessions))
|
||||
|
||||
async def _instant_sleep(*_a, **_kw):
|
||||
pass
|
||||
monkeypatch.setattr(asyncio, "sleep", _instant_sleep)
|
||||
|
||||
runner = _build_runner(monkeypatch, tmp_path, "all")
|
||||
adapter = runner.adapters[Platform.TELEGRAM]
|
||||
|
||||
await runner._run_process_watcher(_watcher_dict())
|
||||
|
||||
assert adapter.send.await_count == 1
|
||||
_, kwargs = adapter.send.call_args
|
||||
assert kwargs["metadata"] is None
|
||||
135
hermes_code/tests/gateway/test_base_topic_sessions.py
Normal file
135
hermes_code/tests/gateway/test_base_topic_sessions.py
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
"""Tests for BasePlatformAdapter topic-aware session handling."""
|
||||
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, SendResult
|
||||
from gateway.session import SessionSource, build_session_key
|
||||
|
||||
|
||||
class DummyTelegramAdapter(BasePlatformAdapter):
|
||||
def __init__(self):
|
||||
super().__init__(PlatformConfig(enabled=True, token="fake-token"), Platform.TELEGRAM)
|
||||
self.sent = []
|
||||
self.typing = []
|
||||
|
||||
async def connect(self) -> bool:
|
||||
return True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
return None
|
||||
|
||||
async def send(self, chat_id, content, reply_to=None, metadata=None) -> SendResult:
|
||||
self.sent.append(
|
||||
{
|
||||
"chat_id": chat_id,
|
||||
"content": content,
|
||||
"reply_to": reply_to,
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
return SendResult(success=True, message_id="1")
|
||||
|
||||
async def send_typing(self, chat_id: str, metadata=None) -> None:
|
||||
self.typing.append({"chat_id": chat_id, "metadata": metadata})
|
||||
return None
|
||||
|
||||
async def get_chat_info(self, chat_id: str):
|
||||
return {"id": chat_id}
|
||||
|
||||
|
||||
def _make_event(chat_id: str, thread_id: str, message_id: str = "1") -> MessageEvent:
|
||||
return MessageEvent(
|
||||
text="hello",
|
||||
source=SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id=chat_id,
|
||||
chat_type="group",
|
||||
thread_id=thread_id,
|
||||
),
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
|
||||
class TestBasePlatformTopicSessions:
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_does_not_interrupt_different_topic(self, monkeypatch):
|
||||
adapter = DummyTelegramAdapter()
|
||||
adapter.set_message_handler(lambda event: asyncio.sleep(0, result=None))
|
||||
|
||||
active_event = _make_event("-1001", "10")
|
||||
adapter._active_sessions[build_session_key(active_event.source)] = asyncio.Event()
|
||||
|
||||
scheduled = []
|
||||
|
||||
def fake_create_task(coro):
|
||||
scheduled.append(coro)
|
||||
coro.close()
|
||||
return SimpleNamespace()
|
||||
|
||||
monkeypatch.setattr(asyncio, "create_task", fake_create_task)
|
||||
|
||||
await adapter.handle_message(_make_event("-1001", "11"))
|
||||
|
||||
assert len(scheduled) == 1
|
||||
assert adapter._pending_messages == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_interrupts_same_topic(self, monkeypatch):
|
||||
adapter = DummyTelegramAdapter()
|
||||
adapter.set_message_handler(lambda event: asyncio.sleep(0, result=None))
|
||||
|
||||
active_event = _make_event("-1001", "10")
|
||||
adapter._active_sessions[build_session_key(active_event.source)] = asyncio.Event()
|
||||
|
||||
scheduled = []
|
||||
|
||||
def fake_create_task(coro):
|
||||
scheduled.append(coro)
|
||||
coro.close()
|
||||
return SimpleNamespace()
|
||||
|
||||
monkeypatch.setattr(asyncio, "create_task", fake_create_task)
|
||||
|
||||
pending_event = _make_event("-1001", "10", message_id="2")
|
||||
await adapter.handle_message(pending_event)
|
||||
|
||||
assert scheduled == []
|
||||
assert adapter.get_pending_message(build_session_key(pending_event.source)) == pending_event
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_message_background_replies_in_same_topic(self):
|
||||
adapter = DummyTelegramAdapter()
|
||||
typing_calls = []
|
||||
|
||||
async def handler(_event):
|
||||
await asyncio.sleep(0)
|
||||
return "ack"
|
||||
|
||||
async def hold_typing(_chat_id, interval=2.0, metadata=None):
|
||||
typing_calls.append({"chat_id": _chat_id, "metadata": metadata})
|
||||
await asyncio.Event().wait()
|
||||
|
||||
adapter.set_message_handler(handler)
|
||||
adapter._keep_typing = hold_typing
|
||||
|
||||
event = _make_event("-1001", "17585")
|
||||
await adapter._process_message_background(event, build_session_key(event.source))
|
||||
|
||||
assert adapter.sent == [
|
||||
{
|
||||
"chat_id": "-1001",
|
||||
"content": "ack",
|
||||
"reply_to": "1",
|
||||
"metadata": {"thread_id": "17585"},
|
||||
}
|
||||
]
|
||||
assert typing_calls == [
|
||||
{
|
||||
"chat_id": "-1001",
|
||||
"metadata": {"thread_id": "17585"},
|
||||
}
|
||||
]
|
||||
252
hermes_code/tests/gateway/test_channel_directory.py
Normal file
252
hermes_code/tests/gateway/test_channel_directory.py
Normal file
|
|
@ -0,0 +1,252 @@
|
|||
"""Tests for gateway/channel_directory.py — channel resolution and display."""
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from gateway.channel_directory import (
|
||||
resolve_channel_name,
|
||||
format_directory_for_display,
|
||||
load_directory,
|
||||
_build_from_sessions,
|
||||
DIRECTORY_PATH,
|
||||
)
|
||||
|
||||
|
||||
def _write_directory(tmp_path, platforms):
|
||||
"""Helper to write a fake channel directory."""
|
||||
data = {"updated_at": "2026-01-01T00:00:00", "platforms": platforms}
|
||||
cache_file = tmp_path / "channel_directory.json"
|
||||
cache_file.write_text(json.dumps(data))
|
||||
return cache_file
|
||||
|
||||
|
||||
class TestLoadDirectory:
|
||||
def test_missing_file(self, tmp_path):
|
||||
with patch("gateway.channel_directory.DIRECTORY_PATH", tmp_path / "nope.json"):
|
||||
result = load_directory()
|
||||
assert result["updated_at"] is None
|
||||
assert result["platforms"] == {}
|
||||
|
||||
def test_valid_file(self, tmp_path):
|
||||
cache_file = _write_directory(tmp_path, {
|
||||
"telegram": [{"id": "123", "name": "John", "type": "dm"}]
|
||||
})
|
||||
with patch("gateway.channel_directory.DIRECTORY_PATH", cache_file):
|
||||
result = load_directory()
|
||||
assert result["platforms"]["telegram"][0]["name"] == "John"
|
||||
|
||||
def test_corrupt_file(self, tmp_path):
|
||||
cache_file = tmp_path / "channel_directory.json"
|
||||
cache_file.write_text("{bad json")
|
||||
with patch("gateway.channel_directory.DIRECTORY_PATH", cache_file):
|
||||
result = load_directory()
|
||||
assert result["updated_at"] is None
|
||||
|
||||
|
||||
class TestResolveChannelName:
|
||||
def _setup(self, tmp_path, platforms):
|
||||
cache_file = _write_directory(tmp_path, platforms)
|
||||
return patch("gateway.channel_directory.DIRECTORY_PATH", cache_file)
|
||||
|
||||
def test_exact_match(self, tmp_path):
|
||||
platforms = {
|
||||
"discord": [
|
||||
{"id": "111", "name": "bot-home", "guild": "MyServer", "type": "channel"},
|
||||
{"id": "222", "name": "general", "guild": "MyServer", "type": "channel"},
|
||||
]
|
||||
}
|
||||
with self._setup(tmp_path, platforms):
|
||||
assert resolve_channel_name("discord", "bot-home") == "111"
|
||||
assert resolve_channel_name("discord", "#bot-home") == "111"
|
||||
|
||||
def test_case_insensitive(self, tmp_path):
|
||||
platforms = {
|
||||
"slack": [{"id": "C01", "name": "Engineering", "type": "channel"}]
|
||||
}
|
||||
with self._setup(tmp_path, platforms):
|
||||
assert resolve_channel_name("slack", "engineering") == "C01"
|
||||
assert resolve_channel_name("slack", "ENGINEERING") == "C01"
|
||||
|
||||
def test_guild_qualified_match(self, tmp_path):
|
||||
platforms = {
|
||||
"discord": [
|
||||
{"id": "111", "name": "general", "guild": "ServerA", "type": "channel"},
|
||||
{"id": "222", "name": "general", "guild": "ServerB", "type": "channel"},
|
||||
]
|
||||
}
|
||||
with self._setup(tmp_path, platforms):
|
||||
assert resolve_channel_name("discord", "ServerA/general") == "111"
|
||||
assert resolve_channel_name("discord", "ServerB/general") == "222"
|
||||
|
||||
def test_prefix_match_unambiguous(self, tmp_path):
|
||||
platforms = {
|
||||
"slack": [
|
||||
{"id": "C01", "name": "engineering-backend", "type": "channel"},
|
||||
{"id": "C02", "name": "design-team", "type": "channel"},
|
||||
]
|
||||
}
|
||||
with self._setup(tmp_path, platforms):
|
||||
# "engineering" prefix matches only one channel
|
||||
assert resolve_channel_name("slack", "engineering") == "C01"
|
||||
|
||||
def test_prefix_match_ambiguous_returns_none(self, tmp_path):
|
||||
platforms = {
|
||||
"slack": [
|
||||
{"id": "C01", "name": "eng-backend", "type": "channel"},
|
||||
{"id": "C02", "name": "eng-frontend", "type": "channel"},
|
||||
]
|
||||
}
|
||||
with self._setup(tmp_path, platforms):
|
||||
assert resolve_channel_name("slack", "eng") is None
|
||||
|
||||
def test_no_channels_returns_none(self, tmp_path):
|
||||
with self._setup(tmp_path, {}):
|
||||
assert resolve_channel_name("telegram", "someone") is None
|
||||
|
||||
def test_no_match_returns_none(self, tmp_path):
|
||||
platforms = {
|
||||
"telegram": [{"id": "123", "name": "John", "type": "dm"}]
|
||||
}
|
||||
with self._setup(tmp_path, platforms):
|
||||
assert resolve_channel_name("telegram", "nonexistent") is None
|
||||
|
||||
def test_topic_name_resolves_to_composite_id(self, tmp_path):
|
||||
platforms = {
|
||||
"telegram": [{"id": "-1001:17585", "name": "Coaching Chat / topic 17585", "type": "group"}]
|
||||
}
|
||||
with self._setup(tmp_path, platforms):
|
||||
assert resolve_channel_name("telegram", "Coaching Chat / topic 17585") == "-1001:17585"
|
||||
|
||||
|
||||
class TestBuildFromSessions:
|
||||
def _write_sessions(self, tmp_path, sessions_data):
|
||||
"""Write sessions.json at the path _build_from_sessions expects."""
|
||||
sessions_path = tmp_path / "sessions" / "sessions.json"
|
||||
sessions_path.parent.mkdir(parents=True)
|
||||
sessions_path.write_text(json.dumps(sessions_data))
|
||||
|
||||
def test_builds_from_sessions_json(self, tmp_path):
|
||||
self._write_sessions(tmp_path, {
|
||||
"session_1": {
|
||||
"origin": {
|
||||
"platform": "telegram",
|
||||
"chat_id": "12345",
|
||||
"chat_name": "Alice",
|
||||
},
|
||||
"chat_type": "dm",
|
||||
},
|
||||
"session_2": {
|
||||
"origin": {
|
||||
"platform": "telegram",
|
||||
"chat_id": "67890",
|
||||
"user_name": "Bob",
|
||||
},
|
||||
"chat_type": "group",
|
||||
},
|
||||
"session_3": {
|
||||
"origin": {
|
||||
"platform": "discord",
|
||||
"chat_id": "99999",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
entries = _build_from_sessions("telegram")
|
||||
|
||||
assert len(entries) == 2
|
||||
names = {e["name"] for e in entries}
|
||||
assert "Alice" in names
|
||||
assert "Bob" in names
|
||||
|
||||
def test_missing_sessions_file(self, tmp_path):
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
entries = _build_from_sessions("telegram")
|
||||
assert entries == []
|
||||
|
||||
def test_deduplication_by_chat_id(self, tmp_path):
|
||||
self._write_sessions(tmp_path, {
|
||||
"s1": {"origin": {"platform": "telegram", "chat_id": "123", "chat_name": "X"}},
|
||||
"s2": {"origin": {"platform": "telegram", "chat_id": "123", "chat_name": "X"}},
|
||||
})
|
||||
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
entries = _build_from_sessions("telegram")
|
||||
|
||||
assert len(entries) == 1
|
||||
|
||||
def test_keeps_distinct_topics_with_same_chat_id(self, tmp_path):
|
||||
self._write_sessions(tmp_path, {
|
||||
"group_root": {
|
||||
"origin": {"platform": "telegram", "chat_id": "-1001", "chat_name": "Coaching Chat"},
|
||||
"chat_type": "group",
|
||||
},
|
||||
"topic_a": {
|
||||
"origin": {
|
||||
"platform": "telegram",
|
||||
"chat_id": "-1001",
|
||||
"chat_name": "Coaching Chat",
|
||||
"thread_id": "17585",
|
||||
},
|
||||
"chat_type": "group",
|
||||
},
|
||||
"topic_b": {
|
||||
"origin": {
|
||||
"platform": "telegram",
|
||||
"chat_id": "-1001",
|
||||
"chat_name": "Coaching Chat",
|
||||
"thread_id": "17587",
|
||||
},
|
||||
"chat_type": "group",
|
||||
},
|
||||
})
|
||||
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
entries = _build_from_sessions("telegram")
|
||||
|
||||
ids = {entry["id"] for entry in entries}
|
||||
names = {entry["name"] for entry in entries}
|
||||
assert ids == {"-1001", "-1001:17585", "-1001:17587"}
|
||||
assert "Coaching Chat" in names
|
||||
assert "Coaching Chat / topic 17585" in names
|
||||
assert "Coaching Chat / topic 17587" in names
|
||||
|
||||
|
||||
class TestFormatDirectoryForDisplay:
|
||||
def test_empty_directory(self, tmp_path):
|
||||
with patch("gateway.channel_directory.DIRECTORY_PATH", tmp_path / "nope.json"):
|
||||
result = format_directory_for_display()
|
||||
assert "No messaging platforms" in result
|
||||
|
||||
def test_telegram_display(self, tmp_path):
|
||||
cache_file = _write_directory(tmp_path, {
|
||||
"telegram": [
|
||||
{"id": "123", "name": "Alice", "type": "dm"},
|
||||
{"id": "456", "name": "Dev Group", "type": "group"},
|
||||
{"id": "-1001:17585", "name": "Coaching Chat / topic 17585", "type": "group"},
|
||||
]
|
||||
})
|
||||
with patch("gateway.channel_directory.DIRECTORY_PATH", cache_file):
|
||||
result = format_directory_for_display()
|
||||
|
||||
assert "Telegram:" in result
|
||||
assert "telegram:Alice" in result
|
||||
assert "telegram:Dev Group" in result
|
||||
assert "telegram:Coaching Chat / topic 17585" in result
|
||||
|
||||
def test_discord_grouped_by_guild(self, tmp_path):
|
||||
cache_file = _write_directory(tmp_path, {
|
||||
"discord": [
|
||||
{"id": "1", "name": "general", "guild": "Server1", "type": "channel"},
|
||||
{"id": "2", "name": "bot-home", "guild": "Server1", "type": "channel"},
|
||||
{"id": "3", "name": "chat", "guild": "Server2", "type": "channel"},
|
||||
]
|
||||
})
|
||||
with patch("gateway.channel_directory.DIRECTORY_PATH", cache_file):
|
||||
result = format_directory_for_display()
|
||||
|
||||
assert "Discord (Server1):" in result
|
||||
assert "Discord (Server2):" in result
|
||||
assert "discord:#general" in result
|
||||
194
hermes_code/tests/gateway/test_config.py
Normal file
194
hermes_code/tests/gateway/test_config.py
Normal file
|
|
@ -0,0 +1,194 @@
|
|||
"""Tests for gateway configuration management."""
|
||||
|
||||
from gateway.config import (
|
||||
GatewayConfig,
|
||||
HomeChannel,
|
||||
Platform,
|
||||
PlatformConfig,
|
||||
SessionResetPolicy,
|
||||
load_gateway_config,
|
||||
)
|
||||
|
||||
|
||||
class TestHomeChannelRoundtrip:
|
||||
def test_to_dict_from_dict(self):
|
||||
hc = HomeChannel(platform=Platform.DISCORD, chat_id="999", name="general")
|
||||
d = hc.to_dict()
|
||||
restored = HomeChannel.from_dict(d)
|
||||
|
||||
assert restored.platform == Platform.DISCORD
|
||||
assert restored.chat_id == "999"
|
||||
assert restored.name == "general"
|
||||
|
||||
|
||||
class TestPlatformConfigRoundtrip:
|
||||
def test_to_dict_from_dict(self):
|
||||
pc = PlatformConfig(
|
||||
enabled=True,
|
||||
token="tok_123",
|
||||
home_channel=HomeChannel(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="555",
|
||||
name="Home",
|
||||
),
|
||||
extra={"foo": "bar"},
|
||||
)
|
||||
d = pc.to_dict()
|
||||
restored = PlatformConfig.from_dict(d)
|
||||
|
||||
assert restored.enabled is True
|
||||
assert restored.token == "tok_123"
|
||||
assert restored.home_channel.chat_id == "555"
|
||||
assert restored.extra == {"foo": "bar"}
|
||||
|
||||
def test_disabled_no_token(self):
|
||||
pc = PlatformConfig()
|
||||
d = pc.to_dict()
|
||||
restored = PlatformConfig.from_dict(d)
|
||||
assert restored.enabled is False
|
||||
assert restored.token is None
|
||||
|
||||
|
||||
class TestGetConnectedPlatforms:
|
||||
def test_returns_enabled_with_token(self):
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.TELEGRAM: PlatformConfig(enabled=True, token="t"),
|
||||
Platform.DISCORD: PlatformConfig(enabled=False, token="d"),
|
||||
Platform.SLACK: PlatformConfig(enabled=True), # no token
|
||||
},
|
||||
)
|
||||
connected = config.get_connected_platforms()
|
||||
assert Platform.TELEGRAM in connected
|
||||
assert Platform.DISCORD not in connected
|
||||
assert Platform.SLACK not in connected
|
||||
|
||||
def test_empty_platforms(self):
|
||||
config = GatewayConfig()
|
||||
assert config.get_connected_platforms() == []
|
||||
|
||||
|
||||
class TestSessionResetPolicy:
|
||||
def test_roundtrip(self):
|
||||
policy = SessionResetPolicy(mode="idle", at_hour=6, idle_minutes=120)
|
||||
d = policy.to_dict()
|
||||
restored = SessionResetPolicy.from_dict(d)
|
||||
assert restored.mode == "idle"
|
||||
assert restored.at_hour == 6
|
||||
assert restored.idle_minutes == 120
|
||||
|
||||
def test_defaults(self):
|
||||
policy = SessionResetPolicy()
|
||||
assert policy.mode == "both"
|
||||
assert policy.at_hour == 4
|
||||
assert policy.idle_minutes == 1440
|
||||
|
||||
def test_from_dict_treats_null_values_as_defaults(self):
|
||||
restored = SessionResetPolicy.from_dict(
|
||||
{"mode": None, "at_hour": None, "idle_minutes": None}
|
||||
)
|
||||
assert restored.mode == "both"
|
||||
assert restored.at_hour == 4
|
||||
assert restored.idle_minutes == 1440
|
||||
|
||||
|
||||
class TestGatewayConfigRoundtrip:
|
||||
def test_full_roundtrip(self):
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.TELEGRAM: PlatformConfig(
|
||||
enabled=True,
|
||||
token="tok_123",
|
||||
home_channel=HomeChannel(Platform.TELEGRAM, "123", "Home"),
|
||||
),
|
||||
},
|
||||
reset_triggers=["/new"],
|
||||
quick_commands={"limits": {"type": "exec", "command": "echo ok"}},
|
||||
group_sessions_per_user=False,
|
||||
)
|
||||
d = config.to_dict()
|
||||
restored = GatewayConfig.from_dict(d)
|
||||
|
||||
assert Platform.TELEGRAM in restored.platforms
|
||||
assert restored.platforms[Platform.TELEGRAM].token == "tok_123"
|
||||
assert restored.reset_triggers == ["/new"]
|
||||
assert restored.quick_commands == {"limits": {"type": "exec", "command": "echo ok"}}
|
||||
assert restored.group_sessions_per_user is False
|
||||
|
||||
def test_roundtrip_preserves_unauthorized_dm_behavior(self):
|
||||
config = GatewayConfig(
|
||||
unauthorized_dm_behavior="ignore",
|
||||
platforms={
|
||||
Platform.WHATSAPP: PlatformConfig(
|
||||
enabled=True,
|
||||
extra={"unauthorized_dm_behavior": "pair"},
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
restored = GatewayConfig.from_dict(config.to_dict())
|
||||
|
||||
assert restored.unauthorized_dm_behavior == "ignore"
|
||||
assert restored.platforms[Platform.WHATSAPP].extra["unauthorized_dm_behavior"] == "pair"
|
||||
|
||||
|
||||
class TestLoadGatewayConfig:
|
||||
def test_bridges_quick_commands_from_config_yaml(self, tmp_path, monkeypatch):
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
config_path = hermes_home / "config.yaml"
|
||||
config_path.write_text(
|
||||
"quick_commands:\n"
|
||||
" limits:\n"
|
||||
" type: exec\n"
|
||||
" command: echo ok\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
config = load_gateway_config()
|
||||
|
||||
assert config.quick_commands == {"limits": {"type": "exec", "command": "echo ok"}}
|
||||
|
||||
def test_bridges_group_sessions_per_user_from_config_yaml(self, tmp_path, monkeypatch):
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
config_path = hermes_home / "config.yaml"
|
||||
config_path.write_text("group_sessions_per_user: false\n", encoding="utf-8")
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
config = load_gateway_config()
|
||||
|
||||
assert config.group_sessions_per_user is False
|
||||
|
||||
def test_invalid_quick_commands_in_config_yaml_are_ignored(self, tmp_path, monkeypatch):
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
config_path = hermes_home / "config.yaml"
|
||||
config_path.write_text("quick_commands: not-a-mapping\n", encoding="utf-8")
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
config = load_gateway_config()
|
||||
|
||||
assert config.quick_commands == {}
|
||||
|
||||
def test_bridges_unauthorized_dm_behavior_from_config_yaml(self, tmp_path, monkeypatch):
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
config_path = hermes_home / "config.yaml"
|
||||
config_path.write_text(
|
||||
"unauthorized_dm_behavior: ignore\n"
|
||||
"whatsapp:\n"
|
||||
" unauthorized_dm_behavior: pair\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
config = load_gateway_config()
|
||||
|
||||
assert config.unauthorized_dm_behavior == "ignore"
|
||||
assert config.platforms[Platform.WHATSAPP].extra["unauthorized_dm_behavior"] == "pair"
|
||||
148
hermes_code/tests/gateway/test_config_cwd_bridge.py
Normal file
148
hermes_code/tests/gateway/test_config_cwd_bridge.py
Normal file
|
|
@ -0,0 +1,148 @@
|
|||
"""Tests for the config.yaml → env var bridge logic in gateway/run.py.
|
||||
|
||||
Specifically tests that top-level `cwd:` and `backend:` in config.yaml
|
||||
are correctly bridged to TERMINAL_CWD / TERMINAL_ENV env vars as
|
||||
convenience aliases for `terminal.cwd` / `terminal.backend`.
|
||||
|
||||
The bridge logic is module-level code in gateway/run.py, so we test
|
||||
the semantics by reimplementing the relevant config bridge snippet and
|
||||
asserting the expected env var outcomes.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import pytest
|
||||
|
||||
|
||||
def _simulate_config_bridge(cfg: dict, initial_env: dict | None = None):
|
||||
"""Simulate the gateway config bridge logic from gateway/run.py.
|
||||
|
||||
Returns the resulting env dict (only TERMINAL_* and MESSAGING_CWD keys).
|
||||
"""
|
||||
env = dict(initial_env or {})
|
||||
|
||||
# --- Replicate lines 54-56: generic top-level bridge (for context) ---
|
||||
for key, val in cfg.items():
|
||||
if isinstance(val, (str, int, float, bool)) and key not in env:
|
||||
env[key] = str(val)
|
||||
|
||||
# --- Replicate lines 59-87: terminal config bridge ---
|
||||
terminal_cfg = cfg.get("terminal", {})
|
||||
if terminal_cfg and isinstance(terminal_cfg, dict):
|
||||
terminal_env_map = {
|
||||
"backend": "TERMINAL_ENV",
|
||||
"cwd": "TERMINAL_CWD",
|
||||
"timeout": "TERMINAL_TIMEOUT",
|
||||
}
|
||||
for cfg_key, env_var in terminal_env_map.items():
|
||||
if cfg_key in terminal_cfg:
|
||||
val = terminal_cfg[cfg_key]
|
||||
if isinstance(val, list):
|
||||
env[env_var] = json.dumps(val)
|
||||
else:
|
||||
env[env_var] = str(val)
|
||||
|
||||
# --- NEW: top-level aliases (the fix being tested) ---
|
||||
top_level_aliases = {
|
||||
"cwd": "TERMINAL_CWD",
|
||||
"backend": "TERMINAL_ENV",
|
||||
}
|
||||
for alias_key, alias_env in top_level_aliases.items():
|
||||
if alias_env not in env:
|
||||
alias_val = cfg.get(alias_key)
|
||||
if isinstance(alias_val, str) and alias_val.strip():
|
||||
env[alias_env] = alias_val.strip()
|
||||
|
||||
# --- Replicate lines 144-147: MESSAGING_CWD fallback ---
|
||||
configured_cwd = env.get("TERMINAL_CWD", "")
|
||||
if not configured_cwd or configured_cwd in (".", "auto", "cwd"):
|
||||
messaging_cwd = env.get("MESSAGING_CWD") or "/root" # Path.home() for root
|
||||
env["TERMINAL_CWD"] = messaging_cwd
|
||||
|
||||
return env
|
||||
|
||||
|
||||
class TestTopLevelCwdAlias:
|
||||
"""Top-level `cwd:` should be treated as `terminal.cwd`."""
|
||||
|
||||
def test_top_level_cwd_sets_terminal_cwd(self):
|
||||
cfg = {"cwd": "/home/hermes/projects"}
|
||||
result = _simulate_config_bridge(cfg)
|
||||
assert result["TERMINAL_CWD"] == "/home/hermes/projects"
|
||||
|
||||
def test_top_level_backend_sets_terminal_env(self):
|
||||
cfg = {"backend": "docker"}
|
||||
result = _simulate_config_bridge(cfg)
|
||||
assert result["TERMINAL_ENV"] == "docker"
|
||||
|
||||
def test_top_level_cwd_and_backend(self):
|
||||
cfg = {"backend": "local", "cwd": "/home/hermes/projects"}
|
||||
result = _simulate_config_bridge(cfg)
|
||||
assert result["TERMINAL_CWD"] == "/home/hermes/projects"
|
||||
assert result["TERMINAL_ENV"] == "local"
|
||||
|
||||
def test_nested_terminal_takes_precedence_over_top_level(self):
|
||||
"""terminal.cwd should win over top-level cwd."""
|
||||
cfg = {
|
||||
"cwd": "/should/not/use",
|
||||
"terminal": {"cwd": "/home/hermes/real"},
|
||||
}
|
||||
result = _simulate_config_bridge(cfg)
|
||||
assert result["TERMINAL_CWD"] == "/home/hermes/real"
|
||||
|
||||
def test_nested_terminal_backend_takes_precedence(self):
|
||||
cfg = {
|
||||
"backend": "should-not-use",
|
||||
"terminal": {"backend": "docker"},
|
||||
}
|
||||
result = _simulate_config_bridge(cfg)
|
||||
assert result["TERMINAL_ENV"] == "docker"
|
||||
|
||||
def test_no_cwd_falls_back_to_messaging_cwd(self):
|
||||
cfg = {}
|
||||
result = _simulate_config_bridge(cfg, {"MESSAGING_CWD": "/home/hermes/projects"})
|
||||
assert result["TERMINAL_CWD"] == "/home/hermes/projects"
|
||||
|
||||
def test_no_cwd_no_messaging_cwd_falls_back_to_home(self):
|
||||
cfg = {}
|
||||
result = _simulate_config_bridge(cfg)
|
||||
assert result["TERMINAL_CWD"] == "/root" # Path.home() for root user
|
||||
|
||||
def test_dot_cwd_triggers_messaging_fallback(self):
|
||||
"""cwd: '.' should trigger MESSAGING_CWD fallback."""
|
||||
cfg = {"cwd": "."}
|
||||
result = _simulate_config_bridge(cfg, {"MESSAGING_CWD": "/home/hermes"})
|
||||
# "." is stripped but truthy, so it gets set as TERMINAL_CWD
|
||||
# Then the MESSAGING_CWD fallback does NOT trigger since TERMINAL_CWD
|
||||
# is set and not in (".", "auto", "cwd").
|
||||
# Wait — "." IS in the fallback list! So this should fall through.
|
||||
# Actually the alias sets it to ".", then the messaging fallback
|
||||
# checks if it's in (".", "auto", "cwd") and overrides.
|
||||
assert result["TERMINAL_CWD"] == "/home/hermes"
|
||||
|
||||
def test_auto_cwd_triggers_messaging_fallback(self):
|
||||
cfg = {"cwd": "auto"}
|
||||
result = _simulate_config_bridge(cfg, {"MESSAGING_CWD": "/home/hermes"})
|
||||
assert result["TERMINAL_CWD"] == "/home/hermes"
|
||||
|
||||
def test_empty_cwd_ignored(self):
|
||||
cfg = {"cwd": ""}
|
||||
result = _simulate_config_bridge(cfg, {"MESSAGING_CWD": "/home/hermes"})
|
||||
assert result["TERMINAL_CWD"] == "/home/hermes"
|
||||
|
||||
def test_whitespace_only_cwd_ignored(self):
|
||||
cfg = {"cwd": " "}
|
||||
result = _simulate_config_bridge(cfg, {"MESSAGING_CWD": "/fallback"})
|
||||
assert result["TERMINAL_CWD"] == "/fallback"
|
||||
|
||||
def test_messaging_cwd_env_var_works(self):
|
||||
"""MESSAGING_CWD in initial env should be picked up as fallback."""
|
||||
cfg = {}
|
||||
result = _simulate_config_bridge(cfg, {"MESSAGING_CWD": "/home/hermes/projects"})
|
||||
assert result["TERMINAL_CWD"] == "/home/hermes/projects"
|
||||
|
||||
def test_top_level_cwd_beats_messaging_cwd(self):
|
||||
"""Explicit top-level cwd should take precedence over MESSAGING_CWD."""
|
||||
cfg = {"cwd": "/from/config"}
|
||||
result = _simulate_config_bridge(cfg, {"MESSAGING_CWD": "/from/env"})
|
||||
assert result["TERMINAL_CWD"] == "/from/config"
|
||||
96
hermes_code/tests/gateway/test_delivery.py
Normal file
96
hermes_code/tests/gateway/test_delivery.py
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
"""Tests for the delivery routing module."""
|
||||
|
||||
from gateway.config import Platform, GatewayConfig, PlatformConfig, HomeChannel
|
||||
from gateway.delivery import DeliveryRouter, DeliveryTarget, parse_deliver_spec
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
class TestParseTargetPlatformChat:
|
||||
def test_explicit_telegram_chat(self):
|
||||
target = DeliveryTarget.parse("telegram:12345")
|
||||
assert target.platform == Platform.TELEGRAM
|
||||
assert target.chat_id == "12345"
|
||||
assert target.is_explicit is True
|
||||
|
||||
def test_platform_only_no_chat_id(self):
|
||||
target = DeliveryTarget.parse("discord")
|
||||
assert target.platform == Platform.DISCORD
|
||||
assert target.chat_id is None
|
||||
assert target.is_explicit is False
|
||||
|
||||
def test_local_target(self):
|
||||
target = DeliveryTarget.parse("local")
|
||||
assert target.platform == Platform.LOCAL
|
||||
assert target.chat_id is None
|
||||
|
||||
def test_origin_with_source(self):
|
||||
origin = SessionSource(platform=Platform.TELEGRAM, chat_id="789", thread_id="42")
|
||||
target = DeliveryTarget.parse("origin", origin=origin)
|
||||
assert target.platform == Platform.TELEGRAM
|
||||
assert target.chat_id == "789"
|
||||
assert target.thread_id == "42"
|
||||
assert target.is_origin is True
|
||||
|
||||
def test_origin_without_source(self):
|
||||
target = DeliveryTarget.parse("origin")
|
||||
assert target.platform == Platform.LOCAL
|
||||
assert target.is_origin is True
|
||||
|
||||
def test_unknown_platform(self):
|
||||
target = DeliveryTarget.parse("unknown_platform")
|
||||
assert target.platform == Platform.LOCAL
|
||||
|
||||
|
||||
class TestParseDeliverSpec:
|
||||
def test_none_returns_default(self):
|
||||
result = parse_deliver_spec(None)
|
||||
assert result == "origin"
|
||||
|
||||
def test_empty_string_returns_default(self):
|
||||
result = parse_deliver_spec("")
|
||||
assert result == "origin"
|
||||
|
||||
def test_custom_default(self):
|
||||
result = parse_deliver_spec(None, default="local")
|
||||
assert result == "local"
|
||||
|
||||
def test_passthrough_string(self):
|
||||
result = parse_deliver_spec("telegram")
|
||||
assert result == "telegram"
|
||||
|
||||
def test_passthrough_list(self):
|
||||
result = parse_deliver_spec(["local", "telegram"])
|
||||
assert result == ["local", "telegram"]
|
||||
|
||||
|
||||
class TestTargetToStringRoundtrip:
|
||||
def test_origin_roundtrip(self):
|
||||
origin = SessionSource(platform=Platform.TELEGRAM, chat_id="111", thread_id="42")
|
||||
target = DeliveryTarget.parse("origin", origin=origin)
|
||||
assert target.to_string() == "origin"
|
||||
|
||||
def test_local_roundtrip(self):
|
||||
target = DeliveryTarget.parse("local")
|
||||
assert target.to_string() == "local"
|
||||
|
||||
def test_platform_only_roundtrip(self):
|
||||
target = DeliveryTarget.parse("discord")
|
||||
assert target.to_string() == "discord"
|
||||
|
||||
def test_explicit_chat_roundtrip(self):
|
||||
target = DeliveryTarget.parse("telegram:999")
|
||||
s = target.to_string()
|
||||
assert s == "telegram:999"
|
||||
|
||||
reparsed = DeliveryTarget.parse(s)
|
||||
assert reparsed.platform == Platform.TELEGRAM
|
||||
assert reparsed.chat_id == "999"
|
||||
|
||||
|
||||
class TestDeliveryRouter:
|
||||
def test_resolve_targets_does_not_duplicate_local_when_explicit(self):
|
||||
router = DeliveryRouter(GatewayConfig(always_log_local=True))
|
||||
|
||||
targets = router.resolve_targets(["local"])
|
||||
|
||||
assert [target.platform for target in targets] == [Platform.LOCAL]
|
||||
274
hermes_code/tests/gateway/test_dingtalk.py
Normal file
274
hermes_code/tests/gateway/test_dingtalk.py
Normal file
|
|
@ -0,0 +1,274 @@
|
|||
"""Tests for DingTalk platform adapter."""
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch, PropertyMock
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Requirements check
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDingTalkRequirements:
|
||||
|
||||
def test_returns_false_when_sdk_missing(self, monkeypatch):
|
||||
with patch.dict("sys.modules", {"dingtalk_stream": None}):
|
||||
monkeypatch.setattr(
|
||||
"gateway.platforms.dingtalk.DINGTALK_STREAM_AVAILABLE", False
|
||||
)
|
||||
from gateway.platforms.dingtalk import check_dingtalk_requirements
|
||||
assert check_dingtalk_requirements() is False
|
||||
|
||||
def test_returns_false_when_env_vars_missing(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"gateway.platforms.dingtalk.DINGTALK_STREAM_AVAILABLE", True
|
||||
)
|
||||
monkeypatch.setattr("gateway.platforms.dingtalk.HTTPX_AVAILABLE", True)
|
||||
monkeypatch.delenv("DINGTALK_CLIENT_ID", raising=False)
|
||||
monkeypatch.delenv("DINGTALK_CLIENT_SECRET", raising=False)
|
||||
from gateway.platforms.dingtalk import check_dingtalk_requirements
|
||||
assert check_dingtalk_requirements() is False
|
||||
|
||||
def test_returns_true_when_all_available(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"gateway.platforms.dingtalk.DINGTALK_STREAM_AVAILABLE", True
|
||||
)
|
||||
monkeypatch.setattr("gateway.platforms.dingtalk.HTTPX_AVAILABLE", True)
|
||||
monkeypatch.setenv("DINGTALK_CLIENT_ID", "test-id")
|
||||
monkeypatch.setenv("DINGTALK_CLIENT_SECRET", "test-secret")
|
||||
from gateway.platforms.dingtalk import check_dingtalk_requirements
|
||||
assert check_dingtalk_requirements() is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Adapter construction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDingTalkAdapterInit:
|
||||
|
||||
def test_reads_config_from_extra(self):
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter
|
||||
config = PlatformConfig(
|
||||
enabled=True,
|
||||
extra={"client_id": "cfg-id", "client_secret": "cfg-secret"},
|
||||
)
|
||||
adapter = DingTalkAdapter(config)
|
||||
assert adapter._client_id == "cfg-id"
|
||||
assert adapter._client_secret == "cfg-secret"
|
||||
assert adapter.name == "Dingtalk" # base class uses .title()
|
||||
|
||||
def test_falls_back_to_env_vars(self, monkeypatch):
|
||||
monkeypatch.setenv("DINGTALK_CLIENT_ID", "env-id")
|
||||
monkeypatch.setenv("DINGTALK_CLIENT_SECRET", "env-secret")
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter
|
||||
config = PlatformConfig(enabled=True)
|
||||
adapter = DingTalkAdapter(config)
|
||||
assert adapter._client_id == "env-id"
|
||||
assert adapter._client_secret == "env-secret"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Message text extraction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExtractText:
|
||||
|
||||
def test_extracts_dict_text(self):
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter
|
||||
msg = MagicMock()
|
||||
msg.text = {"content": " hello world "}
|
||||
msg.rich_text = None
|
||||
assert DingTalkAdapter._extract_text(msg) == "hello world"
|
||||
|
||||
def test_extracts_string_text(self):
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter
|
||||
msg = MagicMock()
|
||||
msg.text = "plain text"
|
||||
msg.rich_text = None
|
||||
assert DingTalkAdapter._extract_text(msg) == "plain text"
|
||||
|
||||
def test_falls_back_to_rich_text(self):
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter
|
||||
msg = MagicMock()
|
||||
msg.text = ""
|
||||
msg.rich_text = [{"text": "part1"}, {"text": "part2"}, {"image": "url"}]
|
||||
assert DingTalkAdapter._extract_text(msg) == "part1 part2"
|
||||
|
||||
def test_returns_empty_for_no_content(self):
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter
|
||||
msg = MagicMock()
|
||||
msg.text = ""
|
||||
msg.rich_text = None
|
||||
assert DingTalkAdapter._extract_text(msg) == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Deduplication
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDeduplication:
|
||||
|
||||
def test_first_message_not_duplicate(self):
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter
|
||||
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
|
||||
assert adapter._is_duplicate("msg-1") is False
|
||||
|
||||
def test_second_same_message_is_duplicate(self):
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter
|
||||
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
|
||||
adapter._is_duplicate("msg-1")
|
||||
assert adapter._is_duplicate("msg-1") is True
|
||||
|
||||
def test_different_messages_not_duplicate(self):
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter
|
||||
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
|
||||
adapter._is_duplicate("msg-1")
|
||||
assert adapter._is_duplicate("msg-2") is False
|
||||
|
||||
def test_cache_cleanup_on_overflow(self):
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter, DEDUP_MAX_SIZE
|
||||
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
|
||||
# Fill beyond max
|
||||
for i in range(DEDUP_MAX_SIZE + 10):
|
||||
adapter._is_duplicate(f"msg-{i}")
|
||||
# Cache should have been pruned
|
||||
assert len(adapter._seen_messages) <= DEDUP_MAX_SIZE + 10
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Send
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSend:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_posts_to_webhook(self):
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter
|
||||
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.text = "OK"
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(return_value=mock_response)
|
||||
adapter._http_client = mock_client
|
||||
|
||||
result = await adapter.send(
|
||||
"chat-123", "Hello!",
|
||||
metadata={"session_webhook": "https://dingtalk.example/webhook"}
|
||||
)
|
||||
assert result.success is True
|
||||
mock_client.post.assert_called_once()
|
||||
call_args = mock_client.post.call_args
|
||||
assert call_args[0][0] == "https://dingtalk.example/webhook"
|
||||
payload = call_args[1]["json"]
|
||||
assert payload["msgtype"] == "markdown"
|
||||
assert payload["markdown"]["title"] == "Hermes"
|
||||
assert payload["markdown"]["text"] == "Hello!"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_fails_without_webhook(self):
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter
|
||||
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
|
||||
adapter._http_client = AsyncMock()
|
||||
|
||||
result = await adapter.send("chat-123", "Hello!")
|
||||
assert result.success is False
|
||||
assert "session_webhook" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_uses_cached_webhook(self):
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter
|
||||
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(return_value=mock_response)
|
||||
adapter._http_client = mock_client
|
||||
adapter._session_webhooks["chat-123"] = "https://cached.example/webhook"
|
||||
|
||||
result = await adapter.send("chat-123", "Hello!")
|
||||
assert result.success is True
|
||||
assert mock_client.post.call_args[0][0] == "https://cached.example/webhook"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_handles_http_error(self):
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter
|
||||
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 400
|
||||
mock_response.text = "Bad Request"
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(return_value=mock_response)
|
||||
adapter._http_client = mock_client
|
||||
|
||||
result = await adapter.send(
|
||||
"chat-123", "Hello!",
|
||||
metadata={"session_webhook": "https://example/webhook"}
|
||||
)
|
||||
assert result.success is False
|
||||
assert "400" in result.error
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Connect / disconnect
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConnect:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_fails_without_sdk(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"gateway.platforms.dingtalk.DINGTALK_STREAM_AVAILABLE", False
|
||||
)
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter
|
||||
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
|
||||
result = await adapter.connect()
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_fails_without_credentials(self):
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter
|
||||
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
|
||||
adapter._client_id = ""
|
||||
adapter._client_secret = ""
|
||||
result = await adapter.connect()
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_cleans_up(self):
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter
|
||||
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
|
||||
adapter._session_webhooks["a"] = "http://x"
|
||||
adapter._seen_messages["b"] = 1.0
|
||||
adapter._http_client = AsyncMock()
|
||||
adapter._stream_task = None
|
||||
|
||||
await adapter.disconnect()
|
||||
assert len(adapter._session_webhooks) == 0
|
||||
assert len(adapter._seen_messages) == 0
|
||||
assert adapter._http_client is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Platform enum
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPlatformEnum:
|
||||
|
||||
def test_dingtalk_in_platform_enum(self):
|
||||
assert Platform.DINGTALK.value == "dingtalk"
|
||||
117
hermes_code/tests/gateway/test_discord_bot_filter.py
Normal file
117
hermes_code/tests/gateway/test_discord_bot_filter.py
Normal file
|
|
@ -0,0 +1,117 @@
|
|||
"""Tests for Discord bot message filtering (DISCORD_ALLOW_BOTS)."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
|
||||
def _make_author(*, bot: bool = False, is_self: bool = False):
|
||||
"""Create a mock Discord author."""
|
||||
author = MagicMock()
|
||||
author.bot = bot
|
||||
author.id = 99999 if is_self else 12345
|
||||
author.name = "TestBot" if bot else "TestUser"
|
||||
author.display_name = author.name
|
||||
return author
|
||||
|
||||
|
||||
def _make_message(*, author=None, content="hello", mentions=None, is_dm=False):
|
||||
"""Create a mock Discord message."""
|
||||
msg = MagicMock()
|
||||
msg.author = author or _make_author()
|
||||
msg.content = content
|
||||
msg.attachments = []
|
||||
msg.mentions = mentions or []
|
||||
if is_dm:
|
||||
import discord
|
||||
msg.channel = MagicMock(spec=discord.DMChannel)
|
||||
msg.channel.id = 111
|
||||
else:
|
||||
msg.channel = MagicMock()
|
||||
msg.channel.id = 222
|
||||
msg.channel.name = "test-channel"
|
||||
msg.channel.guild = MagicMock()
|
||||
msg.channel.guild.name = "TestServer"
|
||||
# Make isinstance checks fail for DMChannel and Thread
|
||||
type(msg.channel).__name__ = "TextChannel"
|
||||
return msg
|
||||
|
||||
|
||||
class TestDiscordBotFilter(unittest.TestCase):
|
||||
"""Test the DISCORD_ALLOW_BOTS filtering logic."""
|
||||
|
||||
def _run_filter(self, message, allow_bots="none", client_user=None):
|
||||
"""Simulate the on_message filter logic and return whether message was accepted."""
|
||||
# Replicate the exact filter logic from discord.py on_message
|
||||
if message.author == client_user:
|
||||
return False # own messages always ignored
|
||||
|
||||
if getattr(message.author, "bot", False):
|
||||
allow = allow_bots.lower().strip()
|
||||
if allow == "none":
|
||||
return False
|
||||
elif allow == "mentions":
|
||||
if not client_user or client_user not in message.mentions:
|
||||
return False
|
||||
# "all" falls through
|
||||
|
||||
return True # message accepted
|
||||
|
||||
def test_own_messages_always_ignored(self):
|
||||
"""Bot's own messages are always ignored regardless of allow_bots."""
|
||||
bot_user = _make_author(is_self=True)
|
||||
msg = _make_message(author=bot_user)
|
||||
self.assertFalse(self._run_filter(msg, "all", bot_user))
|
||||
|
||||
def test_human_messages_always_accepted(self):
|
||||
"""Human messages are always accepted regardless of allow_bots."""
|
||||
human = _make_author(bot=False)
|
||||
msg = _make_message(author=human)
|
||||
self.assertTrue(self._run_filter(msg, "none"))
|
||||
self.assertTrue(self._run_filter(msg, "mentions"))
|
||||
self.assertTrue(self._run_filter(msg, "all"))
|
||||
|
||||
def test_allow_bots_none_rejects_bots(self):
|
||||
"""With allow_bots=none, all other bot messages are rejected."""
|
||||
bot = _make_author(bot=True)
|
||||
msg = _make_message(author=bot)
|
||||
self.assertFalse(self._run_filter(msg, "none"))
|
||||
|
||||
def test_allow_bots_all_accepts_bots(self):
|
||||
"""With allow_bots=all, all bot messages are accepted."""
|
||||
bot = _make_author(bot=True)
|
||||
msg = _make_message(author=bot)
|
||||
self.assertTrue(self._run_filter(msg, "all"))
|
||||
|
||||
def test_allow_bots_mentions_rejects_without_mention(self):
|
||||
"""With allow_bots=mentions, bot messages without @mention are rejected."""
|
||||
our_user = _make_author(is_self=True)
|
||||
bot = _make_author(bot=True)
|
||||
msg = _make_message(author=bot, mentions=[])
|
||||
self.assertFalse(self._run_filter(msg, "mentions", our_user))
|
||||
|
||||
def test_allow_bots_mentions_accepts_with_mention(self):
|
||||
"""With allow_bots=mentions, bot messages with @mention are accepted."""
|
||||
our_user = _make_author(is_self=True)
|
||||
bot = _make_author(bot=True)
|
||||
msg = _make_message(author=bot, mentions=[our_user])
|
||||
self.assertTrue(self._run_filter(msg, "mentions", our_user))
|
||||
|
||||
def test_default_is_none(self):
|
||||
"""Default behavior (no env var) should be 'none'."""
|
||||
default = os.getenv("DISCORD_ALLOW_BOTS", "none")
|
||||
self.assertEqual(default, "none")
|
||||
|
||||
def test_case_insensitive(self):
|
||||
"""Allow_bots value should be case-insensitive."""
|
||||
bot = _make_author(bot=True)
|
||||
msg = _make_message(author=bot)
|
||||
self.assertTrue(self._run_filter(msg, "ALL"))
|
||||
self.assertTrue(self._run_filter(msg, "All"))
|
||||
self.assertFalse(self._run_filter(msg, "NONE"))
|
||||
self.assertFalse(self._run_filter(msg, "None"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
347
hermes_code/tests/gateway/test_discord_document_handling.py
Normal file
347
hermes_code/tests/gateway/test_discord_document_handling.py
Normal file
|
|
@ -0,0 +1,347 @@
|
|||
"""Tests for Discord incoming document/file attachment handling.
|
||||
|
||||
Covers the document branch in DiscordAdapter._handle_message() —
|
||||
the `else` clause of the attachment content-type loop that was added
|
||||
to download, cache, and optionally inject text from non-image/audio files.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.base import MessageType
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Discord mock setup (copied from test_discord_free_response.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _ensure_discord_mock():
|
||||
"""Install a mock discord module when discord.py isn't available."""
|
||||
if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
|
||||
return
|
||||
|
||||
discord_mod = MagicMock()
|
||||
discord_mod.Intents.default.return_value = MagicMock()
|
||||
discord_mod.Client = MagicMock
|
||||
discord_mod.File = MagicMock
|
||||
discord_mod.DMChannel = type("DMChannel", (), {})
|
||||
discord_mod.Thread = type("Thread", (), {})
|
||||
discord_mod.ForumChannel = type("ForumChannel", (), {})
|
||||
discord_mod.ui = SimpleNamespace(View=object, button=lambda *a, **k: (lambda fn: fn), Button=object)
|
||||
discord_mod.ButtonStyle = SimpleNamespace(success=1, primary=2, danger=3, green=1, blurple=2, red=3)
|
||||
discord_mod.Color = SimpleNamespace(orange=lambda: 1, green=lambda: 2, blue=lambda: 3, red=lambda: 4)
|
||||
discord_mod.Interaction = object
|
||||
discord_mod.Embed = MagicMock
|
||||
discord_mod.app_commands = SimpleNamespace(
|
||||
describe=lambda **kwargs: (lambda fn: fn),
|
||||
choices=lambda **kwargs: (lambda fn: fn),
|
||||
Choice=lambda **kwargs: SimpleNamespace(**kwargs),
|
||||
)
|
||||
|
||||
ext_mod = MagicMock()
|
||||
commands_mod = MagicMock()
|
||||
commands_mod.Bot = MagicMock
|
||||
ext_mod.commands = commands_mod
|
||||
|
||||
sys.modules.setdefault("discord", discord_mod)
|
||||
sys.modules.setdefault("discord.ext", ext_mod)
|
||||
sys.modules.setdefault("discord.ext.commands", commands_mod)
|
||||
|
||||
|
||||
_ensure_discord_mock()
|
||||
|
||||
import gateway.platforms.discord as discord_platform # noqa: E402
|
||||
from gateway.platforms.discord import DiscordAdapter # noqa: E402
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fake channel / thread types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class FakeDMChannel:
|
||||
def __init__(self, channel_id: int = 1):
|
||||
self.id = channel_id
|
||||
self.name = "dm"
|
||||
|
||||
|
||||
class FakeThread:
|
||||
def __init__(self, channel_id: int = 10):
|
||||
self.id = channel_id
|
||||
self.name = "thread"
|
||||
self.parent = None
|
||||
self.parent_id = None
|
||||
self.guild = SimpleNamespace(name="TestServer")
|
||||
self.topic = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _redirect_cache(tmp_path, monkeypatch):
|
||||
"""Point document cache to tmp_path so tests never write to ~/.hermes."""
|
||||
monkeypatch.setattr(
|
||||
"gateway.platforms.base.DOCUMENT_CACHE_DIR", tmp_path / "doc_cache"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def adapter(monkeypatch):
|
||||
monkeypatch.setattr(discord_platform.discord, "DMChannel", FakeDMChannel, raising=False)
|
||||
monkeypatch.setattr(discord_platform.discord, "Thread", FakeThread, raising=False)
|
||||
|
||||
config = PlatformConfig(enabled=True, token="fake-token")
|
||||
a = DiscordAdapter(config)
|
||||
a._client = SimpleNamespace(user=SimpleNamespace(id=999))
|
||||
a.handle_message = AsyncMock()
|
||||
return a
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def make_attachment(
|
||||
*,
|
||||
filename: str,
|
||||
content_type: str,
|
||||
size: int = 1024,
|
||||
url: str = "https://cdn.discordapp.com/attachments/fake/file",
|
||||
) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
filename=filename,
|
||||
content_type=content_type,
|
||||
size=size,
|
||||
url=url,
|
||||
)
|
||||
|
||||
|
||||
def make_message(attachments: list, content: str = "") -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
id=123,
|
||||
content=content,
|
||||
attachments=attachments,
|
||||
mentions=[],
|
||||
reference=None,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
channel=FakeDMChannel(),
|
||||
author=SimpleNamespace(id=42, display_name="Tester", name="Tester"),
|
||||
)
|
||||
|
||||
|
||||
def _mock_aiohttp_download(raw_bytes: bytes):
|
||||
"""Return a patch context manager that makes aiohttp return raw_bytes."""
|
||||
resp = AsyncMock()
|
||||
resp.status = 200
|
||||
resp.read = AsyncMock(return_value=raw_bytes)
|
||||
resp.__aenter__ = AsyncMock(return_value=resp)
|
||||
resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
session = AsyncMock()
|
||||
session.get = MagicMock(return_value=resp)
|
||||
session.__aenter__ = AsyncMock(return_value=session)
|
||||
session.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
return patch("aiohttp.ClientSession", return_value=session)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestIncomingDocumentHandling:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pdf_document_cached(self, adapter):
|
||||
"""A PDF attachment should be downloaded, cached, typed as DOCUMENT."""
|
||||
pdf_bytes = b"%PDF-1.4 fake content"
|
||||
|
||||
with _mock_aiohttp_download(pdf_bytes):
|
||||
msg = make_message([make_attachment(filename="report.pdf", content_type="application/pdf")])
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert event.message_type == MessageType.DOCUMENT
|
||||
assert len(event.media_urls) == 1
|
||||
assert os.path.exists(event.media_urls[0])
|
||||
assert event.media_types == ["application/pdf"]
|
||||
assert "[Content of" not in (event.text or "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_txt_content_injected(self, adapter):
|
||||
""".txt file under 100KB should have its content injected into event.text."""
|
||||
file_content = b"Hello from a text file"
|
||||
|
||||
with _mock_aiohttp_download(file_content):
|
||||
msg = make_message(
|
||||
attachments=[make_attachment(filename="notes.txt", content_type="text/plain")],
|
||||
content="summarize this",
|
||||
)
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert "[Content of notes.txt]:" in event.text
|
||||
assert "Hello from a text file" in event.text
|
||||
assert "summarize this" in event.text
|
||||
# injection prepended before caption
|
||||
assert event.text.index("[Content of") < event.text.index("summarize this")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_md_content_injected(self, adapter):
|
||||
""".md file under 100KB should have its content injected."""
|
||||
file_content = b"# Title\nSome markdown content"
|
||||
|
||||
with _mock_aiohttp_download(file_content):
|
||||
msg = make_message(
|
||||
attachments=[make_attachment(filename="readme.md", content_type="text/markdown")],
|
||||
content="",
|
||||
)
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert "[Content of readme.md]:" in event.text
|
||||
assert "# Title" in event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oversized_document_skipped(self, adapter):
|
||||
"""A document over 20MB should be skipped — media_urls stays empty."""
|
||||
msg = make_message([
|
||||
make_attachment(
|
||||
filename="huge.pdf",
|
||||
content_type="application/pdf",
|
||||
size=25 * 1024 * 1024,
|
||||
)
|
||||
])
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert event.media_urls == []
|
||||
# handler must still be called
|
||||
adapter.handle_message.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsupported_type_skipped(self, adapter):
|
||||
"""An unsupported file type (.zip) should be skipped silently."""
|
||||
msg = make_message([
|
||||
make_attachment(filename="archive.zip", content_type="application/zip")
|
||||
])
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert event.media_urls == []
|
||||
assert event.message_type == MessageType.TEXT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_error_handled(self, adapter):
|
||||
"""If the HTTP download raises, the handler should not crash."""
|
||||
resp = AsyncMock()
|
||||
resp.__aenter__ = AsyncMock(side_effect=RuntimeError("connection reset"))
|
||||
resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
session = AsyncMock()
|
||||
session.get = MagicMock(return_value=resp)
|
||||
session.__aenter__ = AsyncMock(return_value=session)
|
||||
session.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=session):
|
||||
msg = make_message([
|
||||
make_attachment(filename="report.pdf", content_type="application/pdf")
|
||||
])
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
# Must still deliver an event
|
||||
adapter.handle_message.assert_called_once()
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert event.media_urls == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_txt_cached_not_injected(self, adapter):
|
||||
""".txt over 100KB should be cached but NOT injected into event.text."""
|
||||
large_content = b"x" * (200 * 1024)
|
||||
|
||||
with _mock_aiohttp_download(large_content):
|
||||
msg = make_message(
|
||||
attachments=[make_attachment(filename="big.txt", content_type="text/plain", size=len(large_content))],
|
||||
content="",
|
||||
)
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert len(event.media_urls) == 1
|
||||
assert os.path.exists(event.media_urls[0])
|
||||
assert "[Content of" not in (event.text or "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_text_files_both_injected(self, adapter):
|
||||
"""Two text file attachments should both be injected into event.text in order."""
|
||||
content1 = b"First file content"
|
||||
content2 = b"Second file content"
|
||||
|
||||
call_count = 0
|
||||
responses = [content1, content2]
|
||||
|
||||
def make_session(_responses):
|
||||
idx = 0
|
||||
|
||||
class FakeSession:
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *_):
|
||||
pass
|
||||
|
||||
def get(self, url, **kwargs):
|
||||
nonlocal idx
|
||||
data = _responses[idx % len(_responses)]
|
||||
idx += 1
|
||||
|
||||
resp = AsyncMock()
|
||||
resp.status = 200
|
||||
resp.read = AsyncMock(return_value=data)
|
||||
resp.__aenter__ = AsyncMock(return_value=resp)
|
||||
resp.__aexit__ = AsyncMock(return_value=False)
|
||||
return resp
|
||||
|
||||
return FakeSession()
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=make_session([content1, content2])):
|
||||
msg = make_message(
|
||||
attachments=[
|
||||
make_attachment(filename="file1.txt", content_type="text/plain"),
|
||||
make_attachment(filename="file2.txt", content_type="text/plain"),
|
||||
],
|
||||
content="",
|
||||
)
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert "[Content of file1.txt]:" in event.text
|
||||
assert "First file content" in event.text
|
||||
assert "[Content of file2.txt]:" in event.text
|
||||
assert "Second file content" in event.text
|
||||
assert event.text.index("file1") < event.text.index("file2")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_attachment_unaffected(self, adapter):
|
||||
"""Image attachments should still go through the image path, not the document path."""
|
||||
with patch(
|
||||
"gateway.platforms.discord.cache_image_from_url",
|
||||
new_callable=AsyncMock,
|
||||
return_value="/tmp/cached_image.png",
|
||||
):
|
||||
msg = make_message([
|
||||
make_attachment(filename="photo.png", content_type="image/png")
|
||||
])
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert event.message_type == MessageType.PHOTO
|
||||
assert event.media_urls == ["/tmp/cached_image.png"]
|
||||
assert event.media_types == ["image/png"]
|
||||
360
hermes_code/tests/gateway/test_discord_free_response.py
Normal file
360
hermes_code/tests/gateway/test_discord_free_response.py
Normal file
|
|
@ -0,0 +1,360 @@
|
|||
"""Tests for Discord free-response defaults and mention gating."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
|
||||
|
||||
def _ensure_discord_mock():
|
||||
"""Install a mock discord module when discord.py isn't available."""
|
||||
if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
|
||||
return
|
||||
|
||||
discord_mod = MagicMock()
|
||||
discord_mod.Intents.default.return_value = MagicMock()
|
||||
discord_mod.Client = MagicMock
|
||||
discord_mod.File = MagicMock
|
||||
discord_mod.DMChannel = type("DMChannel", (), {})
|
||||
discord_mod.Thread = type("Thread", (), {})
|
||||
discord_mod.ForumChannel = type("ForumChannel", (), {})
|
||||
discord_mod.ui = SimpleNamespace(View=object, button=lambda *a, **k: (lambda fn: fn), Button=object)
|
||||
discord_mod.ButtonStyle = SimpleNamespace(success=1, primary=2, danger=3, green=1, blurple=2, red=3)
|
||||
discord_mod.Color = SimpleNamespace(orange=lambda: 1, green=lambda: 2, blue=lambda: 3, red=lambda: 4)
|
||||
discord_mod.Interaction = object
|
||||
discord_mod.Embed = MagicMock
|
||||
discord_mod.app_commands = SimpleNamespace(
|
||||
describe=lambda **kwargs: (lambda fn: fn),
|
||||
choices=lambda **kwargs: (lambda fn: fn),
|
||||
Choice=lambda **kwargs: SimpleNamespace(**kwargs),
|
||||
)
|
||||
|
||||
ext_mod = MagicMock()
|
||||
commands_mod = MagicMock()
|
||||
commands_mod.Bot = MagicMock
|
||||
ext_mod.commands = commands_mod
|
||||
|
||||
sys.modules.setdefault("discord", discord_mod)
|
||||
sys.modules.setdefault("discord.ext", ext_mod)
|
||||
sys.modules.setdefault("discord.ext.commands", commands_mod)
|
||||
|
||||
|
||||
_ensure_discord_mock()
|
||||
|
||||
import gateway.platforms.discord as discord_platform # noqa: E402
|
||||
from gateway.platforms.discord import DiscordAdapter # noqa: E402
|
||||
|
||||
|
||||
class FakeDMChannel:
|
||||
def __init__(self, channel_id: int = 1, name: str = "dm"):
|
||||
self.id = channel_id
|
||||
self.name = name
|
||||
|
||||
|
||||
class FakeTextChannel:
|
||||
def __init__(self, channel_id: int = 1, name: str = "general", guild_name: str = "Hermes Server"):
|
||||
self.id = channel_id
|
||||
self.name = name
|
||||
self.guild = SimpleNamespace(name=guild_name)
|
||||
self.topic = None
|
||||
|
||||
|
||||
class FakeForumChannel:
|
||||
def __init__(self, channel_id: int = 1, name: str = "support-forum", guild_name: str = "Hermes Server"):
|
||||
self.id = channel_id
|
||||
self.name = name
|
||||
self.guild = SimpleNamespace(name=guild_name)
|
||||
self.type = 15
|
||||
self.topic = None
|
||||
|
||||
|
||||
class FakeThread:
|
||||
def __init__(self, channel_id: int = 1, name: str = "thread", parent=None, guild_name: str = "Hermes Server"):
|
||||
self.id = channel_id
|
||||
self.name = name
|
||||
self.parent = parent
|
||||
self.parent_id = getattr(parent, "id", None)
|
||||
self.guild = getattr(parent, "guild", None) or SimpleNamespace(name=guild_name)
|
||||
self.topic = None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def adapter(monkeypatch):
|
||||
monkeypatch.setattr(discord_platform.discord, "DMChannel", FakeDMChannel, raising=False)
|
||||
monkeypatch.setattr(discord_platform.discord, "Thread", FakeThread, raising=False)
|
||||
monkeypatch.setattr(discord_platform.discord, "ForumChannel", FakeForumChannel, raising=False)
|
||||
|
||||
config = PlatformConfig(enabled=True, token="fake-token")
|
||||
adapter = DiscordAdapter(config)
|
||||
adapter._client = SimpleNamespace(user=SimpleNamespace(id=999))
|
||||
adapter.handle_message = AsyncMock()
|
||||
return adapter
|
||||
|
||||
|
||||
def make_message(*, channel, content: str, mentions=None):
|
||||
author = SimpleNamespace(id=42, display_name="Jezza", name="Jezza")
|
||||
return SimpleNamespace(
|
||||
id=123,
|
||||
content=content,
|
||||
mentions=list(mentions or []),
|
||||
attachments=[],
|
||||
reference=None,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
channel=channel,
|
||||
author=author,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_defaults_to_require_mention(adapter, monkeypatch):
|
||||
"""Default behavior: require @mention in server channels."""
|
||||
monkeypatch.delenv("DISCORD_REQUIRE_MENTION", raising=False)
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
|
||||
message = make_message(channel=FakeTextChannel(channel_id=123), content="hello from channel")
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
# Should be ignored — no mention, require_mention defaults to true
|
||||
adapter.handle_message.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_free_response_in_server_channels(adapter, monkeypatch):
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
|
||||
message = make_message(channel=FakeTextChannel(channel_id=123), content="hello from channel")
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.text == "hello from channel"
|
||||
assert event.source.chat_id == "123"
|
||||
assert event.source.chat_type == "group"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_free_response_in_threads(adapter, monkeypatch):
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
|
||||
thread = FakeThread(channel_id=456, name="Ghost reader skill")
|
||||
message = make_message(channel=thread, content="hello from thread")
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.text == "hello from thread"
|
||||
assert event.source.chat_id == "456"
|
||||
assert event.source.thread_id == "456"
|
||||
assert event.source.chat_type == "thread"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_forum_threads_are_handled_as_threads(adapter, monkeypatch):
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
|
||||
forum = FakeForumChannel(channel_id=222, name="support-forum")
|
||||
thread = FakeThread(channel_id=456, name="Can Hermes reply here?", parent=forum)
|
||||
message = make_message(channel=thread, content="hello from forum post")
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.text == "hello from forum post"
|
||||
assert event.source.chat_id == "456"
|
||||
assert event.source.thread_id == "456"
|
||||
assert event.source.chat_type == "thread"
|
||||
assert event.source.chat_name == "Hermes Server / support-forum / Can Hermes reply here?"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_can_still_require_mentions_when_enabled(adapter, monkeypatch):
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
|
||||
message = make_message(channel=FakeTextChannel(channel_id=789), content="ignored without mention")
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter.handle_message.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_free_response_channel_overrides_mention_requirement(adapter, monkeypatch):
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
|
||||
monkeypatch.setenv("DISCORD_FREE_RESPONSE_CHANNELS", "789,999")
|
||||
|
||||
message = make_message(channel=FakeTextChannel(channel_id=789), content="allowed without mention")
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.text == "allowed without mention"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_forum_parent_in_free_response_list_allows_forum_thread(adapter, monkeypatch):
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
|
||||
monkeypatch.setenv("DISCORD_FREE_RESPONSE_CHANNELS", "222")
|
||||
|
||||
forum = FakeForumChannel(channel_id=222, name="support-forum")
|
||||
thread = FakeThread(channel_id=333, name="Forum topic", parent=forum)
|
||||
message = make_message(channel=thread, content="allowed from forum thread")
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.text == "allowed from forum thread"
|
||||
assert event.source.chat_id == "333"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_accepts_and_strips_bot_mentions_when_required(adapter, monkeypatch):
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
|
||||
bot_user = adapter._client.user
|
||||
message = make_message(
|
||||
channel=FakeTextChannel(channel_id=321),
|
||||
content=f"<@{bot_user.id}> hello with mention",
|
||||
mentions=[bot_user],
|
||||
)
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.text == "hello with mention"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_dms_ignore_mention_requirement(adapter, monkeypatch):
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
|
||||
message = make_message(channel=FakeDMChannel(channel_id=654), content="dm without mention")
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.text == "dm without mention"
|
||||
assert event.source.chat_type == "dm"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_auto_thread_enabled_by_default(adapter, monkeypatch):
|
||||
"""Auto-threading should be enabled by default (DISCORD_AUTO_THREAD defaults to 'true')."""
|
||||
monkeypatch.delenv("DISCORD_AUTO_THREAD", raising=False)
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
|
||||
# Patch _auto_create_thread to return a fake thread
|
||||
fake_thread = FakeThread(channel_id=999, name="auto-thread")
|
||||
adapter._auto_create_thread = AsyncMock(return_value=fake_thread)
|
||||
|
||||
message = make_message(channel=FakeTextChannel(channel_id=123), content="hello")
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter._auto_create_thread.assert_awaited_once()
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.source.chat_type == "thread"
|
||||
assert event.source.thread_id == "999"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_auto_thread_can_be_disabled(adapter, monkeypatch):
|
||||
"""Setting auto_thread to false skips thread creation."""
|
||||
monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
|
||||
adapter._auto_create_thread = AsyncMock()
|
||||
|
||||
message = make_message(channel=FakeTextChannel(channel_id=123), content="hello")
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter._auto_create_thread.assert_not_awaited()
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.source.chat_type == "group"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_bot_thread_skips_mention_requirement(adapter, monkeypatch):
|
||||
"""Messages in a thread the bot has participated in should not require @mention."""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
|
||||
|
||||
# Simulate bot having previously participated in thread 456
|
||||
adapter._bot_participated_threads.add("456")
|
||||
|
||||
thread = FakeThread(channel_id=456, name="existing thread")
|
||||
message = make_message(channel=thread, content="follow-up without mention")
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.text == "follow-up without mention"
|
||||
assert event.source.chat_type == "thread"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_unknown_thread_still_requires_mention(adapter, monkeypatch):
|
||||
"""Messages in a thread the bot hasn't participated in should still require @mention."""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
|
||||
|
||||
# Bot has NOT participated in thread 789
|
||||
thread = FakeThread(channel_id=789, name="some thread")
|
||||
message = make_message(channel=thread, content="hello from unknown thread")
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter.handle_message.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_auto_thread_tracks_participation(adapter, monkeypatch):
|
||||
"""Auto-created threads should be tracked for future mention-free replies."""
|
||||
monkeypatch.delenv("DISCORD_AUTO_THREAD", raising=False)
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
|
||||
fake_thread = FakeThread(channel_id=555, name="auto-thread")
|
||||
adapter._auto_create_thread = AsyncMock(return_value=fake_thread)
|
||||
|
||||
message = make_message(channel=FakeTextChannel(channel_id=123), content="start a thread")
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
assert "555" in adapter._bot_participated_threads
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_thread_participation_tracked_on_dispatch(adapter, monkeypatch):
|
||||
"""When the bot processes a message in a thread, it tracks participation."""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
|
||||
|
||||
thread = FakeThread(channel_id=777, name="manually created thread")
|
||||
message = make_message(channel=thread, content="hello in thread")
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
assert "777" in adapter._bot_participated_threads
|
||||
23
hermes_code/tests/gateway/test_discord_imports.py
Normal file
23
hermes_code/tests/gateway/test_discord_imports.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
"""Import-safety tests for the Discord gateway adapter."""
|
||||
|
||||
import builtins
|
||||
import importlib
|
||||
import sys
|
||||
|
||||
|
||||
class TestDiscordImportSafety:
|
||||
def test_module_imports_even_when_discord_dependency_is_missing(self, monkeypatch):
|
||||
original_import = builtins.__import__
|
||||
|
||||
def fake_import(name, globals=None, locals=None, fromlist=(), level=0):
|
||||
if name == "discord" or name.startswith("discord."):
|
||||
raise ImportError("discord unavailable for test")
|
||||
return original_import(name, globals, locals, fromlist, level)
|
||||
|
||||
monkeypatch.delitem(sys.modules, "gateway.platforms.discord", raising=False)
|
||||
monkeypatch.setattr(builtins, "__import__", fake_import)
|
||||
|
||||
module = importlib.import_module("gateway.platforms.discord")
|
||||
|
||||
assert module.DISCORD_AVAILABLE is False
|
||||
assert module.discord is None
|
||||
9
hermes_code/tests/gateway/test_discord_media_metadata.py
Normal file
9
hermes_code/tests/gateway/test_discord_media_metadata.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
import inspect
|
||||
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
|
||||
|
||||
def test_discord_media_methods_accept_metadata_kwarg():
|
||||
for method_name in ("send_voice", "send_image_file", "send_image"):
|
||||
signature = inspect.signature(getattr(DiscordAdapter, method_name))
|
||||
assert "metadata" in signature.parameters, method_name
|
||||
44
hermes_code/tests/gateway/test_discord_opus.py
Normal file
44
hermes_code/tests/gateway/test_discord_opus.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
"""Tests for Discord Opus codec loading — must use ctypes.util.find_library."""
|
||||
|
||||
import inspect
|
||||
|
||||
|
||||
class TestOpusFindLibrary:
|
||||
"""Opus loading must try ctypes.util.find_library first, with platform fallback."""
|
||||
|
||||
def test_uses_find_library_first(self):
|
||||
"""find_library must be the primary lookup strategy."""
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
source = inspect.getsource(DiscordAdapter.connect)
|
||||
assert "find_library" in source, \
|
||||
"Opus loading must use ctypes.util.find_library"
|
||||
|
||||
def test_homebrew_fallback_is_conditional(self):
|
||||
"""Homebrew paths must only be tried when find_library returns None."""
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
source = inspect.getsource(DiscordAdapter.connect)
|
||||
# Homebrew fallback must exist
|
||||
assert "/opt/homebrew" in source or "homebrew" in source, \
|
||||
"Opus loading should have macOS Homebrew fallback"
|
||||
# find_library must appear BEFORE any Homebrew path
|
||||
fl_idx = source.index("find_library")
|
||||
hb_idx = source.index("/opt/homebrew")
|
||||
assert fl_idx < hb_idx, \
|
||||
"find_library must be tried before Homebrew fallback paths"
|
||||
# Fallback must be guarded by platform check
|
||||
assert "sys.platform" in source or "darwin" in source, \
|
||||
"Homebrew fallback must be guarded by macOS platform check"
|
||||
|
||||
def test_opus_decode_error_logged(self):
|
||||
"""Opus decode failure must log the error, not silently return."""
|
||||
from gateway.platforms.discord import VoiceReceiver
|
||||
source = inspect.getsource(VoiceReceiver._on_packet)
|
||||
assert "logger" in source, \
|
||||
"_on_packet must log Opus decode errors"
|
||||
# Must not have bare `except Exception:\n return`
|
||||
lines = source.split("\n")
|
||||
for i, line in enumerate(lines):
|
||||
if "except Exception" in line and i + 1 < len(lines):
|
||||
next_line = lines[i + 1].strip()
|
||||
assert next_line != "return", \
|
||||
f"_on_packet has bare 'except Exception: return' at line {i+1}"
|
||||
80
hermes_code/tests/gateway/test_discord_send.py
Normal file
80
hermes_code/tests/gateway/test_discord_send.py
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
|
||||
|
||||
def _ensure_discord_mock():
|
||||
if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
|
||||
return
|
||||
|
||||
discord_mod = MagicMock()
|
||||
discord_mod.Intents.default.return_value = MagicMock()
|
||||
discord_mod.Client = MagicMock
|
||||
discord_mod.File = MagicMock
|
||||
discord_mod.DMChannel = type("DMChannel", (), {})
|
||||
discord_mod.Thread = type("Thread", (), {})
|
||||
discord_mod.ForumChannel = type("ForumChannel", (), {})
|
||||
discord_mod.ui = SimpleNamespace(View=object, button=lambda *a, **k: (lambda fn: fn), Button=object)
|
||||
discord_mod.ButtonStyle = SimpleNamespace(success=1, primary=2, danger=3, green=1, blurple=2, red=3)
|
||||
discord_mod.Color = SimpleNamespace(orange=lambda: 1, green=lambda: 2, blue=lambda: 3, red=lambda: 4)
|
||||
discord_mod.Interaction = object
|
||||
discord_mod.Embed = MagicMock
|
||||
discord_mod.app_commands = SimpleNamespace(
|
||||
describe=lambda **kwargs: (lambda fn: fn),
|
||||
choices=lambda **kwargs: (lambda fn: fn),
|
||||
Choice=lambda **kwargs: SimpleNamespace(**kwargs),
|
||||
)
|
||||
|
||||
ext_mod = MagicMock()
|
||||
commands_mod = MagicMock()
|
||||
commands_mod.Bot = MagicMock
|
||||
ext_mod.commands = commands_mod
|
||||
|
||||
sys.modules.setdefault("discord", discord_mod)
|
||||
sys.modules.setdefault("discord.ext", ext_mod)
|
||||
sys.modules.setdefault("discord.ext.commands", commands_mod)
|
||||
|
||||
|
||||
_ensure_discord_mock()
|
||||
|
||||
from gateway.platforms.discord import DiscordAdapter # noqa: E402
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_retries_without_reference_when_reply_target_is_system_message():
|
||||
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
|
||||
ref_msg = SimpleNamespace(id=99)
|
||||
sent_msg = SimpleNamespace(id=1234)
|
||||
send_calls = []
|
||||
|
||||
async def fake_send(*, content, reference=None):
|
||||
send_calls.append({"content": content, "reference": reference})
|
||||
if len(send_calls) == 1:
|
||||
raise RuntimeError(
|
||||
"400 Bad Request (error code: 50035): Invalid Form Body\n"
|
||||
"In message_reference: Cannot reply to a system message"
|
||||
)
|
||||
return sent_msg
|
||||
|
||||
channel = SimpleNamespace(
|
||||
fetch_message=AsyncMock(return_value=ref_msg),
|
||||
send=AsyncMock(side_effect=fake_send),
|
||||
)
|
||||
adapter._client = SimpleNamespace(
|
||||
get_channel=lambda _chat_id: channel,
|
||||
fetch_channel=AsyncMock(),
|
||||
)
|
||||
|
||||
result = await adapter.send("555", "hello", reply_to="99")
|
||||
|
||||
assert result.success is True
|
||||
assert result.message_id == "1234"
|
||||
assert channel.fetch_message.await_count == 1
|
||||
assert channel.send.await_count == 2
|
||||
assert send_calls[0]["reference"] is ref_msg
|
||||
assert send_calls[1]["reference"] is None
|
||||
499
hermes_code/tests/gateway/test_discord_slash_commands.py
Normal file
499
hermes_code/tests/gateway/test_discord_slash_commands.py
Normal file
|
|
@ -0,0 +1,499 @@
|
|||
"""Tests for native Discord slash command fast-paths (thread creation & auto-thread)."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
|
||||
|
||||
def _ensure_discord_mock():
|
||||
if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
|
||||
return
|
||||
|
||||
discord_mod = MagicMock()
|
||||
discord_mod.Intents.default.return_value = MagicMock()
|
||||
discord_mod.DMChannel = type("DMChannel", (), {})
|
||||
discord_mod.Thread = type("Thread", (), {})
|
||||
discord_mod.ForumChannel = type("ForumChannel", (), {})
|
||||
discord_mod.Interaction = object
|
||||
discord_mod.app_commands = SimpleNamespace(
|
||||
describe=lambda **kwargs: (lambda fn: fn),
|
||||
choices=lambda **kwargs: (lambda fn: fn),
|
||||
Choice=lambda **kwargs: SimpleNamespace(**kwargs),
|
||||
)
|
||||
|
||||
ext_mod = MagicMock()
|
||||
commands_mod = MagicMock()
|
||||
commands_mod.Bot = MagicMock
|
||||
ext_mod.commands = commands_mod
|
||||
|
||||
sys.modules.setdefault("discord", discord_mod)
|
||||
sys.modules.setdefault("discord.ext", ext_mod)
|
||||
sys.modules.setdefault("discord.ext.commands", commands_mod)
|
||||
|
||||
|
||||
_ensure_discord_mock()
|
||||
|
||||
from gateway.platforms.discord import DiscordAdapter # noqa: E402
|
||||
|
||||
|
||||
class FakeTree:
|
||||
def __init__(self):
|
||||
self.commands = {}
|
||||
|
||||
def command(self, *, name, description):
|
||||
def decorator(fn):
|
||||
self.commands[name] = fn
|
||||
return fn
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def adapter():
|
||||
config = PlatformConfig(enabled=True, token="***")
|
||||
adapter = DiscordAdapter(config)
|
||||
adapter._client = SimpleNamespace(
|
||||
tree=FakeTree(),
|
||||
get_channel=lambda _id: None,
|
||||
fetch_channel=AsyncMock(),
|
||||
user=SimpleNamespace(id=99999, name="HermesBot"),
|
||||
)
|
||||
return adapter
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# /thread slash command registration
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_registers_native_thread_slash_command(adapter):
|
||||
adapter._handle_thread_create_slash = AsyncMock()
|
||||
adapter._register_slash_commands()
|
||||
|
||||
command = adapter._client.tree.commands["thread"]
|
||||
interaction = SimpleNamespace(
|
||||
response=SimpleNamespace(defer=AsyncMock()),
|
||||
)
|
||||
|
||||
await command(interaction, name="Planning", message="", auto_archive_duration=1440)
|
||||
|
||||
interaction.response.defer.assert_awaited_once_with(ephemeral=True)
|
||||
adapter._handle_thread_create_slash.assert_awaited_once_with(interaction, "Planning", "", 1440)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# _handle_thread_create_slash — success, session dispatch, failure
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_thread_create_slash_reports_success(adapter):
|
||||
created_thread = SimpleNamespace(id=555, name="Planning", send=AsyncMock())
|
||||
parent_channel = SimpleNamespace(create_thread=AsyncMock(return_value=created_thread), send=AsyncMock())
|
||||
interaction_channel = SimpleNamespace(parent=parent_channel)
|
||||
interaction = SimpleNamespace(
|
||||
channel=interaction_channel,
|
||||
channel_id=123,
|
||||
user=SimpleNamespace(display_name="Jezza", id=42),
|
||||
guild=SimpleNamespace(name="TestGuild"),
|
||||
followup=SimpleNamespace(send=AsyncMock()),
|
||||
)
|
||||
|
||||
await adapter._handle_thread_create_slash(interaction, "Planning", "Kickoff", 1440)
|
||||
|
||||
parent_channel.create_thread.assert_awaited_once_with(
|
||||
name="Planning",
|
||||
auto_archive_duration=1440,
|
||||
reason="Requested by Jezza via /thread",
|
||||
)
|
||||
created_thread.send.assert_awaited_once_with("Kickoff")
|
||||
# Thread link shown to user
|
||||
interaction.followup.send.assert_awaited()
|
||||
args, kwargs = interaction.followup.send.await_args
|
||||
assert "<#555>" in args[0]
|
||||
assert kwargs["ephemeral"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_thread_create_slash_dispatches_session_when_message_provided(adapter):
|
||||
"""When a message is given, _dispatch_thread_session should be called."""
|
||||
created_thread = SimpleNamespace(id=555, name="Planning", send=AsyncMock())
|
||||
parent_channel = SimpleNamespace(create_thread=AsyncMock(return_value=created_thread))
|
||||
interaction = SimpleNamespace(
|
||||
channel=SimpleNamespace(parent=parent_channel),
|
||||
channel_id=123,
|
||||
user=SimpleNamespace(display_name="Jezza", id=42),
|
||||
guild=SimpleNamespace(name="TestGuild"),
|
||||
followup=SimpleNamespace(send=AsyncMock()),
|
||||
)
|
||||
|
||||
adapter._dispatch_thread_session = AsyncMock()
|
||||
|
||||
await adapter._handle_thread_create_slash(interaction, "Planning", "Hello Hermes", 1440)
|
||||
|
||||
adapter._dispatch_thread_session.assert_awaited_once_with(
|
||||
interaction, "555", "Planning", "Hello Hermes",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_thread_create_slash_no_dispatch_without_message(adapter):
|
||||
"""Without a message, no session dispatch should occur."""
|
||||
created_thread = SimpleNamespace(id=555, name="Planning", send=AsyncMock())
|
||||
parent_channel = SimpleNamespace(create_thread=AsyncMock(return_value=created_thread))
|
||||
interaction = SimpleNamespace(
|
||||
channel=SimpleNamespace(parent=parent_channel),
|
||||
channel_id=123,
|
||||
user=SimpleNamespace(display_name="Jezza", id=42),
|
||||
guild=SimpleNamespace(name="TestGuild"),
|
||||
followup=SimpleNamespace(send=AsyncMock()),
|
||||
)
|
||||
|
||||
adapter._dispatch_thread_session = AsyncMock()
|
||||
|
||||
await adapter._handle_thread_create_slash(interaction, "Planning", "", 1440)
|
||||
|
||||
adapter._dispatch_thread_session.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_thread_create_slash_falls_back_to_seed_message(adapter):
|
||||
created_thread = SimpleNamespace(id=555, name="Planning")
|
||||
seed_message = SimpleNamespace(id=777, create_thread=AsyncMock(return_value=created_thread))
|
||||
channel = SimpleNamespace(
|
||||
create_thread=AsyncMock(side_effect=RuntimeError("direct failed")),
|
||||
send=AsyncMock(return_value=seed_message),
|
||||
)
|
||||
interaction = SimpleNamespace(
|
||||
channel=channel,
|
||||
channel_id=123,
|
||||
user=SimpleNamespace(display_name="Jezza", id=42),
|
||||
guild=SimpleNamespace(name="TestGuild"),
|
||||
followup=SimpleNamespace(send=AsyncMock()),
|
||||
)
|
||||
|
||||
await adapter._handle_thread_create_slash(interaction, "Planning", "Kickoff", 1440)
|
||||
|
||||
channel.send.assert_awaited_once_with("Kickoff")
|
||||
seed_message.create_thread.assert_awaited_once_with(
|
||||
name="Planning",
|
||||
auto_archive_duration=1440,
|
||||
reason="Requested by Jezza via /thread",
|
||||
)
|
||||
interaction.followup.send.assert_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_thread_create_slash_reports_failure(adapter):
|
||||
channel = SimpleNamespace(
|
||||
create_thread=AsyncMock(side_effect=RuntimeError("direct failed")),
|
||||
send=AsyncMock(side_effect=RuntimeError("nope")),
|
||||
)
|
||||
interaction = SimpleNamespace(
|
||||
channel=channel,
|
||||
channel_id=123,
|
||||
user=SimpleNamespace(display_name="Jezza", id=42),
|
||||
followup=SimpleNamespace(send=AsyncMock()),
|
||||
)
|
||||
|
||||
await adapter._handle_thread_create_slash(interaction, "Planning", "", 1440)
|
||||
|
||||
interaction.followup.send.assert_awaited_once()
|
||||
args, kwargs = interaction.followup.send.await_args
|
||||
assert "Failed to create thread:" in args[0]
|
||||
assert "nope" in args[0]
|
||||
assert kwargs["ephemeral"] is True
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# _dispatch_thread_session — builds correct event and routes it
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_thread_session_builds_thread_event(adapter):
|
||||
"""Dispatched event should have chat_type=thread and chat_id=thread_id."""
|
||||
interaction = SimpleNamespace(
|
||||
user=SimpleNamespace(display_name="Jezza", id=42),
|
||||
guild=SimpleNamespace(name="TestGuild"),
|
||||
)
|
||||
|
||||
captured_events = []
|
||||
|
||||
async def capture_handle(event):
|
||||
captured_events.append(event)
|
||||
|
||||
adapter.handle_message = capture_handle
|
||||
|
||||
await adapter._dispatch_thread_session(interaction, "555", "Planning", "Hello!")
|
||||
|
||||
assert len(captured_events) == 1
|
||||
event = captured_events[0]
|
||||
assert event.text == "Hello!"
|
||||
assert event.source.chat_id == "555"
|
||||
assert event.source.chat_type == "thread"
|
||||
assert event.source.thread_id == "555"
|
||||
assert "TestGuild" in event.source.chat_name
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# _build_slash_event — preserve thread context for native slash commands
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_build_slash_event_preserves_thread_context(adapter):
|
||||
interaction = SimpleNamespace(
|
||||
channel=_FakeThreadChannel(channel_id=555, name="Planning"),
|
||||
channel_id=555,
|
||||
user=SimpleNamespace(display_name="Jezza", id=42),
|
||||
)
|
||||
|
||||
event = adapter._build_slash_event(interaction, "/status")
|
||||
|
||||
assert event.text == "/status"
|
||||
assert event.source.chat_id == "555"
|
||||
assert event.source.chat_type == "thread"
|
||||
assert event.source.thread_id == "555"
|
||||
assert "TestGuild" in event.source.chat_name
|
||||
|
||||
|
||||
def test_build_slash_event_uses_group_context_for_channels(adapter):
|
||||
interaction = SimpleNamespace(
|
||||
channel=_FakeTextChannel(channel_id=123, name="general"),
|
||||
channel_id=123,
|
||||
user=SimpleNamespace(display_name="Jezza", id=42),
|
||||
)
|
||||
|
||||
event = adapter._build_slash_event(interaction, "/status")
|
||||
|
||||
assert event.source.chat_id == "123"
|
||||
assert event.source.chat_type == "group"
|
||||
assert event.source.thread_id is None
|
||||
assert "TestGuild / #general" == event.source.chat_name
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Auto-thread: _auto_create_thread
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_create_thread_uses_message_content_as_name(adapter):
|
||||
thread = SimpleNamespace(id=999, name="Hello world")
|
||||
message = SimpleNamespace(
|
||||
content="Hello world, how are you?",
|
||||
create_thread=AsyncMock(return_value=thread),
|
||||
)
|
||||
|
||||
result = await adapter._auto_create_thread(message)
|
||||
|
||||
assert result is thread
|
||||
message.create_thread.assert_awaited_once()
|
||||
call_kwargs = message.create_thread.await_args[1]
|
||||
assert call_kwargs["name"] == "Hello world, how are you?"
|
||||
assert call_kwargs["auto_archive_duration"] == 1440
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_create_thread_truncates_long_names(adapter):
|
||||
long_text = "a" * 200
|
||||
thread = SimpleNamespace(id=999, name="truncated")
|
||||
message = SimpleNamespace(
|
||||
content=long_text,
|
||||
create_thread=AsyncMock(return_value=thread),
|
||||
)
|
||||
|
||||
result = await adapter._auto_create_thread(message)
|
||||
|
||||
assert result is thread
|
||||
call_kwargs = message.create_thread.await_args[1]
|
||||
assert len(call_kwargs["name"]) <= 80
|
||||
assert call_kwargs["name"].endswith("...")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_create_thread_returns_none_on_failure(adapter):
|
||||
message = SimpleNamespace(
|
||||
content="Hello",
|
||||
create_thread=AsyncMock(side_effect=RuntimeError("no perms")),
|
||||
)
|
||||
|
||||
result = await adapter._auto_create_thread(message)
|
||||
assert result is None
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Auto-thread integration in _handle_message
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
import discord as _discord_mod # noqa: E402 — mock or real, used below
|
||||
|
||||
|
||||
class _FakeTextChannel:
|
||||
"""A channel that is NOT a discord.Thread or discord.DMChannel."""
|
||||
|
||||
def __init__(self, channel_id=100, name="general", guild_name="TestGuild"):
|
||||
self.id = channel_id
|
||||
self.name = name
|
||||
self.guild = SimpleNamespace(name=guild_name, id=1)
|
||||
self.topic = None
|
||||
|
||||
|
||||
class _FakeThreadChannel(_discord_mod.Thread):
|
||||
"""isinstance(ch, discord.Thread) → True."""
|
||||
|
||||
def __init__(self, channel_id=200, name="existing-thread", guild_name="TestGuild", parent_id=100):
|
||||
# Don't call super().__init__ — mock Thread is just an empty type
|
||||
self.id = channel_id
|
||||
self.name = name
|
||||
self.guild = SimpleNamespace(name=guild_name, id=1)
|
||||
self.topic = None
|
||||
self.parent = SimpleNamespace(id=parent_id, name="general", guild=SimpleNamespace(name=guild_name, id=1))
|
||||
|
||||
|
||||
def _fake_message(channel, *, content="Hello", author_id=42, display_name="Jezza"):
|
||||
return SimpleNamespace(
|
||||
author=SimpleNamespace(id=author_id, display_name=display_name, bot=False),
|
||||
content=content,
|
||||
channel=channel,
|
||||
attachments=[],
|
||||
mentions=[],
|
||||
reference=None,
|
||||
created_at=None,
|
||||
id=12345,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_thread_creates_thread_and_redirects(adapter, monkeypatch):
|
||||
"""When DISCORD_AUTO_THREAD=true, a new thread is created and the event routes there."""
|
||||
monkeypatch.setenv("DISCORD_AUTO_THREAD", "true")
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
|
||||
thread = SimpleNamespace(id=999, name="Hello")
|
||||
adapter._auto_create_thread = AsyncMock(return_value=thread)
|
||||
|
||||
captured_events = []
|
||||
|
||||
async def capture_handle(event):
|
||||
captured_events.append(event)
|
||||
|
||||
adapter.handle_message = capture_handle
|
||||
|
||||
msg = _fake_message(_FakeTextChannel(), content="Hello world")
|
||||
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
adapter._auto_create_thread.assert_awaited_once_with(msg)
|
||||
assert len(captured_events) == 1
|
||||
event = captured_events[0]
|
||||
assert event.source.chat_id == "999" # redirected to thread
|
||||
assert event.source.chat_type == "thread"
|
||||
assert event.source.thread_id == "999"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_thread_enabled_by_default_slash_commands(adapter, monkeypatch):
|
||||
"""Without DISCORD_AUTO_THREAD env var, auto-threading is enabled (default: true)."""
|
||||
monkeypatch.delenv("DISCORD_AUTO_THREAD", raising=False)
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
|
||||
fake_thread = _FakeThreadChannel(channel_id=999, name="auto-thread")
|
||||
adapter._auto_create_thread = AsyncMock(return_value=fake_thread)
|
||||
|
||||
captured_events = []
|
||||
|
||||
async def capture_handle(event):
|
||||
captured_events.append(event)
|
||||
|
||||
adapter.handle_message = capture_handle
|
||||
|
||||
msg = _fake_message(_FakeTextChannel())
|
||||
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
adapter._auto_create_thread.assert_awaited_once()
|
||||
assert len(captured_events) == 1
|
||||
assert captured_events[0].source.chat_id == "999" # redirected to thread
|
||||
assert captured_events[0].source.chat_type == "thread"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_thread_can_be_disabled(adapter, monkeypatch):
|
||||
"""Setting DISCORD_AUTO_THREAD=false keeps messages in the channel."""
|
||||
monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
|
||||
adapter._auto_create_thread = AsyncMock()
|
||||
|
||||
captured_events = []
|
||||
|
||||
async def capture_handle(event):
|
||||
captured_events.append(event)
|
||||
|
||||
adapter.handle_message = capture_handle
|
||||
|
||||
msg = _fake_message(_FakeTextChannel())
|
||||
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
adapter._auto_create_thread.assert_not_awaited()
|
||||
assert len(captured_events) == 1
|
||||
assert captured_events[0].source.chat_id == "100" # stays in channel
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_thread_skips_threads_and_dms(adapter, monkeypatch):
|
||||
"""Auto-thread should not create threads inside existing threads."""
|
||||
monkeypatch.setenv("DISCORD_AUTO_THREAD", "true")
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
|
||||
adapter._auto_create_thread = AsyncMock()
|
||||
|
||||
captured_events = []
|
||||
|
||||
async def capture_handle(event):
|
||||
captured_events.append(event)
|
||||
|
||||
adapter.handle_message = capture_handle
|
||||
|
||||
msg = _fake_message(_FakeThreadChannel())
|
||||
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
adapter._auto_create_thread.assert_not_awaited() # should NOT auto-thread
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Config bridge
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_discord_auto_thread_config_bridge(monkeypatch, tmp_path):
|
||||
"""discord.auto_thread in config.yaml should be bridged to DISCORD_AUTO_THREAD env var."""
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
|
||||
# Write a config.yaml the loader will find
|
||||
hermes_dir = tmp_path / ".hermes"
|
||||
hermes_dir.mkdir()
|
||||
config_path = hermes_dir / "config.yaml"
|
||||
config_path.write_text(yaml.dump({
|
||||
"discord": {"auto_thread": True},
|
||||
}))
|
||||
|
||||
monkeypatch.delenv("DISCORD_AUTO_THREAD", raising=False)
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_dir))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
from gateway.config import load_gateway_config
|
||||
load_gateway_config()
|
||||
|
||||
import os
|
||||
assert os.getenv("DISCORD_AUTO_THREAD") == "true"
|
||||
99
hermes_code/tests/gateway/test_discord_system_messages.py
Normal file
99
hermes_code/tests/gateway/test_discord_system_messages.py
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
"""Tests for Discord system message filtering (thread renames, pins, etc.)."""
|
||||
|
||||
import pytest
|
||||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
discord = pytest.importorskip("discord")
|
||||
|
||||
|
||||
def _make_author(*, bot: bool = False, is_self: bool = False):
|
||||
"""Create a mock Discord author."""
|
||||
author = MagicMock()
|
||||
author.bot = bot
|
||||
author.id = 99999 if is_self else 12345
|
||||
author.name = "TestBot" if bot else "TestUser"
|
||||
author.display_name = author.name
|
||||
return author
|
||||
|
||||
|
||||
def _make_message(*, author=None, content="hello", msg_type=None):
|
||||
"""Create a mock Discord message with a specific type."""
|
||||
msg = MagicMock()
|
||||
msg.author = author or _make_author()
|
||||
msg.content = content
|
||||
msg.attachments = []
|
||||
msg.mentions = []
|
||||
msg.type = msg_type if msg_type is not None else discord.MessageType.default
|
||||
msg.channel = MagicMock()
|
||||
msg.channel.id = 222
|
||||
msg.channel.name = "test-channel"
|
||||
msg.channel.guild = MagicMock()
|
||||
msg.channel.guild.name = "TestServer"
|
||||
return msg
|
||||
|
||||
|
||||
class TestDiscordSystemMessageFilter(unittest.TestCase):
|
||||
"""Test that Discord system messages (thread renames, pins, etc.) are ignored."""
|
||||
|
||||
def _run_filter(self, message, client_user=None):
|
||||
"""Simulate the on_message filter logic and return whether message was accepted.
|
||||
|
||||
Replicates the guard added to discord.py:
|
||||
if message.type not in (discord.MessageType.default, discord.MessageType.reply):
|
||||
return # ignored
|
||||
"""
|
||||
# Own messages always ignored
|
||||
if message.author == client_user:
|
||||
return False
|
||||
|
||||
# System message filter (the fix being tested)
|
||||
if message.type not in (discord.MessageType.default, discord.MessageType.reply):
|
||||
return False
|
||||
|
||||
return True # message accepted
|
||||
|
||||
def test_default_messages_accepted(self):
|
||||
"""Regular user messages (type=default) should be accepted."""
|
||||
msg = _make_message(msg_type=discord.MessageType.default)
|
||||
self.assertTrue(self._run_filter(msg))
|
||||
|
||||
def test_reply_messages_accepted(self):
|
||||
"""Reply messages (type=reply) should be accepted — users reply to bot messages."""
|
||||
msg = _make_message(msg_type=discord.MessageType.reply)
|
||||
self.assertTrue(self._run_filter(msg))
|
||||
|
||||
def test_thread_rename_ignored(self):
|
||||
"""Thread rename system messages should be ignored."""
|
||||
msg = _make_message(msg_type=discord.MessageType.channel_name_change)
|
||||
self.assertFalse(self._run_filter(msg))
|
||||
|
||||
def test_pins_add_ignored(self):
|
||||
"""Pin notifications should be ignored."""
|
||||
msg = _make_message(msg_type=discord.MessageType.pins_add)
|
||||
self.assertFalse(self._run_filter(msg))
|
||||
|
||||
def test_new_member_ignored(self):
|
||||
"""New member join messages should be ignored."""
|
||||
msg = _make_message(msg_type=discord.MessageType.new_member)
|
||||
self.assertFalse(self._run_filter(msg))
|
||||
|
||||
def test_premium_guild_subscription_ignored(self):
|
||||
"""Boost messages should be ignored."""
|
||||
msg = _make_message(msg_type=discord.MessageType.premium_guild_subscription)
|
||||
self.assertFalse(self._run_filter(msg))
|
||||
|
||||
def test_recipient_add_ignored(self):
|
||||
"""Group DM recipient add messages should be ignored."""
|
||||
msg = _make_message(msg_type=discord.MessageType.recipient_add)
|
||||
self.assertFalse(self._run_filter(msg))
|
||||
|
||||
def test_own_default_messages_still_ignored(self):
|
||||
"""Bot's own messages should still be ignored even if type is default."""
|
||||
bot_user = _make_author(is_self=True)
|
||||
msg = _make_message(author=bot_user, msg_type=discord.MessageType.default)
|
||||
self.assertFalse(self._run_filter(msg, client_user=bot_user))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
83
hermes_code/tests/gateway/test_discord_thread_persistence.py
Normal file
83
hermes_code/tests/gateway/test_discord_thread_persistence.py
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
"""Tests for Discord thread participation persistence.
|
||||
|
||||
Verifies that _bot_participated_threads survives adapter restarts by
|
||||
being persisted to ~/.hermes/discord_threads.json.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestDiscordThreadPersistence:
|
||||
"""Thread IDs are saved to disk and reloaded on init."""
|
||||
|
||||
def _make_adapter(self, tmp_path):
|
||||
"""Build a minimal DiscordAdapter with HERMES_HOME pointed at tmp_path."""
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
|
||||
config = PlatformConfig(enabled=True, token="test-token")
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
return DiscordAdapter(config=config)
|
||||
|
||||
def test_starts_empty_when_no_state_file(self, tmp_path):
|
||||
adapter = self._make_adapter(tmp_path)
|
||||
assert adapter._bot_participated_threads == set()
|
||||
|
||||
def test_track_thread_persists_to_disk(self, tmp_path):
|
||||
adapter = self._make_adapter(tmp_path)
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
adapter._track_thread("111")
|
||||
adapter._track_thread("222")
|
||||
|
||||
state_file = tmp_path / "discord_threads.json"
|
||||
assert state_file.exists()
|
||||
saved = json.loads(state_file.read_text())
|
||||
assert set(saved) == {"111", "222"}
|
||||
|
||||
def test_threads_survive_restart(self, tmp_path):
|
||||
"""Threads tracked by one adapter instance are visible to the next."""
|
||||
adapter1 = self._make_adapter(tmp_path)
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
adapter1._track_thread("aaa")
|
||||
adapter1._track_thread("bbb")
|
||||
|
||||
adapter2 = self._make_adapter(tmp_path)
|
||||
assert "aaa" in adapter2._bot_participated_threads
|
||||
assert "bbb" in adapter2._bot_participated_threads
|
||||
|
||||
def test_duplicate_track_does_not_double_save(self, tmp_path):
|
||||
adapter = self._make_adapter(tmp_path)
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
adapter._track_thread("111")
|
||||
adapter._track_thread("111") # no-op
|
||||
|
||||
saved = json.loads((tmp_path / "discord_threads.json").read_text())
|
||||
assert saved.count("111") == 1
|
||||
|
||||
def test_caps_at_max_tracked_threads(self, tmp_path):
|
||||
adapter = self._make_adapter(tmp_path)
|
||||
adapter._MAX_TRACKED_THREADS = 5
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
for i in range(10):
|
||||
adapter._track_thread(str(i))
|
||||
|
||||
assert len(adapter._bot_participated_threads) == 5
|
||||
|
||||
def test_corrupted_state_file_falls_back_to_empty(self, tmp_path):
|
||||
state_file = tmp_path / "discord_threads.json"
|
||||
state_file.write_text("not valid json{{{")
|
||||
adapter = self._make_adapter(tmp_path)
|
||||
assert adapter._bot_participated_threads == set()
|
||||
|
||||
def test_missing_hermes_home_does_not_crash(self, tmp_path):
|
||||
"""Load/save tolerate missing directories."""
|
||||
fake_home = tmp_path / "nonexistent" / "deep"
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(fake_home)}):
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
# _load should return empty set, not crash
|
||||
threads = DiscordAdapter._load_participated_threads()
|
||||
assert threads == set()
|
||||
157
hermes_code/tests/gateway/test_document_cache.py
Normal file
157
hermes_code/tests/gateway/test_document_cache.py
Normal file
|
|
@ -0,0 +1,157 @@
|
|||
"""
|
||||
Tests for document cache utilities in gateway/platforms/base.py.
|
||||
|
||||
Covers: get_document_cache_dir, cache_document_from_bytes,
|
||||
cleanup_document_cache, SUPPORTED_DOCUMENT_TYPES.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.platforms.base import (
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
cache_document_from_bytes,
|
||||
cleanup_document_cache,
|
||||
get_document_cache_dir,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixture: redirect DOCUMENT_CACHE_DIR to a temp directory for every test
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _redirect_cache(tmp_path, monkeypatch):
|
||||
"""Point the module-level DOCUMENT_CACHE_DIR to a fresh tmp_path."""
|
||||
monkeypatch.setattr(
|
||||
"gateway.platforms.base.DOCUMENT_CACHE_DIR", tmp_path / "doc_cache"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestGetDocumentCacheDir
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGetDocumentCacheDir:
|
||||
def test_creates_directory(self, tmp_path):
|
||||
cache_dir = get_document_cache_dir()
|
||||
assert cache_dir.exists()
|
||||
assert cache_dir.is_dir()
|
||||
|
||||
def test_returns_existing_directory(self):
|
||||
first = get_document_cache_dir()
|
||||
second = get_document_cache_dir()
|
||||
assert first == second
|
||||
assert first.exists()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestCacheDocumentFromBytes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCacheDocumentFromBytes:
|
||||
def test_basic_caching(self):
|
||||
data = b"hello world"
|
||||
path = cache_document_from_bytes(data, "test.txt")
|
||||
assert os.path.exists(path)
|
||||
assert Path(path).read_bytes() == data
|
||||
|
||||
def test_filename_preserved_in_path(self):
|
||||
path = cache_document_from_bytes(b"data", "report.pdf")
|
||||
assert "report.pdf" in os.path.basename(path)
|
||||
|
||||
def test_empty_filename_uses_fallback(self):
|
||||
path = cache_document_from_bytes(b"data", "")
|
||||
assert "document" in os.path.basename(path)
|
||||
|
||||
def test_unique_filenames(self):
|
||||
p1 = cache_document_from_bytes(b"a", "same.txt")
|
||||
p2 = cache_document_from_bytes(b"b", "same.txt")
|
||||
assert p1 != p2
|
||||
|
||||
def test_path_traversal_blocked(self):
|
||||
"""Malicious directory components are stripped — only the leaf name survives."""
|
||||
path = cache_document_from_bytes(b"data", "../../etc/passwd")
|
||||
basename = os.path.basename(path)
|
||||
assert "passwd" in basename
|
||||
# Must NOT contain directory separators
|
||||
assert ".." not in basename
|
||||
# File must reside inside the cache directory
|
||||
cache_dir = get_document_cache_dir()
|
||||
assert Path(path).resolve().is_relative_to(cache_dir.resolve())
|
||||
|
||||
def test_null_bytes_stripped(self):
|
||||
path = cache_document_from_bytes(b"data", "file\x00.pdf")
|
||||
basename = os.path.basename(path)
|
||||
assert "\x00" not in basename
|
||||
assert "file.pdf" in basename
|
||||
|
||||
def test_dot_dot_filename_handled(self):
|
||||
"""A filename that is literally '..' falls back to 'document'."""
|
||||
path = cache_document_from_bytes(b"data", "..")
|
||||
basename = os.path.basename(path)
|
||||
assert "document" in basename
|
||||
|
||||
def test_none_filename_uses_fallback(self):
|
||||
path = cache_document_from_bytes(b"data", None)
|
||||
assert "document" in os.path.basename(path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestCleanupDocumentCache
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCleanupDocumentCache:
|
||||
def test_removes_old_files(self, tmp_path):
|
||||
cache_dir = get_document_cache_dir()
|
||||
old_file = cache_dir / "old.txt"
|
||||
old_file.write_text("old")
|
||||
# Set modification time to 48 hours ago
|
||||
old_mtime = time.time() - 48 * 3600
|
||||
os.utime(old_file, (old_mtime, old_mtime))
|
||||
|
||||
removed = cleanup_document_cache(max_age_hours=24)
|
||||
assert removed == 1
|
||||
assert not old_file.exists()
|
||||
|
||||
def test_keeps_recent_files(self):
|
||||
cache_dir = get_document_cache_dir()
|
||||
recent = cache_dir / "recent.txt"
|
||||
recent.write_text("fresh")
|
||||
|
||||
removed = cleanup_document_cache(max_age_hours=24)
|
||||
assert removed == 0
|
||||
assert recent.exists()
|
||||
|
||||
def test_returns_removed_count(self):
|
||||
cache_dir = get_document_cache_dir()
|
||||
old_time = time.time() - 48 * 3600
|
||||
for i in range(3):
|
||||
f = cache_dir / f"old_{i}.txt"
|
||||
f.write_text("x")
|
||||
os.utime(f, (old_time, old_time))
|
||||
|
||||
assert cleanup_document_cache(max_age_hours=24) == 3
|
||||
|
||||
def test_empty_cache_dir(self):
|
||||
assert cleanup_document_cache(max_age_hours=24) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestSupportedDocumentTypes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSupportedDocumentTypes:
|
||||
def test_all_extensions_have_mime_types(self):
|
||||
for ext, mime in SUPPORTED_DOCUMENT_TYPES.items():
|
||||
assert ext.startswith("."), f"{ext} missing leading dot"
|
||||
assert "/" in mime, f"{mime} is not a valid MIME type"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ext",
|
||||
[".pdf", ".md", ".txt", ".docx", ".xlsx", ".pptx"],
|
||||
)
|
||||
def test_expected_extensions_present(self, ext):
|
||||
assert ext in SUPPORTED_DOCUMENT_TYPES
|
||||
1061
hermes_code/tests/gateway/test_email.py
Normal file
1061
hermes_code/tests/gateway/test_email.py
Normal file
File diff suppressed because it is too large
Load diff
317
hermes_code/tests/gateway/test_extract_local_files.py
Normal file
317
hermes_code/tests/gateway/test_extract_local_files.py
Normal file
|
|
@ -0,0 +1,317 @@
|
|||
"""
|
||||
Tests for extract_local_files() — auto-detection of bare local file paths
|
||||
in model response text for native media delivery.
|
||||
|
||||
Covers: path matching, code-block exclusion, URL rejection, tilde expansion,
|
||||
deduplication, text cleanup, and extension routing.
|
||||
|
||||
Based on PR #1636 by sudoingX (salvaged + hardened).
|
||||
"""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.platforms.base import BasePlatformAdapter
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _extract(content: str, existing_files: set[str] | None = None):
|
||||
"""
|
||||
Run extract_local_files with os.path.isfile mocked to return True
|
||||
for any path in *existing_files* (expanded form). If *existing_files*
|
||||
is None every path passes.
|
||||
"""
|
||||
existing = existing_files
|
||||
|
||||
def fake_isfile(p):
|
||||
if existing is None:
|
||||
return True
|
||||
return p in existing
|
||||
|
||||
def fake_expanduser(p):
|
||||
if p.startswith("~/"):
|
||||
return "/home/user" + p[1:]
|
||||
return p
|
||||
|
||||
with patch("os.path.isfile", side_effect=fake_isfile), \
|
||||
patch("os.path.expanduser", side_effect=fake_expanduser):
|
||||
return BasePlatformAdapter.extract_local_files(content)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Basic detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBasicDetection:
|
||||
|
||||
def test_absolute_path_image(self):
|
||||
paths, cleaned = _extract("Here is the screenshot /root/screenshots/game.png enjoy")
|
||||
assert paths == ["/root/screenshots/game.png"]
|
||||
assert "/root/screenshots/game.png" not in cleaned
|
||||
assert "Here is the screenshot" in cleaned
|
||||
|
||||
def test_tilde_path_image(self):
|
||||
paths, cleaned = _extract("Check out ~/photos/cat.jpg for the cat")
|
||||
assert paths == ["/home/user/photos/cat.jpg"]
|
||||
assert "~/photos/cat.jpg" not in cleaned
|
||||
|
||||
def test_video_extensions(self):
|
||||
for ext in (".mp4", ".mov", ".avi", ".mkv", ".webm"):
|
||||
text = f"Video at /tmp/clip{ext} here"
|
||||
paths, _ = _extract(text)
|
||||
assert len(paths) == 1, f"Failed for {ext}"
|
||||
assert paths[0] == f"/tmp/clip{ext}"
|
||||
|
||||
def test_image_extensions(self):
|
||||
for ext in (".png", ".jpg", ".jpeg", ".gif", ".webp"):
|
||||
text = f"Image at /tmp/pic{ext} here"
|
||||
paths, _ = _extract(text)
|
||||
assert len(paths) == 1, f"Failed for {ext}"
|
||||
assert paths[0] == f"/tmp/pic{ext}"
|
||||
|
||||
def test_case_insensitive_extension(self):
|
||||
paths, _ = _extract("See /tmp/PHOTO.PNG and /tmp/vid.MP4 now")
|
||||
assert len(paths) == 2
|
||||
|
||||
def test_multiple_paths(self):
|
||||
text = "First /tmp/a.png then /tmp/b.jpg and /tmp/c.mp4 done"
|
||||
paths, cleaned = _extract(text)
|
||||
assert len(paths) == 3
|
||||
assert "/tmp/a.png" in paths
|
||||
assert "/tmp/b.jpg" in paths
|
||||
assert "/tmp/c.mp4" in paths
|
||||
for p in paths:
|
||||
assert p not in cleaned
|
||||
|
||||
def test_path_at_line_start(self):
|
||||
paths, _ = _extract("/var/data/image.png")
|
||||
assert paths == ["/var/data/image.png"]
|
||||
|
||||
def test_path_at_end_of_line(self):
|
||||
paths, _ = _extract("saved to /var/data/image.png")
|
||||
assert paths == ["/var/data/image.png"]
|
||||
|
||||
def test_path_with_dots_in_directory(self):
|
||||
paths, _ = _extract("See /opt/my.app/assets/logo.png here")
|
||||
assert paths == ["/opt/my.app/assets/logo.png"]
|
||||
|
||||
def test_path_with_hyphens(self):
|
||||
paths, _ = _extract("File at /tmp/my-screenshot-2024.png done")
|
||||
assert paths == ["/tmp/my-screenshot-2024.png"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Non-existent files are skipped
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestIsfileGuard:
|
||||
|
||||
def test_nonexistent_path_skipped(self):
|
||||
"""Paths that don't exist on disk are not extracted."""
|
||||
paths, cleaned = _extract(
|
||||
"See /tmp/nope.png here",
|
||||
existing_files=set(), # nothing exists
|
||||
)
|
||||
assert paths == []
|
||||
assert "/tmp/nope.png" in cleaned # not stripped
|
||||
|
||||
def test_only_existing_paths_extracted(self):
|
||||
"""Mix of existing and non-existing — only existing are returned."""
|
||||
paths, cleaned = _extract(
|
||||
"A /tmp/real.png and /tmp/fake.jpg end",
|
||||
existing_files={"/tmp/real.png"},
|
||||
)
|
||||
assert paths == ["/tmp/real.png"]
|
||||
assert "/tmp/real.png" not in cleaned
|
||||
assert "/tmp/fake.jpg" in cleaned
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# URL false-positive prevention
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestURLRejection:
|
||||
|
||||
def test_https_url_not_matched(self):
|
||||
"""Paths embedded in HTTP URLs must not be extracted."""
|
||||
paths, cleaned = _extract("Visit https://example.com/images/photo.png for details")
|
||||
# The regex lookbehind should prevent matching the URL's path segment
|
||||
# Even if it did match, isfile would be False for /images/photo.png
|
||||
# (we mock isfile to True-for-all here, so the lookbehind is the guard)
|
||||
assert paths == []
|
||||
assert "https://example.com/images/photo.png" in cleaned
|
||||
|
||||
def test_http_url_not_matched(self):
|
||||
paths, _ = _extract("See http://cdn.example.com/assets/banner.jpg here")
|
||||
assert paths == []
|
||||
|
||||
def test_file_url_not_matched(self):
|
||||
paths, _ = _extract("Open file:///home/user/doc.png in browser")
|
||||
# file:// has :// before /home so lookbehind blocks it
|
||||
assert paths == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Code block exclusion
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCodeBlockExclusion:
|
||||
|
||||
def test_fenced_code_block_skipped(self):
|
||||
text = "Here's how:\n```python\nimg = open('/tmp/image.png')\n```\nDone."
|
||||
paths, cleaned = _extract(text)
|
||||
assert paths == []
|
||||
assert "/tmp/image.png" in cleaned # not stripped
|
||||
|
||||
def test_inline_code_skipped(self):
|
||||
text = "Use the path `/tmp/image.png` in your config"
|
||||
paths, cleaned = _extract(text)
|
||||
assert paths == []
|
||||
assert "`/tmp/image.png`" in cleaned
|
||||
|
||||
def test_path_outside_code_block_still_matched(self):
|
||||
text = (
|
||||
"```\ncode: /tmp/inside.png\n```\n"
|
||||
"But this one is real: /tmp/outside.png"
|
||||
)
|
||||
paths, _ = _extract(text, existing_files={"/tmp/outside.png"})
|
||||
assert paths == ["/tmp/outside.png"]
|
||||
|
||||
def test_mixed_inline_code_and_bare_path(self):
|
||||
text = "Config uses `/etc/app/bg.png` but output is /tmp/result.jpg"
|
||||
paths, cleaned = _extract(text, existing_files={"/tmp/result.jpg"})
|
||||
assert paths == ["/tmp/result.jpg"]
|
||||
assert "`/etc/app/bg.png`" in cleaned
|
||||
assert "/tmp/result.jpg" not in cleaned
|
||||
|
||||
def test_multiline_fenced_block(self):
|
||||
text = (
|
||||
"```bash\n"
|
||||
"cp /source/a.png /dest/b.png\n"
|
||||
"mv /source/c.mp4 /dest/d.mp4\n"
|
||||
"```\n"
|
||||
"Files are ready."
|
||||
)
|
||||
paths, _ = _extract(text)
|
||||
assert paths == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Deduplication
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDeduplication:
|
||||
|
||||
def test_duplicate_paths_deduplicated(self):
|
||||
text = "See /tmp/img.png and also /tmp/img.png again"
|
||||
paths, _ = _extract(text)
|
||||
assert paths == ["/tmp/img.png"]
|
||||
|
||||
def test_tilde_and_expanded_same_file(self):
|
||||
"""~/photos/a.png and /home/user/photos/a.png are the same file."""
|
||||
text = "See ~/photos/a.png and /home/user/photos/a.png here"
|
||||
paths, _ = _extract(text, existing_files={"/home/user/photos/a.png"})
|
||||
assert len(paths) == 1
|
||||
assert paths[0] == "/home/user/photos/a.png"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Text cleanup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTextCleanup:
|
||||
|
||||
def test_path_removed_from_text(self):
|
||||
paths, cleaned = _extract("Before /tmp/x.png after")
|
||||
assert "Before" in cleaned
|
||||
assert "after" in cleaned
|
||||
assert "/tmp/x.png" not in cleaned
|
||||
|
||||
def test_excessive_blank_lines_collapsed(self):
|
||||
text = "Before\n\n\n/tmp/x.png\n\n\nAfter"
|
||||
_, cleaned = _extract(text)
|
||||
assert "\n\n\n" not in cleaned
|
||||
|
||||
def test_no_paths_text_unchanged(self):
|
||||
text = "This is a normal response with no file paths."
|
||||
paths, cleaned = _extract(text)
|
||||
assert paths == []
|
||||
assert cleaned == text
|
||||
|
||||
def test_tilde_form_cleaned_from_text(self):
|
||||
"""The raw ~/... form should be removed, not the expanded /home/user/... form."""
|
||||
text = "Output saved to ~/result.png for review"
|
||||
paths, cleaned = _extract(text)
|
||||
assert paths == ["/home/user/result.png"]
|
||||
assert "~/result.png" not in cleaned
|
||||
|
||||
def test_only_path_in_text(self):
|
||||
"""If the response is just a path, cleaned text is empty."""
|
||||
paths, cleaned = _extract("/tmp/screenshot.png")
|
||||
assert paths == ["/tmp/screenshot.png"]
|
||||
assert cleaned == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Edge cases
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEdgeCases:
|
||||
|
||||
def test_empty_string(self):
|
||||
paths, cleaned = _extract("")
|
||||
assert paths == []
|
||||
assert cleaned == ""
|
||||
|
||||
def test_no_media_extensions(self):
|
||||
"""Non-media extensions should not be matched."""
|
||||
paths, _ = _extract("See /tmp/data.csv and /tmp/script.py and /tmp/notes.txt")
|
||||
assert paths == []
|
||||
|
||||
def test_path_with_spaces_not_matched(self):
|
||||
"""Paths with spaces are intentionally not matched (avoids false positives)."""
|
||||
paths, _ = _extract("File at /tmp/my file.png here")
|
||||
assert paths == []
|
||||
|
||||
def test_windows_path_not_matched(self):
|
||||
"""Windows-style paths should not match."""
|
||||
paths, _ = _extract("See C:\\Users\\test\\image.png")
|
||||
assert paths == []
|
||||
|
||||
def test_relative_path_not_matched(self):
|
||||
"""Relative paths like ./image.png should not match."""
|
||||
paths, _ = _extract("File at ./screenshots/image.png here")
|
||||
assert paths == []
|
||||
|
||||
def test_bare_filename_not_matched(self):
|
||||
"""Just 'image.png' without a path should not match."""
|
||||
paths, _ = _extract("Open image.png to see")
|
||||
assert paths == []
|
||||
|
||||
def test_path_followed_by_punctuation(self):
|
||||
"""Path followed by comma, period, paren should still match."""
|
||||
for suffix in [",", ".", ")", ":", ";"]:
|
||||
text = f"See /tmp/img.png{suffix} details"
|
||||
paths, _ = _extract(text)
|
||||
assert len(paths) == 1, f"Failed with suffix '{suffix}'"
|
||||
|
||||
def test_path_in_parentheses(self):
|
||||
paths, _ = _extract("(see /tmp/img.png)")
|
||||
assert paths == ["/tmp/img.png"]
|
||||
|
||||
def test_path_in_quotes(self):
|
||||
paths, _ = _extract('The file is "/tmp/img.png" right here')
|
||||
assert paths == ["/tmp/img.png"]
|
||||
|
||||
def test_deep_nested_path(self):
|
||||
paths, _ = _extract("At /a/b/c/d/e/f/g/h/image.png end")
|
||||
assert paths == ["/a/b/c/d/e/f/g/h/image.png"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
167
hermes_code/tests/gateway/test_flush_memory_stale_guard.py
Normal file
167
hermes_code/tests/gateway/test_flush_memory_stale_guard.py
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
"""Tests for memory flush stale-overwrite prevention (#2670).
|
||||
|
||||
Verifies that:
|
||||
1. Cron sessions are skipped (no flush for headless cron runs)
|
||||
2. Current memory state is injected into the flush prompt so the
|
||||
flush agent can see what's already saved and avoid overwrites
|
||||
3. The flush still works normally when memory files don't exist
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch, call
|
||||
|
||||
|
||||
def _make_runner():
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner._honcho_managers = {}
|
||||
runner._honcho_configs = {}
|
||||
runner._running_agents = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner.adapters = {}
|
||||
runner.hooks = MagicMock()
|
||||
runner.session_store = MagicMock()
|
||||
return runner
|
||||
|
||||
|
||||
_TRANSCRIPT_4_MSGS = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "hi there"},
|
||||
{"role": "user", "content": "remember my name is Alice"},
|
||||
{"role": "assistant", "content": "Got it, Alice!"},
|
||||
]
|
||||
|
||||
|
||||
class TestCronSessionBypass:
|
||||
"""Cron sessions should never trigger a memory flush."""
|
||||
|
||||
def test_cron_session_skipped(self):
|
||||
runner = _make_runner()
|
||||
runner._flush_memories_for_session("cron_job123_20260323_120000")
|
||||
# session_store.load_transcript should never be called
|
||||
runner.session_store.load_transcript.assert_not_called()
|
||||
|
||||
def test_cron_session_with_honcho_key_skipped(self):
|
||||
runner = _make_runner()
|
||||
runner._flush_memories_for_session("cron_daily_20260323", "some-honcho-key")
|
||||
runner.session_store.load_transcript.assert_not_called()
|
||||
|
||||
def test_non_cron_session_proceeds(self):
|
||||
"""Non-cron sessions should still attempt the flush."""
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = []
|
||||
runner._flush_memories_for_session("session_abc123")
|
||||
runner.session_store.load_transcript.assert_called_once_with("session_abc123")
|
||||
|
||||
|
||||
class TestMemoryInjection:
|
||||
"""The flush prompt should include current memory state from disk."""
|
||||
|
||||
def test_memory_content_injected_into_flush_prompt(self, tmp_path):
|
||||
"""When memory files exist, their content appears in the flush prompt."""
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
|
||||
|
||||
tmp_agent = MagicMock()
|
||||
memory_dir = tmp_path / "memories"
|
||||
memory_dir.mkdir()
|
||||
(memory_dir / "MEMORY.md").write_text("Agent knows Python\n§\nUser prefers dark mode")
|
||||
(memory_dir / "USER.md").write_text("Name: Alice\n§\nTimezone: PST")
|
||||
|
||||
with (
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
|
||||
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
|
||||
patch("run_agent.AIAgent", return_value=tmp_agent),
|
||||
# Intercept `from tools.memory_tool import MEMORY_DIR` inside the function
|
||||
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=memory_dir)}),
|
||||
):
|
||||
runner._flush_memories_for_session("session_123")
|
||||
|
||||
tmp_agent.run_conversation.assert_called_once()
|
||||
call_kwargs = tmp_agent.run_conversation.call_args.kwargs
|
||||
flush_prompt = call_kwargs.get("user_message", "")
|
||||
|
||||
# Verify both memory sections appear in the prompt
|
||||
assert "Agent knows Python" in flush_prompt
|
||||
assert "User prefers dark mode" in flush_prompt
|
||||
assert "Name: Alice" in flush_prompt
|
||||
assert "Timezone: PST" in flush_prompt
|
||||
# Verify the stale-overwrite warning is present
|
||||
assert "Do NOT overwrite or remove entries" in flush_prompt
|
||||
assert "current live state of memory" in flush_prompt
|
||||
|
||||
def test_flush_works_without_memory_files(self, tmp_path):
|
||||
"""When no memory files exist, flush still runs without the guard."""
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
|
||||
|
||||
tmp_agent = MagicMock()
|
||||
empty_dir = tmp_path / "no_memories"
|
||||
empty_dir.mkdir()
|
||||
|
||||
with (
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
|
||||
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
|
||||
patch("run_agent.AIAgent", return_value=tmp_agent),
|
||||
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=empty_dir)}),
|
||||
):
|
||||
runner._flush_memories_for_session("session_456")
|
||||
|
||||
# Should still run, just without the memory guard section
|
||||
tmp_agent.run_conversation.assert_called_once()
|
||||
flush_prompt = tmp_agent.run_conversation.call_args.kwargs.get("user_message", "")
|
||||
assert "Do NOT overwrite or remove entries" not in flush_prompt
|
||||
assert "Review the conversation above" in flush_prompt
|
||||
|
||||
def test_empty_memory_files_no_injection(self, tmp_path):
|
||||
"""Empty memory files should not trigger the guard section."""
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
|
||||
|
||||
tmp_agent = MagicMock()
|
||||
memory_dir = tmp_path / "memories"
|
||||
memory_dir.mkdir()
|
||||
(memory_dir / "MEMORY.md").write_text("")
|
||||
(memory_dir / "USER.md").write_text(" \n ") # whitespace only
|
||||
|
||||
with (
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
|
||||
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
|
||||
patch("run_agent.AIAgent", return_value=tmp_agent),
|
||||
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=memory_dir)}),
|
||||
):
|
||||
runner._flush_memories_for_session("session_789")
|
||||
|
||||
tmp_agent.run_conversation.assert_called_once()
|
||||
flush_prompt = tmp_agent.run_conversation.call_args.kwargs.get("user_message", "")
|
||||
# No memory content → no guard section
|
||||
assert "current live state of memory" not in flush_prompt
|
||||
|
||||
|
||||
class TestFlushPromptStructure:
|
||||
"""Verify the flush prompt retains its core instructions."""
|
||||
|
||||
def test_core_instructions_present(self):
|
||||
"""The flush prompt should still contain the original guidance."""
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
|
||||
|
||||
tmp_agent = MagicMock()
|
||||
|
||||
with (
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
|
||||
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
|
||||
patch("run_agent.AIAgent", return_value=tmp_agent),
|
||||
# Make the import fail gracefully so we test without memory files
|
||||
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=Path("/nonexistent"))}),
|
||||
):
|
||||
runner._flush_memories_for_session("session_struct")
|
||||
|
||||
flush_prompt = tmp_agent.run_conversation.call_args.kwargs.get("user_message", "")
|
||||
assert "automatically reset" in flush_prompt
|
||||
assert "Save any important facts" in flush_prompt
|
||||
assert "consider saving it as a skill" in flush_prompt
|
||||
assert "Do NOT respond to the user" in flush_prompt
|
||||
106
hermes_code/tests/gateway/test_gateway_shutdown.py
Normal file
106
hermes_code/tests/gateway/test_gateway_shutdown.py
Normal file
|
|
@ -0,0 +1,106 @@
|
|||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, SendResult
|
||||
from gateway.run import GatewayRunner
|
||||
from gateway.session import SessionSource, build_session_key
|
||||
|
||||
|
||||
class StubAdapter(BasePlatformAdapter):
|
||||
def __init__(self):
|
||||
super().__init__(PlatformConfig(enabled=True, token="***"), Platform.TELEGRAM)
|
||||
|
||||
async def connect(self):
|
||||
return True
|
||||
|
||||
async def disconnect(self):
|
||||
return None
|
||||
|
||||
async def send(self, chat_id, content, reply_to=None, metadata=None):
|
||||
return SendResult(success=True, message_id="1")
|
||||
|
||||
async def send_typing(self, chat_id, metadata=None):
|
||||
return None
|
||||
|
||||
async def get_chat_info(self, chat_id):
|
||||
return {"id": chat_id}
|
||||
|
||||
|
||||
def _source(chat_id="123456", chat_type="dm"):
|
||||
return SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id=chat_id,
|
||||
chat_type=chat_type,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_background_tasks_cancels_inflight_message_processing():
|
||||
adapter = StubAdapter()
|
||||
release = asyncio.Event()
|
||||
|
||||
async def block_forever(_event):
|
||||
await release.wait()
|
||||
return None
|
||||
|
||||
adapter.set_message_handler(block_forever)
|
||||
event = MessageEvent(text="work", source=_source(), message_id="1")
|
||||
|
||||
await adapter.handle_message(event)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
session_key = build_session_key(event.source)
|
||||
assert session_key in adapter._active_sessions
|
||||
assert adapter._background_tasks
|
||||
|
||||
await adapter.cancel_background_tasks()
|
||||
|
||||
assert adapter._background_tasks == set()
|
||||
assert adapter._active_sessions == {}
|
||||
assert adapter._pending_messages == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks():
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")})
|
||||
runner._running = True
|
||||
runner._shutdown_event = asyncio.Event()
|
||||
runner._exit_reason = None
|
||||
runner._pending_messages = {"session": "pending text"}
|
||||
runner._pending_approvals = {"session": {"command": "rm -rf /tmp/x"}}
|
||||
runner._shutdown_all_gateway_honcho = lambda: None
|
||||
|
||||
adapter = StubAdapter()
|
||||
release = asyncio.Event()
|
||||
|
||||
async def block_forever(_event):
|
||||
await release.wait()
|
||||
return None
|
||||
|
||||
adapter.set_message_handler(block_forever)
|
||||
event = MessageEvent(text="work", source=_source(), message_id="1")
|
||||
await adapter.handle_message(event)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
disconnect_mock = AsyncMock()
|
||||
adapter.disconnect = disconnect_mock
|
||||
|
||||
session_key = build_session_key(event.source)
|
||||
running_agent = MagicMock()
|
||||
runner._running_agents = {session_key: running_agent}
|
||||
runner.adapters = {Platform.TELEGRAM: adapter}
|
||||
|
||||
with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"):
|
||||
await runner.stop()
|
||||
|
||||
running_agent.interrupt.assert_called_once_with("Gateway shutting down")
|
||||
disconnect_mock.assert_awaited_once()
|
||||
assert runner.adapters == {}
|
||||
assert runner._running_agents == {}
|
||||
assert runner._pending_messages == {}
|
||||
assert runner._pending_approvals == {}
|
||||
assert runner._shutdown_event.is_set() is True
|
||||
622
hermes_code/tests/gateway/test_homeassistant.py
Normal file
622
hermes_code/tests/gateway/test_homeassistant.py
Normal file
|
|
@ -0,0 +1,622 @@
|
|||
"""Tests for the Home Assistant gateway adapter.
|
||||
|
||||
Tests real logic: state change formatting, event filtering pipeline,
|
||||
cooldown behavior, config integration, and adapter initialization.
|
||||
"""
|
||||
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import (
|
||||
GatewayConfig,
|
||||
Platform,
|
||||
PlatformConfig,
|
||||
)
|
||||
from gateway.platforms.homeassistant import (
|
||||
HomeAssistantAdapter,
|
||||
check_ha_requirements,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_ha_requirements
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCheckRequirements:
|
||||
def test_returns_false_without_token(self, monkeypatch):
|
||||
monkeypatch.delenv("HASS_TOKEN", raising=False)
|
||||
assert check_ha_requirements() is False
|
||||
|
||||
def test_returns_true_with_token(self, monkeypatch):
|
||||
monkeypatch.setenv("HASS_TOKEN", "test-token")
|
||||
assert check_ha_requirements() is True
|
||||
|
||||
@patch("gateway.platforms.homeassistant.AIOHTTP_AVAILABLE", False)
|
||||
def test_returns_false_without_aiohttp(self, monkeypatch):
|
||||
monkeypatch.setenv("HASS_TOKEN", "test-token")
|
||||
assert check_ha_requirements() is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _format_state_change - pure function, all domain branches
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFormatStateChange:
|
||||
@staticmethod
|
||||
def fmt(entity_id, old_state, new_state):
|
||||
return HomeAssistantAdapter._format_state_change(entity_id, old_state, new_state)
|
||||
|
||||
def test_climate_includes_temperatures(self):
|
||||
msg = self.fmt(
|
||||
"climate.thermostat",
|
||||
{"state": "off"},
|
||||
{"state": "heat", "attributes": {
|
||||
"friendly_name": "Main Thermostat",
|
||||
"current_temperature": 21.5,
|
||||
"temperature": 23,
|
||||
}},
|
||||
)
|
||||
assert "Main Thermostat" in msg
|
||||
assert "'off'" in msg and "'heat'" in msg
|
||||
assert "21.5" in msg and "23" in msg
|
||||
|
||||
def test_sensor_includes_unit(self):
|
||||
msg = self.fmt(
|
||||
"sensor.temperature",
|
||||
{"state": "22.5"},
|
||||
{"state": "25.1", "attributes": {
|
||||
"friendly_name": "Living Room Temp",
|
||||
"unit_of_measurement": "C",
|
||||
}},
|
||||
)
|
||||
assert "22.5C" in msg and "25.1C" in msg
|
||||
assert "Living Room Temp" in msg
|
||||
|
||||
def test_sensor_without_unit(self):
|
||||
msg = self.fmt(
|
||||
"sensor.count",
|
||||
{"state": "5"},
|
||||
{"state": "10", "attributes": {"friendly_name": "Counter"}},
|
||||
)
|
||||
assert "5" in msg and "10" in msg
|
||||
|
||||
def test_binary_sensor_on(self):
|
||||
msg = self.fmt(
|
||||
"binary_sensor.motion",
|
||||
{"state": "off"},
|
||||
{"state": "on", "attributes": {"friendly_name": "Hallway Motion"}},
|
||||
)
|
||||
assert "triggered" in msg
|
||||
assert "Hallway Motion" in msg
|
||||
|
||||
def test_binary_sensor_off(self):
|
||||
msg = self.fmt(
|
||||
"binary_sensor.door",
|
||||
{"state": "on"},
|
||||
{"state": "off", "attributes": {"friendly_name": "Front Door"}},
|
||||
)
|
||||
assert "cleared" in msg
|
||||
|
||||
def test_light_turned_on(self):
|
||||
msg = self.fmt(
|
||||
"light.bedroom",
|
||||
{"state": "off"},
|
||||
{"state": "on", "attributes": {"friendly_name": "Bedroom Light"}},
|
||||
)
|
||||
assert "turned on" in msg
|
||||
|
||||
def test_switch_turned_off(self):
|
||||
msg = self.fmt(
|
||||
"switch.heater",
|
||||
{"state": "on"},
|
||||
{"state": "off", "attributes": {"friendly_name": "Heater"}},
|
||||
)
|
||||
assert "turned off" in msg
|
||||
|
||||
def test_fan_domain_uses_light_switch_branch(self):
|
||||
msg = self.fmt(
|
||||
"fan.ceiling",
|
||||
{"state": "off"},
|
||||
{"state": "on", "attributes": {"friendly_name": "Ceiling Fan"}},
|
||||
)
|
||||
assert "turned on" in msg
|
||||
|
||||
def test_alarm_panel(self):
|
||||
msg = self.fmt(
|
||||
"alarm_control_panel.home",
|
||||
{"state": "disarmed"},
|
||||
{"state": "armed_away", "attributes": {"friendly_name": "Home Alarm"}},
|
||||
)
|
||||
assert "Home Alarm" in msg
|
||||
assert "armed_away" in msg and "disarmed" in msg
|
||||
|
||||
def test_generic_domain_includes_entity_id(self):
|
||||
msg = self.fmt(
|
||||
"automation.morning",
|
||||
{"state": "off"},
|
||||
{"state": "on", "attributes": {"friendly_name": "Morning Routine"}},
|
||||
)
|
||||
assert "automation.morning" in msg
|
||||
assert "Morning Routine" in msg
|
||||
|
||||
def test_same_state_returns_none(self):
|
||||
assert self.fmt(
|
||||
"sensor.temp",
|
||||
{"state": "22"},
|
||||
{"state": "22", "attributes": {"friendly_name": "Temp"}},
|
||||
) is None
|
||||
|
||||
def test_empty_new_state_returns_none(self):
|
||||
assert self.fmt("light.x", {"state": "on"}, {}) is None
|
||||
|
||||
def test_no_old_state_uses_unknown(self):
|
||||
msg = self.fmt(
|
||||
"light.new",
|
||||
None,
|
||||
{"state": "on", "attributes": {"friendly_name": "New Light"}},
|
||||
)
|
||||
assert msg is not None
|
||||
assert "New Light" in msg
|
||||
|
||||
def test_uses_entity_id_when_no_friendly_name(self):
|
||||
msg = self.fmt(
|
||||
"sensor.unnamed",
|
||||
{"state": "1"},
|
||||
{"state": "2", "attributes": {}},
|
||||
)
|
||||
assert "sensor.unnamed" in msg
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Adapter initialization from config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAdapterInit:
|
||||
def test_url_and_token_from_config_extra(self, monkeypatch):
|
||||
monkeypatch.delenv("HASS_URL", raising=False)
|
||||
monkeypatch.delenv("HASS_TOKEN", raising=False)
|
||||
|
||||
config = PlatformConfig(
|
||||
enabled=True,
|
||||
token="config-token",
|
||||
extra={"url": "http://192.168.1.50:8123"},
|
||||
)
|
||||
adapter = HomeAssistantAdapter(config)
|
||||
assert adapter._hass_token == "config-token"
|
||||
assert adapter._hass_url == "http://192.168.1.50:8123"
|
||||
|
||||
def test_url_fallback_to_env(self, monkeypatch):
|
||||
monkeypatch.setenv("HASS_URL", "http://env-host:8123")
|
||||
monkeypatch.setenv("HASS_TOKEN", "env-tok")
|
||||
|
||||
config = PlatformConfig(enabled=True, token="env-tok")
|
||||
adapter = HomeAssistantAdapter(config)
|
||||
assert adapter._hass_url == "http://env-host:8123"
|
||||
|
||||
def test_trailing_slash_stripped(self):
|
||||
config = PlatformConfig(
|
||||
enabled=True, token="t",
|
||||
extra={"url": "http://ha.local:8123/"},
|
||||
)
|
||||
adapter = HomeAssistantAdapter(config)
|
||||
assert adapter._hass_url == "http://ha.local:8123"
|
||||
|
||||
def test_watch_filters_parsed(self):
|
||||
config = PlatformConfig(
|
||||
enabled=True, token="***",
|
||||
extra={
|
||||
"watch_domains": ["climate", "binary_sensor"],
|
||||
"watch_entities": ["sensor.special"],
|
||||
"ignore_entities": ["sensor.uptime", "sensor.cpu"],
|
||||
"cooldown_seconds": 120,
|
||||
},
|
||||
)
|
||||
adapter = HomeAssistantAdapter(config)
|
||||
assert adapter._watch_domains == {"climate", "binary_sensor"}
|
||||
assert adapter._watch_entities == {"sensor.special"}
|
||||
assert adapter._ignore_entities == {"sensor.uptime", "sensor.cpu"}
|
||||
assert adapter._watch_all is False
|
||||
assert adapter._cooldown_seconds == 120
|
||||
|
||||
def test_watch_all_parsed(self):
|
||||
config = PlatformConfig(
|
||||
enabled=True, token="***",
|
||||
extra={"watch_all": True},
|
||||
)
|
||||
adapter = HomeAssistantAdapter(config)
|
||||
assert adapter._watch_all is True
|
||||
|
||||
def test_defaults_when_no_extra(self, monkeypatch):
|
||||
monkeypatch.setenv("HASS_TOKEN", "tok")
|
||||
config = PlatformConfig(enabled=True, token="***")
|
||||
adapter = HomeAssistantAdapter(config)
|
||||
assert adapter._watch_domains == set()
|
||||
assert adapter._watch_entities == set()
|
||||
assert adapter._ignore_entities == set()
|
||||
assert adapter._watch_all is False
|
||||
assert adapter._cooldown_seconds == 30
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Event filtering pipeline (_handle_ha_event)
|
||||
#
|
||||
# We mock handle_message (not our code, it's the base class pipeline) to
|
||||
# capture the MessageEvent that _handle_ha_event produces.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_adapter(**extra) -> HomeAssistantAdapter:
|
||||
config = PlatformConfig(enabled=True, token="tok", extra=extra)
|
||||
adapter = HomeAssistantAdapter(config)
|
||||
adapter.handle_message = AsyncMock()
|
||||
return adapter
|
||||
|
||||
|
||||
def _make_event(entity_id, old_state, new_state, old_attrs=None, new_attrs=None):
|
||||
return {
|
||||
"data": {
|
||||
"entity_id": entity_id,
|
||||
"old_state": {"state": old_state, "attributes": old_attrs or {}},
|
||||
"new_state": {"state": new_state, "attributes": new_attrs or {"friendly_name": entity_id}},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class TestEventFilteringPipeline:
|
||||
@pytest.mark.asyncio
|
||||
async def test_ignored_entity_not_forwarded(self):
|
||||
adapter = _make_adapter(watch_all=True, ignore_entities=["sensor.uptime"])
|
||||
await adapter._handle_ha_event(_make_event("sensor.uptime", "100", "101"))
|
||||
adapter.handle_message.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unwatched_domain_not_forwarded(self):
|
||||
adapter = _make_adapter(watch_domains=["climate"])
|
||||
await adapter._handle_ha_event(_make_event("light.bedroom", "off", "on"))
|
||||
adapter.handle_message.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_watched_domain_forwarded(self):
|
||||
adapter = _make_adapter(watch_domains=["climate"], cooldown_seconds=0)
|
||||
await adapter._handle_ha_event(
|
||||
_make_event("climate.thermostat", "off", "heat",
|
||||
new_attrs={"friendly_name": "Thermostat", "current_temperature": 20, "temperature": 22})
|
||||
)
|
||||
adapter.handle_message.assert_called_once()
|
||||
|
||||
# Verify the actual MessageEvent text content
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert "Thermostat" in msg_event.text
|
||||
assert "heat" in msg_event.text
|
||||
assert msg_event.source.platform == Platform.HOMEASSISTANT
|
||||
assert msg_event.source.chat_id == "ha_events"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_watched_entity_forwarded(self):
|
||||
adapter = _make_adapter(watch_entities=["sensor.important"], cooldown_seconds=0)
|
||||
await adapter._handle_ha_event(
|
||||
_make_event("sensor.important", "10", "20",
|
||||
new_attrs={"friendly_name": "Important Sensor", "unit_of_measurement": "W"})
|
||||
)
|
||||
adapter.handle_message.assert_called_once()
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert "10W" in msg_event.text and "20W" in msg_event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_filters_blocks_everything(self):
|
||||
"""Without watch_domains, watch_entities, or watch_all, events are dropped."""
|
||||
adapter = _make_adapter(cooldown_seconds=0)
|
||||
await adapter._handle_ha_event(_make_event("cover.blinds", "closed", "open"))
|
||||
adapter.handle_message.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_watch_all_passes_everything(self):
|
||||
"""With watch_all=True and no specific filters, all events pass through."""
|
||||
adapter = _make_adapter(watch_all=True, cooldown_seconds=0)
|
||||
await adapter._handle_ha_event(_make_event("cover.blinds", "closed", "open"))
|
||||
adapter.handle_message.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_same_state_not_forwarded(self):
|
||||
adapter = _make_adapter(watch_all=True, cooldown_seconds=0)
|
||||
await adapter._handle_ha_event(_make_event("light.x", "on", "on"))
|
||||
adapter.handle_message.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_entity_id_skipped(self):
|
||||
adapter = _make_adapter(watch_all=True)
|
||||
await adapter._handle_ha_event({"data": {"entity_id": ""}})
|
||||
adapter.handle_message.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_event_has_correct_source(self):
|
||||
adapter = _make_adapter(watch_all=True, cooldown_seconds=0)
|
||||
await adapter._handle_ha_event(
|
||||
_make_event("light.test", "off", "on",
|
||||
new_attrs={"friendly_name": "Test Light"})
|
||||
)
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert msg_event.source.user_name == "Home Assistant"
|
||||
assert msg_event.source.chat_type == "channel"
|
||||
assert msg_event.message_id.startswith("ha_light.test_")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cooldown behavior
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCooldown:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cooldown_blocks_rapid_events(self):
|
||||
adapter = _make_adapter(watch_all=True, cooldown_seconds=60)
|
||||
|
||||
event = _make_event("sensor.temp", "20", "21",
|
||||
new_attrs={"friendly_name": "Temp"})
|
||||
await adapter._handle_ha_event(event)
|
||||
assert adapter.handle_message.call_count == 1
|
||||
|
||||
# Second event immediately after should be blocked
|
||||
event2 = _make_event("sensor.temp", "21", "22",
|
||||
new_attrs={"friendly_name": "Temp"})
|
||||
await adapter._handle_ha_event(event2)
|
||||
assert adapter.handle_message.call_count == 1 # Still 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cooldown_expires(self):
|
||||
adapter = _make_adapter(watch_all=True, cooldown_seconds=1)
|
||||
|
||||
event = _make_event("sensor.temp", "20", "21",
|
||||
new_attrs={"friendly_name": "Temp"})
|
||||
await adapter._handle_ha_event(event)
|
||||
assert adapter.handle_message.call_count == 1
|
||||
|
||||
# Simulate time passing beyond cooldown
|
||||
adapter._last_event_time["sensor.temp"] = time.time() - 2
|
||||
|
||||
event2 = _make_event("sensor.temp", "21", "22",
|
||||
new_attrs={"friendly_name": "Temp"})
|
||||
await adapter._handle_ha_event(event2)
|
||||
assert adapter.handle_message.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_entities_independent_cooldowns(self):
|
||||
adapter = _make_adapter(watch_all=True, cooldown_seconds=60)
|
||||
|
||||
await adapter._handle_ha_event(
|
||||
_make_event("sensor.a", "1", "2", new_attrs={"friendly_name": "A"})
|
||||
)
|
||||
await adapter._handle_ha_event(
|
||||
_make_event("sensor.b", "3", "4", new_attrs={"friendly_name": "B"})
|
||||
)
|
||||
# Both should pass - different entities
|
||||
assert adapter.handle_message.call_count == 2
|
||||
|
||||
# Same entity again - should be blocked
|
||||
await adapter._handle_ha_event(
|
||||
_make_event("sensor.a", "2", "3", new_attrs={"friendly_name": "A"})
|
||||
)
|
||||
assert adapter.handle_message.call_count == 2 # Still 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_cooldown_passes_all(self):
|
||||
adapter = _make_adapter(watch_all=True, cooldown_seconds=0)
|
||||
|
||||
for i in range(5):
|
||||
await adapter._handle_ha_event(
|
||||
_make_event("sensor.temp", str(i), str(i + 1),
|
||||
new_attrs={"friendly_name": "Temp"})
|
||||
)
|
||||
assert adapter.handle_message.call_count == 5
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config integration (env overrides, round-trip)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConfigIntegration:
|
||||
def test_env_override_creates_ha_platform(self, monkeypatch):
|
||||
monkeypatch.setenv("HASS_TOKEN", "env-token")
|
||||
monkeypatch.setenv("HASS_URL", "http://10.0.0.5:8123")
|
||||
# Clear other platform tokens
|
||||
for v in ["TELEGRAM_BOT_TOKEN", "DISCORD_BOT_TOKEN", "SLACK_BOT_TOKEN"]:
|
||||
monkeypatch.delenv(v, raising=False)
|
||||
|
||||
from gateway.config import load_gateway_config
|
||||
config = load_gateway_config()
|
||||
|
||||
assert Platform.HOMEASSISTANT in config.platforms
|
||||
ha = config.platforms[Platform.HOMEASSISTANT]
|
||||
assert ha.enabled is True
|
||||
assert ha.token == "env-token"
|
||||
assert ha.extra["url"] == "http://10.0.0.5:8123"
|
||||
|
||||
def test_no_env_no_platform(self, monkeypatch):
|
||||
for v in ["HASS_TOKEN", "HASS_URL", "TELEGRAM_BOT_TOKEN",
|
||||
"DISCORD_BOT_TOKEN", "SLACK_BOT_TOKEN"]:
|
||||
monkeypatch.delenv(v, raising=False)
|
||||
|
||||
from gateway.config import load_gateway_config
|
||||
config = load_gateway_config()
|
||||
assert Platform.HOMEASSISTANT not in config.platforms
|
||||
|
||||
def test_config_roundtrip_preserves_extra(self):
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.HOMEASSISTANT: PlatformConfig(
|
||||
enabled=True,
|
||||
token="tok",
|
||||
extra={
|
||||
"url": "http://ha:8123",
|
||||
"watch_domains": ["climate"],
|
||||
"cooldown_seconds": 45,
|
||||
},
|
||||
),
|
||||
},
|
||||
)
|
||||
d = config.to_dict()
|
||||
restored = GatewayConfig.from_dict(d)
|
||||
|
||||
ha = restored.platforms[Platform.HOMEASSISTANT]
|
||||
assert ha.enabled is True
|
||||
assert ha.token == "tok"
|
||||
assert ha.extra["watch_domains"] == ["climate"]
|
||||
assert ha.extra["cooldown_seconds"] == 45
|
||||
|
||||
def test_connected_platforms_includes_ha(self):
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.HOMEASSISTANT: PlatformConfig(enabled=True, token="tok"),
|
||||
Platform.TELEGRAM: PlatformConfig(enabled=False, token="t"),
|
||||
},
|
||||
)
|
||||
connected = config.get_connected_platforms()
|
||||
assert Platform.HOMEASSISTANT in connected
|
||||
assert Platform.TELEGRAM not in connected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# send() via REST API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSendViaRestApi:
|
||||
"""send() uses REST API (not WebSocket) to avoid race conditions."""
|
||||
|
||||
@staticmethod
|
||||
def _mock_aiohttp_session(response_status=200, response_text="OK"):
|
||||
"""Build a mock aiohttp session + response for async-with patterns.
|
||||
|
||||
aiohttp.ClientSession() is a sync constructor whose return value
|
||||
is used as ``async with session:``. ``session.post(...)`` returns a
|
||||
context-manager (not a coroutine), so both layers use MagicMock for
|
||||
the call and AsyncMock only for ``__aenter__`` / ``__aexit__``.
|
||||
"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = response_status
|
||||
mock_response.text = AsyncMock(return_value=response_text)
|
||||
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_response.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.post = MagicMock(return_value=mock_response)
|
||||
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
return mock_session
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_success(self):
|
||||
adapter = _make_adapter()
|
||||
mock_session = self._mock_aiohttp_session(200)
|
||||
|
||||
with patch("gateway.platforms.homeassistant.aiohttp") as mock_aiohttp:
|
||||
mock_aiohttp.ClientSession = MagicMock(return_value=mock_session)
|
||||
mock_aiohttp.ClientTimeout = lambda total: total
|
||||
|
||||
result = await adapter.send("ha_events", "Test notification")
|
||||
|
||||
assert result.success is True
|
||||
# Verify the REST API was called with correct payload
|
||||
call_args = mock_session.post.call_args
|
||||
assert "/api/services/persistent_notification/create" in call_args[0][0]
|
||||
assert call_args[1]["json"]["title"] == "Hermes Agent"
|
||||
assert call_args[1]["json"]["message"] == "Test notification"
|
||||
assert "Bearer tok" in call_args[1]["headers"]["Authorization"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_http_error(self):
|
||||
adapter = _make_adapter()
|
||||
mock_session = self._mock_aiohttp_session(401, "Unauthorized")
|
||||
|
||||
with patch("gateway.platforms.homeassistant.aiohttp") as mock_aiohttp:
|
||||
mock_aiohttp.ClientSession = MagicMock(return_value=mock_session)
|
||||
mock_aiohttp.ClientTimeout = lambda total: total
|
||||
|
||||
result = await adapter.send("ha_events", "Test")
|
||||
|
||||
assert result.success is False
|
||||
assert "401" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_truncates_long_message(self):
|
||||
adapter = _make_adapter()
|
||||
mock_session = self._mock_aiohttp_session(200)
|
||||
long_message = "x" * 10000
|
||||
|
||||
with patch("gateway.platforms.homeassistant.aiohttp") as mock_aiohttp:
|
||||
mock_aiohttp.ClientSession = MagicMock(return_value=mock_session)
|
||||
mock_aiohttp.ClientTimeout = lambda total: total
|
||||
|
||||
await adapter.send("ha_events", long_message)
|
||||
|
||||
sent_message = mock_session.post.call_args[1]["json"]["message"]
|
||||
assert len(sent_message) == 4096
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_does_not_use_websocket(self):
|
||||
"""send() must use REST API, not the WS connection (race condition fix)."""
|
||||
adapter = _make_adapter()
|
||||
adapter._ws = AsyncMock() # Simulate an active WS
|
||||
mock_session = self._mock_aiohttp_session(200)
|
||||
|
||||
with patch("gateway.platforms.homeassistant.aiohttp") as mock_aiohttp:
|
||||
mock_aiohttp.ClientSession = MagicMock(return_value=mock_session)
|
||||
mock_aiohttp.ClientTimeout = lambda total: total
|
||||
|
||||
await adapter.send("ha_events", "Test")
|
||||
|
||||
# WS should NOT have been used for sending
|
||||
adapter._ws.send_json.assert_not_called()
|
||||
adapter._ws.receive_json.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Toolset integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestToolsetIntegration:
|
||||
def test_homeassistant_toolset_resolves(self):
|
||||
from toolsets import resolve_toolset
|
||||
|
||||
tools = resolve_toolset("homeassistant")
|
||||
assert set(tools) == {"ha_list_entities", "ha_get_state", "ha_call_service", "ha_list_services"}
|
||||
|
||||
def test_gateway_toolset_includes_ha_tools(self):
|
||||
from toolsets import resolve_toolset
|
||||
|
||||
gateway_tools = resolve_toolset("hermes-gateway")
|
||||
for tool in ("ha_list_entities", "ha_get_state", "ha_call_service", "ha_list_services"):
|
||||
assert tool in gateway_tools
|
||||
|
||||
def test_hermes_core_tools_includes_ha(self):
|
||||
from toolsets import _HERMES_CORE_TOOLS
|
||||
|
||||
for tool in ("ha_list_entities", "ha_get_state", "ha_call_service", "ha_list_services"):
|
||||
assert tool in _HERMES_CORE_TOOLS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WebSocket URL construction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestWsUrlConstruction:
|
||||
def test_http_to_ws(self):
|
||||
config = PlatformConfig(enabled=True, token="t", extra={"url": "http://ha:8123"})
|
||||
adapter = HomeAssistantAdapter(config)
|
||||
ws_url = adapter._hass_url.replace("http://", "ws://").replace("https://", "wss://")
|
||||
assert ws_url == "ws://ha:8123"
|
||||
|
||||
def test_https_to_wss(self):
|
||||
config = PlatformConfig(enabled=True, token="t", extra={"url": "https://ha.example.com"})
|
||||
adapter = HomeAssistantAdapter(config)
|
||||
ws_url = adapter._hass_url.replace("http://", "ws://").replace("https://", "wss://")
|
||||
assert ws_url == "wss://ha.example.com"
|
||||
131
hermes_code/tests/gateway/test_honcho_lifecycle.py
Normal file
131
hermes_code/tests/gateway/test_honcho_lifecycle.py
Normal file
|
|
@ -0,0 +1,131 @@
|
|||
"""Tests for gateway-owned Honcho lifecycle helpers."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform
|
||||
from gateway.platforms.base import MessageEvent
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
def _make_runner():
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner._honcho_managers = {}
|
||||
runner._honcho_configs = {}
|
||||
runner._running_agents = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner.adapters = {}
|
||||
runner.hooks = MagicMock()
|
||||
runner.hooks.emit = AsyncMock()
|
||||
return runner
|
||||
|
||||
|
||||
def _make_event(text="/reset"):
|
||||
return MessageEvent(
|
||||
text=text,
|
||||
source=SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="chat-1",
|
||||
user_id="user-1",
|
||||
user_name="alice",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TestGatewayHonchoLifecycle:
|
||||
def test_gateway_reuses_honcho_manager_for_session_key(self):
|
||||
runner = _make_runner()
|
||||
hcfg = SimpleNamespace(
|
||||
enabled=True,
|
||||
api_key="honcho-key",
|
||||
ai_peer="hermes",
|
||||
peer_name="alice",
|
||||
context_tokens=123,
|
||||
peer_memory_mode=lambda peer: "hybrid",
|
||||
)
|
||||
manager = MagicMock()
|
||||
|
||||
with (
|
||||
patch("honcho_integration.client.HonchoClientConfig.from_global_config", return_value=hcfg),
|
||||
patch("honcho_integration.client.get_honcho_client", return_value=MagicMock()),
|
||||
patch("honcho_integration.session.HonchoSessionManager", return_value=manager) as mock_mgr_cls,
|
||||
):
|
||||
first_mgr, first_cfg = runner._get_or_create_gateway_honcho("session-key")
|
||||
second_mgr, second_cfg = runner._get_or_create_gateway_honcho("session-key")
|
||||
|
||||
assert first_mgr is manager
|
||||
assert second_mgr is manager
|
||||
assert first_cfg is hcfg
|
||||
assert second_cfg is hcfg
|
||||
mock_mgr_cls.assert_called_once()
|
||||
|
||||
def test_gateway_skips_honcho_manager_when_disabled(self):
|
||||
runner = _make_runner()
|
||||
hcfg = SimpleNamespace(
|
||||
enabled=False,
|
||||
api_key="honcho-key",
|
||||
ai_peer="hermes",
|
||||
peer_name="alice",
|
||||
)
|
||||
|
||||
with (
|
||||
patch("honcho_integration.client.HonchoClientConfig.from_global_config", return_value=hcfg),
|
||||
patch("honcho_integration.client.get_honcho_client") as mock_client,
|
||||
patch("honcho_integration.session.HonchoSessionManager") as mock_mgr_cls,
|
||||
):
|
||||
manager, cfg = runner._get_or_create_gateway_honcho("session-key")
|
||||
|
||||
assert manager is None
|
||||
assert cfg is hcfg
|
||||
mock_client.assert_not_called()
|
||||
mock_mgr_cls.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_shuts_down_gateway_honcho_manager(self):
|
||||
runner = _make_runner()
|
||||
event = _make_event()
|
||||
runner._shutdown_gateway_honcho = MagicMock()
|
||||
runner._async_flush_memories = AsyncMock()
|
||||
runner.session_store = MagicMock()
|
||||
runner.session_store._generate_session_key.return_value = "gateway-key"
|
||||
runner.session_store._entries = {
|
||||
"gateway-key": SimpleNamespace(session_id="old-session"),
|
||||
}
|
||||
runner.session_store.reset_session.return_value = SimpleNamespace(session_id="new-session")
|
||||
|
||||
result = await runner._handle_reset_command(event)
|
||||
|
||||
runner._shutdown_gateway_honcho.assert_called_once_with("gateway-key")
|
||||
runner._async_flush_memories.assert_called_once_with("old-session", "gateway-key")
|
||||
assert "Session reset" in result
|
||||
|
||||
def test_flush_memories_reuses_gateway_session_key_and_skips_honcho_sync(self):
|
||||
runner = _make_runner()
|
||||
runner.session_store = MagicMock()
|
||||
runner.session_store.load_transcript.return_value = [
|
||||
{"role": "user", "content": "a"},
|
||||
{"role": "assistant", "content": "b"},
|
||||
{"role": "user", "content": "c"},
|
||||
{"role": "assistant", "content": "d"},
|
||||
]
|
||||
tmp_agent = MagicMock()
|
||||
|
||||
with (
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "test-key"}),
|
||||
patch("gateway.run._resolve_gateway_model", return_value="model-name"),
|
||||
patch("run_agent.AIAgent", return_value=tmp_agent) as mock_agent_cls,
|
||||
):
|
||||
runner._flush_memories_for_session("old-session", "gateway-key")
|
||||
|
||||
mock_agent_cls.assert_called_once()
|
||||
_, kwargs = mock_agent_cls.call_args
|
||||
assert kwargs["session_id"] == "old-session"
|
||||
assert kwargs["honcho_session_key"] == "gateway-key"
|
||||
tmp_agent.run_conversation.assert_called_once()
|
||||
_, run_kwargs = tmp_agent.run_conversation.call_args
|
||||
assert run_kwargs["sync_honcho"] is False
|
||||
217
hermes_code/tests/gateway/test_hooks.py
Normal file
217
hermes_code/tests/gateway/test_hooks.py
Normal file
|
|
@ -0,0 +1,217 @@
|
|||
"""Tests for gateway/hooks.py — event hook system."""
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.hooks import HookRegistry
|
||||
|
||||
|
||||
def _create_hook(hooks_dir, hook_name, events, handler_code):
|
||||
"""Helper to create a hook directory with HOOK.yaml and handler.py."""
|
||||
hook_dir = hooks_dir / hook_name
|
||||
hook_dir.mkdir(parents=True)
|
||||
(hook_dir / "HOOK.yaml").write_text(
|
||||
f"name: {hook_name}\n"
|
||||
f"description: Test hook\n"
|
||||
f"events: {events}\n"
|
||||
)
|
||||
(hook_dir / "handler.py").write_text(handler_code)
|
||||
return hook_dir
|
||||
|
||||
|
||||
class TestHookRegistryInit:
|
||||
def test_empty_registry(self):
|
||||
reg = HookRegistry()
|
||||
assert reg.loaded_hooks == []
|
||||
assert reg._handlers == {}
|
||||
|
||||
|
||||
class TestDiscoverAndLoad:
|
||||
def test_loads_valid_hook(self, tmp_path):
|
||||
_create_hook(tmp_path, "my-hook", '["agent:start"]',
|
||||
"def handle(event_type, context):\n pass\n")
|
||||
|
||||
reg = HookRegistry()
|
||||
with patch("gateway.hooks.HOOKS_DIR", tmp_path):
|
||||
reg.discover_and_load()
|
||||
|
||||
assert len(reg.loaded_hooks) == 1
|
||||
assert reg.loaded_hooks[0]["name"] == "my-hook"
|
||||
assert "agent:start" in reg.loaded_hooks[0]["events"]
|
||||
|
||||
def test_skips_missing_hook_yaml(self, tmp_path):
|
||||
hook_dir = tmp_path / "bad-hook"
|
||||
hook_dir.mkdir()
|
||||
(hook_dir / "handler.py").write_text("def handle(e, c): pass\n")
|
||||
|
||||
reg = HookRegistry()
|
||||
with patch("gateway.hooks.HOOKS_DIR", tmp_path):
|
||||
reg.discover_and_load()
|
||||
|
||||
assert len(reg.loaded_hooks) == 0
|
||||
|
||||
def test_skips_missing_handler_py(self, tmp_path):
|
||||
hook_dir = tmp_path / "bad-hook"
|
||||
hook_dir.mkdir()
|
||||
(hook_dir / "HOOK.yaml").write_text("name: bad\nevents: ['agent:start']\n")
|
||||
|
||||
reg = HookRegistry()
|
||||
with patch("gateway.hooks.HOOKS_DIR", tmp_path):
|
||||
reg.discover_and_load()
|
||||
|
||||
assert len(reg.loaded_hooks) == 0
|
||||
|
||||
def test_skips_no_events(self, tmp_path):
|
||||
hook_dir = tmp_path / "empty-hook"
|
||||
hook_dir.mkdir()
|
||||
(hook_dir / "HOOK.yaml").write_text("name: empty\nevents: []\n")
|
||||
(hook_dir / "handler.py").write_text("def handle(e, c): pass\n")
|
||||
|
||||
reg = HookRegistry()
|
||||
with patch("gateway.hooks.HOOKS_DIR", tmp_path):
|
||||
reg.discover_and_load()
|
||||
|
||||
assert len(reg.loaded_hooks) == 0
|
||||
|
||||
def test_skips_no_handle_function(self, tmp_path):
|
||||
hook_dir = tmp_path / "no-handle"
|
||||
hook_dir.mkdir()
|
||||
(hook_dir / "HOOK.yaml").write_text("name: no-handle\nevents: ['agent:start']\n")
|
||||
(hook_dir / "handler.py").write_text("def something_else(): pass\n")
|
||||
|
||||
reg = HookRegistry()
|
||||
with patch("gateway.hooks.HOOKS_DIR", tmp_path):
|
||||
reg.discover_and_load()
|
||||
|
||||
assert len(reg.loaded_hooks) == 0
|
||||
|
||||
def test_nonexistent_hooks_dir(self, tmp_path):
|
||||
reg = HookRegistry()
|
||||
with patch("gateway.hooks.HOOKS_DIR", tmp_path / "nonexistent"):
|
||||
reg.discover_and_load()
|
||||
|
||||
assert len(reg.loaded_hooks) == 0
|
||||
|
||||
def test_multiple_hooks(self, tmp_path):
|
||||
_create_hook(tmp_path, "hook-a", '["agent:start"]',
|
||||
"def handle(e, c): pass\n")
|
||||
_create_hook(tmp_path, "hook-b", '["session:start", "session:reset"]',
|
||||
"def handle(e, c): pass\n")
|
||||
|
||||
reg = HookRegistry()
|
||||
with patch("gateway.hooks.HOOKS_DIR", tmp_path):
|
||||
reg.discover_and_load()
|
||||
|
||||
assert len(reg.loaded_hooks) == 2
|
||||
|
||||
|
||||
class TestEmit:
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_calls_sync_handler(self, tmp_path):
|
||||
results = []
|
||||
|
||||
_create_hook(tmp_path, "sync-hook", '["agent:start"]',
|
||||
"results = []\n"
|
||||
"def handle(event_type, context):\n"
|
||||
" results.append(event_type)\n")
|
||||
|
||||
reg = HookRegistry()
|
||||
with patch("gateway.hooks.HOOKS_DIR", tmp_path):
|
||||
reg.discover_and_load()
|
||||
|
||||
# Inject our results list into the handler's module globals
|
||||
handler_fn = reg._handlers["agent:start"][0]
|
||||
handler_fn.__globals__["results"] = results
|
||||
|
||||
await reg.emit("agent:start", {"test": True})
|
||||
assert "agent:start" in results
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_calls_async_handler(self, tmp_path):
|
||||
results = []
|
||||
|
||||
hook_dir = tmp_path / "async-hook"
|
||||
hook_dir.mkdir()
|
||||
(hook_dir / "HOOK.yaml").write_text(
|
||||
"name: async-hook\nevents: ['agent:end']\n"
|
||||
)
|
||||
(hook_dir / "handler.py").write_text(
|
||||
"import asyncio\n"
|
||||
"results = []\n"
|
||||
"async def handle(event_type, context):\n"
|
||||
" results.append(event_type)\n"
|
||||
)
|
||||
|
||||
reg = HookRegistry()
|
||||
with patch("gateway.hooks.HOOKS_DIR", tmp_path):
|
||||
reg.discover_and_load()
|
||||
|
||||
handler_fn = reg._handlers["agent:end"][0]
|
||||
handler_fn.__globals__["results"] = results
|
||||
|
||||
await reg.emit("agent:end", {})
|
||||
assert "agent:end" in results
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wildcard_matching(self, tmp_path):
|
||||
results = []
|
||||
|
||||
_create_hook(tmp_path, "wildcard-hook", '["command:*"]',
|
||||
"results = []\n"
|
||||
"def handle(event_type, context):\n"
|
||||
" results.append(event_type)\n")
|
||||
|
||||
reg = HookRegistry()
|
||||
with patch("gateway.hooks.HOOKS_DIR", tmp_path):
|
||||
reg.discover_and_load()
|
||||
|
||||
handler_fn = reg._handlers["command:*"][0]
|
||||
handler_fn.__globals__["results"] = results
|
||||
|
||||
await reg.emit("command:reset", {})
|
||||
assert "command:reset" in results
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_handlers_for_event(self, tmp_path):
|
||||
reg = HookRegistry()
|
||||
# Should not raise and should have no handlers registered
|
||||
result = await reg.emit("unknown:event", {})
|
||||
assert result is None
|
||||
assert not reg._handlers.get("unknown:event")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_error_does_not_propagate(self, tmp_path):
|
||||
_create_hook(tmp_path, "bad-hook", '["agent:start"]',
|
||||
"def handle(event_type, context):\n"
|
||||
" raise ValueError('boom')\n")
|
||||
|
||||
reg = HookRegistry()
|
||||
with patch("gateway.hooks.HOOKS_DIR", tmp_path):
|
||||
reg.discover_and_load()
|
||||
|
||||
assert len(reg._handlers.get("agent:start", [])) == 1
|
||||
# Should not raise even though handler throws
|
||||
result = await reg.emit("agent:start", {})
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_default_context(self, tmp_path):
|
||||
captured = []
|
||||
|
||||
_create_hook(tmp_path, "ctx-hook", '["agent:start"]',
|
||||
"captured = []\n"
|
||||
"def handle(event_type, context):\n"
|
||||
" captured.append(context)\n")
|
||||
|
||||
reg = HookRegistry()
|
||||
with patch("gateway.hooks.HOOKS_DIR", tmp_path):
|
||||
reg.discover_and_load()
|
||||
|
||||
handler_fn = reg._handlers["agent:start"][0]
|
||||
handler_fn.__globals__["captured"] = captured
|
||||
|
||||
await reg.emit("agent:start") # no context arg
|
||||
assert captured[0] == {}
|
||||
150
hermes_code/tests/gateway/test_interrupt_key_match.py
Normal file
150
hermes_code/tests/gateway/test_interrupt_key_match.py
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
"""Tests verifying interrupt key consistency between adapter and gateway.
|
||||
|
||||
Regression test for a bug where monitor_for_interrupt() in _run_agent used
|
||||
source.chat_id to query the adapter, but the adapter stores interrupts under
|
||||
the full session key (build_session_key output). This mismatch meant
|
||||
interrupts were never detected, causing subagents to ignore new messages.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType, SendResult
|
||||
from gateway.session import SessionSource, build_session_key
|
||||
|
||||
|
||||
class StubAdapter(BasePlatformAdapter):
|
||||
"""Minimal adapter for interrupt tests."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(PlatformConfig(enabled=True, token="test"), Platform.TELEGRAM)
|
||||
|
||||
async def connect(self):
|
||||
return True
|
||||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
async def send(self, chat_id, content, reply_to=None, metadata=None):
|
||||
return SendResult(success=True, message_id="1")
|
||||
|
||||
async def send_typing(self, chat_id, metadata=None):
|
||||
pass
|
||||
|
||||
async def get_chat_info(self, chat_id):
|
||||
return {"id": chat_id}
|
||||
|
||||
|
||||
def _source(chat_id="123456", chat_type="dm", thread_id=None):
|
||||
return SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id=chat_id,
|
||||
chat_type=chat_type,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
|
||||
class TestInterruptKeyConsistency:
|
||||
"""Ensure adapter interrupt methods are queried with session_key, not chat_id."""
|
||||
|
||||
def test_session_key_differs_from_chat_id_for_dm(self):
|
||||
"""Session key for a DM is namespaced and includes the DM chat_id."""
|
||||
source = _source("123456", "dm")
|
||||
session_key = build_session_key(source)
|
||||
assert session_key != source.chat_id
|
||||
assert session_key == "agent:main:telegram:dm:123456"
|
||||
|
||||
def test_session_key_differs_from_chat_id_for_group(self):
|
||||
"""Session key for a group chat includes prefix, unlike raw chat_id."""
|
||||
source = _source("-1001234", "group")
|
||||
session_key = build_session_key(source)
|
||||
assert session_key != source.chat_id
|
||||
assert "agent:main:" in session_key
|
||||
assert source.chat_id in session_key
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_pending_interrupt_requires_session_key(self):
|
||||
"""has_pending_interrupt returns True only when queried with session_key."""
|
||||
adapter = StubAdapter()
|
||||
source = _source("123456", "dm")
|
||||
session_key = build_session_key(source)
|
||||
|
||||
# Simulate adapter storing interrupt under session_key
|
||||
interrupt_event = asyncio.Event()
|
||||
adapter._active_sessions[session_key] = interrupt_event
|
||||
interrupt_event.set()
|
||||
|
||||
# Using session_key → found
|
||||
assert adapter.has_pending_interrupt(session_key) is True
|
||||
|
||||
# Using chat_id → NOT found (this was the bug)
|
||||
assert adapter.has_pending_interrupt(source.chat_id) is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_pending_message_requires_session_key(self):
|
||||
"""get_pending_message returns the event only with session_key."""
|
||||
adapter = StubAdapter()
|
||||
source = _source("123456", "dm")
|
||||
session_key = build_session_key(source)
|
||||
|
||||
event = MessageEvent(text="hello", source=source, message_id="42")
|
||||
adapter._pending_messages[session_key] = event
|
||||
|
||||
# Using chat_id → None (the bug)
|
||||
assert adapter.get_pending_message(source.chat_id) is None
|
||||
|
||||
# Using session_key → found
|
||||
result = adapter.get_pending_message(session_key)
|
||||
assert result is event
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_stores_under_session_key(self):
|
||||
"""handle_message stores pending messages under session_key, not chat_id."""
|
||||
adapter = StubAdapter()
|
||||
adapter.set_message_handler(lambda event: asyncio.sleep(0, result=None))
|
||||
|
||||
source = _source("-1001234", "group")
|
||||
session_key = build_session_key(source)
|
||||
|
||||
# Mark session as active
|
||||
adapter._active_sessions[session_key] = asyncio.Event()
|
||||
|
||||
# Send a second message while session is active
|
||||
event = MessageEvent(text="interrupt!", source=source, message_id="2")
|
||||
await adapter.handle_message(event)
|
||||
|
||||
# Stored under session_key
|
||||
assert session_key in adapter._pending_messages
|
||||
# NOT stored under chat_id
|
||||
assert source.chat_id not in adapter._pending_messages
|
||||
|
||||
# Interrupt event was set
|
||||
assert adapter._active_sessions[session_key].is_set()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_photo_followup_is_queued_without_interrupt(self):
|
||||
"""Photo follow-ups should queue behind the active run instead of interrupting it."""
|
||||
adapter = StubAdapter()
|
||||
adapter.set_message_handler(lambda event: asyncio.sleep(0, result=None))
|
||||
|
||||
source = _source("-1001234", "group")
|
||||
session_key = build_session_key(source)
|
||||
interrupt_event = asyncio.Event()
|
||||
adapter._active_sessions[session_key] = interrupt_event
|
||||
|
||||
event = MessageEvent(
|
||||
text="caption",
|
||||
source=source,
|
||||
message_type=MessageType.PHOTO,
|
||||
message_id="2",
|
||||
media_urls=["/tmp/photo-a.jpg"],
|
||||
media_types=["image/jpeg"],
|
||||
)
|
||||
await adapter.handle_message(event)
|
||||
|
||||
queued = adapter._pending_messages[session_key]
|
||||
assert queued is event
|
||||
assert queued.media_urls == ["/tmp/photo-a.jpg"]
|
||||
assert interrupt_event.is_set() is False
|
||||
448
hermes_code/tests/gateway/test_matrix.py
Normal file
448
hermes_code/tests/gateway/test_matrix.py
Normal file
|
|
@ -0,0 +1,448 @@
|
|||
"""Tests for Matrix platform adapter."""
|
||||
import json
|
||||
import re
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Platform & Config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMatrixPlatformEnum:
|
||||
def test_matrix_enum_exists(self):
|
||||
assert Platform.MATRIX.value == "matrix"
|
||||
|
||||
def test_matrix_in_platform_list(self):
|
||||
platforms = [p.value for p in Platform]
|
||||
assert "matrix" in platforms
|
||||
|
||||
|
||||
class TestMatrixConfigLoading:
|
||||
def test_apply_env_overrides_with_access_token(self, monkeypatch):
|
||||
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_abc123")
|
||||
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
|
||||
|
||||
from gateway.config import GatewayConfig, _apply_env_overrides
|
||||
config = GatewayConfig()
|
||||
_apply_env_overrides(config)
|
||||
|
||||
assert Platform.MATRIX in config.platforms
|
||||
mc = config.platforms[Platform.MATRIX]
|
||||
assert mc.enabled is True
|
||||
assert mc.token == "syt_abc123"
|
||||
assert mc.extra.get("homeserver") == "https://matrix.example.org"
|
||||
|
||||
def test_apply_env_overrides_with_password(self, monkeypatch):
|
||||
monkeypatch.delenv("MATRIX_ACCESS_TOKEN", raising=False)
|
||||
monkeypatch.setenv("MATRIX_PASSWORD", "secret123")
|
||||
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
|
||||
monkeypatch.setenv("MATRIX_USER_ID", "@bot:example.org")
|
||||
|
||||
from gateway.config import GatewayConfig, _apply_env_overrides
|
||||
config = GatewayConfig()
|
||||
_apply_env_overrides(config)
|
||||
|
||||
assert Platform.MATRIX in config.platforms
|
||||
mc = config.platforms[Platform.MATRIX]
|
||||
assert mc.enabled is True
|
||||
assert mc.extra.get("password") == "secret123"
|
||||
assert mc.extra.get("user_id") == "@bot:example.org"
|
||||
|
||||
def test_matrix_not_loaded_without_creds(self, monkeypatch):
|
||||
monkeypatch.delenv("MATRIX_ACCESS_TOKEN", raising=False)
|
||||
monkeypatch.delenv("MATRIX_PASSWORD", raising=False)
|
||||
monkeypatch.delenv("MATRIX_HOMESERVER", raising=False)
|
||||
|
||||
from gateway.config import GatewayConfig, _apply_env_overrides
|
||||
config = GatewayConfig()
|
||||
_apply_env_overrides(config)
|
||||
|
||||
assert Platform.MATRIX not in config.platforms
|
||||
|
||||
def test_matrix_encryption_flag(self, monkeypatch):
|
||||
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_abc123")
|
||||
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
|
||||
monkeypatch.setenv("MATRIX_ENCRYPTION", "true")
|
||||
|
||||
from gateway.config import GatewayConfig, _apply_env_overrides
|
||||
config = GatewayConfig()
|
||||
_apply_env_overrides(config)
|
||||
|
||||
mc = config.platforms[Platform.MATRIX]
|
||||
assert mc.extra.get("encryption") is True
|
||||
|
||||
def test_matrix_encryption_default_off(self, monkeypatch):
|
||||
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_abc123")
|
||||
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
|
||||
monkeypatch.delenv("MATRIX_ENCRYPTION", raising=False)
|
||||
|
||||
from gateway.config import GatewayConfig, _apply_env_overrides
|
||||
config = GatewayConfig()
|
||||
_apply_env_overrides(config)
|
||||
|
||||
mc = config.platforms[Platform.MATRIX]
|
||||
assert mc.extra.get("encryption") is False
|
||||
|
||||
def test_matrix_home_room(self, monkeypatch):
|
||||
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_abc123")
|
||||
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
|
||||
monkeypatch.setenv("MATRIX_HOME_ROOM", "!room123:example.org")
|
||||
monkeypatch.setenv("MATRIX_HOME_ROOM_NAME", "Bot Room")
|
||||
|
||||
from gateway.config import GatewayConfig, _apply_env_overrides
|
||||
config = GatewayConfig()
|
||||
_apply_env_overrides(config)
|
||||
|
||||
home = config.get_home_channel(Platform.MATRIX)
|
||||
assert home is not None
|
||||
assert home.chat_id == "!room123:example.org"
|
||||
assert home.name == "Bot Room"
|
||||
|
||||
def test_matrix_user_id_stored_in_extra(self, monkeypatch):
|
||||
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_abc123")
|
||||
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
|
||||
monkeypatch.setenv("MATRIX_USER_ID", "@hermes:example.org")
|
||||
|
||||
from gateway.config import GatewayConfig, _apply_env_overrides
|
||||
config = GatewayConfig()
|
||||
_apply_env_overrides(config)
|
||||
|
||||
mc = config.platforms[Platform.MATRIX]
|
||||
assert mc.extra.get("user_id") == "@hermes:example.org"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Adapter helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_adapter():
|
||||
"""Create a MatrixAdapter with mocked config."""
|
||||
from gateway.platforms.matrix import MatrixAdapter
|
||||
config = PlatformConfig(
|
||||
enabled=True,
|
||||
token="syt_test_token",
|
||||
extra={
|
||||
"homeserver": "https://matrix.example.org",
|
||||
"user_id": "@bot:example.org",
|
||||
},
|
||||
)
|
||||
adapter = MatrixAdapter(config)
|
||||
return adapter
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# mxc:// URL conversion
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMatrixMxcToHttp:
|
||||
def setup_method(self):
|
||||
self.adapter = _make_adapter()
|
||||
|
||||
def test_basic_mxc_conversion(self):
|
||||
"""mxc://server/media_id should become an authenticated HTTP URL."""
|
||||
mxc = "mxc://matrix.org/abc123"
|
||||
result = self.adapter._mxc_to_http(mxc)
|
||||
assert result == "https://matrix.example.org/_matrix/client/v1/media/download/matrix.org/abc123"
|
||||
|
||||
def test_mxc_with_different_server(self):
|
||||
"""mxc:// from a different server should still use our homeserver."""
|
||||
mxc = "mxc://other.server/media456"
|
||||
result = self.adapter._mxc_to_http(mxc)
|
||||
assert result.startswith("https://matrix.example.org/")
|
||||
assert "other.server/media456" in result
|
||||
|
||||
def test_non_mxc_url_passthrough(self):
|
||||
"""Non-mxc URLs should be returned unchanged."""
|
||||
url = "https://example.com/image.png"
|
||||
assert self.adapter._mxc_to_http(url) == url
|
||||
|
||||
def test_mxc_uses_client_v1_endpoint(self):
|
||||
"""Should use /_matrix/client/v1/media/download/ not the deprecated path."""
|
||||
mxc = "mxc://example.com/test123"
|
||||
result = self.adapter._mxc_to_http(mxc)
|
||||
assert "/_matrix/client/v1/media/download/" in result
|
||||
assert "/_matrix/media/v3/download/" not in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DM detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMatrixDmDetection:
|
||||
def setup_method(self):
|
||||
self.adapter = _make_adapter()
|
||||
|
||||
def test_room_in_m_direct_is_dm(self):
|
||||
"""A room listed in m.direct should be detected as DM."""
|
||||
self.adapter._joined_rooms = {"!dm_room:ex.org", "!group_room:ex.org"}
|
||||
self.adapter._dm_rooms = {
|
||||
"!dm_room:ex.org": True,
|
||||
"!group_room:ex.org": False,
|
||||
}
|
||||
|
||||
assert self.adapter._dm_rooms.get("!dm_room:ex.org") is True
|
||||
assert self.adapter._dm_rooms.get("!group_room:ex.org") is False
|
||||
|
||||
def test_unknown_room_not_in_cache(self):
|
||||
"""Unknown rooms should not be in the DM cache."""
|
||||
self.adapter._dm_rooms = {}
|
||||
assert self.adapter._dm_rooms.get("!unknown:ex.org") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_dm_cache_with_m_direct(self):
|
||||
"""_refresh_dm_cache should populate _dm_rooms from m.direct data."""
|
||||
self.adapter._joined_rooms = {"!room_a:ex.org", "!room_b:ex.org", "!room_c:ex.org"}
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.content = {
|
||||
"@alice:ex.org": ["!room_a:ex.org"],
|
||||
"@bob:ex.org": ["!room_b:ex.org"],
|
||||
}
|
||||
mock_client.get_account_data = AsyncMock(return_value=mock_resp)
|
||||
self.adapter._client = mock_client
|
||||
|
||||
await self.adapter._refresh_dm_cache()
|
||||
|
||||
assert self.adapter._dm_rooms["!room_a:ex.org"] is True
|
||||
assert self.adapter._dm_rooms["!room_b:ex.org"] is True
|
||||
assert self.adapter._dm_rooms["!room_c:ex.org"] is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Reply fallback stripping
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMatrixReplyFallbackStripping:
|
||||
"""Test that Matrix reply fallback lines ('> ' prefix) are stripped."""
|
||||
|
||||
def setup_method(self):
|
||||
self.adapter = _make_adapter()
|
||||
self.adapter._user_id = "@bot:example.org"
|
||||
self.adapter._startup_ts = 0.0
|
||||
self.adapter._dm_rooms = {}
|
||||
self.adapter._message_handler = AsyncMock()
|
||||
|
||||
def _strip_fallback(self, body: str, has_reply: bool = True) -> str:
|
||||
"""Simulate the reply fallback stripping logic from _on_room_message."""
|
||||
reply_to = "some_event_id" if has_reply else None
|
||||
if reply_to and body.startswith("> "):
|
||||
lines = body.split("\n")
|
||||
stripped = []
|
||||
past_fallback = False
|
||||
for line in lines:
|
||||
if not past_fallback:
|
||||
if line.startswith("> ") or line == ">":
|
||||
continue
|
||||
if line == "":
|
||||
past_fallback = True
|
||||
continue
|
||||
past_fallback = True
|
||||
stripped.append(line)
|
||||
body = "\n".join(stripped) if stripped else body
|
||||
return body
|
||||
|
||||
def test_simple_reply_fallback(self):
|
||||
body = "> <@alice:ex.org> Original message\n\nActual reply"
|
||||
result = self._strip_fallback(body)
|
||||
assert result == "Actual reply"
|
||||
|
||||
def test_multiline_reply_fallback(self):
|
||||
body = "> <@alice:ex.org> Line 1\n> Line 2\n\nMy response"
|
||||
result = self._strip_fallback(body)
|
||||
assert result == "My response"
|
||||
|
||||
def test_no_reply_fallback_preserved(self):
|
||||
body = "Just a normal message"
|
||||
result = self._strip_fallback(body, has_reply=False)
|
||||
assert result == "Just a normal message"
|
||||
|
||||
def test_quote_without_reply_preserved(self):
|
||||
"""'> ' lines without a reply_to context should be preserved."""
|
||||
body = "> This is a blockquote"
|
||||
result = self._strip_fallback(body, has_reply=False)
|
||||
assert result == "> This is a blockquote"
|
||||
|
||||
def test_empty_fallback_separator(self):
|
||||
"""The blank line between fallback and actual content should be stripped."""
|
||||
body = "> <@alice:ex.org> hi\n>\n\nResponse"
|
||||
result = self._strip_fallback(body)
|
||||
assert result == "Response"
|
||||
|
||||
def test_multiline_response_after_fallback(self):
|
||||
body = "> <@alice:ex.org> Original\n\nLine 1\nLine 2\nLine 3"
|
||||
result = self._strip_fallback(body)
|
||||
assert result == "Line 1\nLine 2\nLine 3"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Thread detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMatrixThreadDetection:
|
||||
def test_thread_id_from_m_relates_to(self):
|
||||
"""m.relates_to with rel_type=m.thread should extract the event_id."""
|
||||
relates_to = {
|
||||
"rel_type": "m.thread",
|
||||
"event_id": "$thread_root_event",
|
||||
"is_falling_back": True,
|
||||
"m.in_reply_to": {"event_id": "$some_event"},
|
||||
}
|
||||
# Simulate the extraction logic from _on_room_message
|
||||
thread_id = None
|
||||
if relates_to.get("rel_type") == "m.thread":
|
||||
thread_id = relates_to.get("event_id")
|
||||
assert thread_id == "$thread_root_event"
|
||||
|
||||
def test_no_thread_for_reply(self):
|
||||
"""m.in_reply_to without m.thread should not set thread_id."""
|
||||
relates_to = {
|
||||
"m.in_reply_to": {"event_id": "$reply_event"},
|
||||
}
|
||||
thread_id = None
|
||||
if relates_to.get("rel_type") == "m.thread":
|
||||
thread_id = relates_to.get("event_id")
|
||||
assert thread_id is None
|
||||
|
||||
def test_no_thread_for_edit(self):
|
||||
"""m.replace relation should not set thread_id."""
|
||||
relates_to = {
|
||||
"rel_type": "m.replace",
|
||||
"event_id": "$edited_event",
|
||||
}
|
||||
thread_id = None
|
||||
if relates_to.get("rel_type") == "m.thread":
|
||||
thread_id = relates_to.get("event_id")
|
||||
assert thread_id is None
|
||||
|
||||
def test_empty_relates_to(self):
|
||||
"""Empty m.relates_to should not set thread_id."""
|
||||
relates_to = {}
|
||||
thread_id = None
|
||||
if relates_to.get("rel_type") == "m.thread":
|
||||
thread_id = relates_to.get("event_id")
|
||||
assert thread_id is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Format message
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMatrixFormatMessage:
|
||||
def setup_method(self):
|
||||
self.adapter = _make_adapter()
|
||||
|
||||
def test_image_markdown_stripped(self):
|
||||
""" should be converted to just the URL."""
|
||||
result = self.adapter.format_message("")
|
||||
assert result == "https://img.example.com/cat.png"
|
||||
|
||||
def test_regular_markdown_preserved(self):
|
||||
"""Standard markdown should be preserved (Matrix supports it)."""
|
||||
content = "**bold** and *italic* and `code`"
|
||||
assert self.adapter.format_message(content) == content
|
||||
|
||||
def test_plain_text_unchanged(self):
|
||||
content = "Hello, world!"
|
||||
assert self.adapter.format_message(content) == content
|
||||
|
||||
def test_multiple_images_stripped(self):
|
||||
content = " and "
|
||||
result = self.adapter.format_message(content)
|
||||
assert "![" not in result
|
||||
assert "http://a.com/1.png" in result
|
||||
assert "http://b.com/2.png" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Markdown to HTML conversion
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMatrixMarkdownToHtml:
|
||||
def setup_method(self):
|
||||
self.adapter = _make_adapter()
|
||||
|
||||
def test_bold_conversion(self):
|
||||
"""**bold** should produce <strong> tags."""
|
||||
result = self.adapter._markdown_to_html("**bold**")
|
||||
assert "<strong>" in result or "<b>" in result
|
||||
assert "bold" in result
|
||||
|
||||
def test_italic_conversion(self):
|
||||
"""*italic* should produce <em> tags."""
|
||||
result = self.adapter._markdown_to_html("*italic*")
|
||||
assert "<em>" in result or "<i>" in result
|
||||
|
||||
def test_inline_code(self):
|
||||
"""`code` should produce <code> tags."""
|
||||
result = self.adapter._markdown_to_html("`code`")
|
||||
assert "<code>" in result
|
||||
|
||||
def test_plain_text_returns_html(self):
|
||||
"""Plain text should still be returned (possibly with <br> or <p>)."""
|
||||
result = self.adapter._markdown_to_html("Hello world")
|
||||
assert "Hello world" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper: display name extraction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMatrixDisplayName:
|
||||
def setup_method(self):
|
||||
self.adapter = _make_adapter()
|
||||
|
||||
def test_get_display_name_from_room_users(self):
|
||||
"""Should get display name from room's users dict."""
|
||||
mock_room = MagicMock()
|
||||
mock_user = MagicMock()
|
||||
mock_user.display_name = "Alice"
|
||||
mock_room.users = {"@alice:ex.org": mock_user}
|
||||
|
||||
name = self.adapter._get_display_name(mock_room, "@alice:ex.org")
|
||||
assert name == "Alice"
|
||||
|
||||
def test_get_display_name_fallback_to_localpart(self):
|
||||
"""Should extract localpart from @user:server format."""
|
||||
mock_room = MagicMock()
|
||||
mock_room.users = {}
|
||||
|
||||
name = self.adapter._get_display_name(mock_room, "@bob:example.org")
|
||||
assert name == "bob"
|
||||
|
||||
def test_get_display_name_no_room(self):
|
||||
"""Should handle None room gracefully."""
|
||||
name = self.adapter._get_display_name(None, "@charlie:ex.org")
|
||||
assert name == "charlie"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Requirements check
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMatrixRequirements:
|
||||
def test_check_requirements_with_token(self, monkeypatch):
|
||||
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_test")
|
||||
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
|
||||
from gateway.platforms.matrix import check_matrix_requirements
|
||||
try:
|
||||
import nio # noqa: F401
|
||||
assert check_matrix_requirements() is True
|
||||
except ImportError:
|
||||
assert check_matrix_requirements() is False
|
||||
|
||||
def test_check_requirements_without_creds(self, monkeypatch):
|
||||
monkeypatch.delenv("MATRIX_ACCESS_TOKEN", raising=False)
|
||||
monkeypatch.delenv("MATRIX_PASSWORD", raising=False)
|
||||
monkeypatch.delenv("MATRIX_HOMESERVER", raising=False)
|
||||
from gateway.platforms.matrix import check_matrix_requirements
|
||||
assert check_matrix_requirements() is False
|
||||
|
||||
def test_check_requirements_without_homeserver(self, monkeypatch):
|
||||
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_test")
|
||||
monkeypatch.delenv("MATRIX_HOMESERVER", raising=False)
|
||||
from gateway.platforms.matrix import check_matrix_requirements
|
||||
assert check_matrix_requirements() is False
|
||||
673
hermes_code/tests/gateway/test_mattermost.py
Normal file
673
hermes_code/tests/gateway/test_mattermost.py
Normal file
|
|
@ -0,0 +1,673 @@
|
|||
"""Tests for Mattermost platform adapter."""
|
||||
import json
|
||||
import time
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Platform & Config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMattermostPlatformEnum:
|
||||
def test_mattermost_enum_exists(self):
|
||||
assert Platform.MATTERMOST.value == "mattermost"
|
||||
|
||||
def test_mattermost_in_platform_list(self):
|
||||
platforms = [p.value for p in Platform]
|
||||
assert "mattermost" in platforms
|
||||
|
||||
|
||||
class TestMattermostConfigLoading:
|
||||
def test_apply_env_overrides_mattermost(self, monkeypatch):
|
||||
monkeypatch.setenv("MATTERMOST_TOKEN", "mm-tok-abc123")
|
||||
monkeypatch.setenv("MATTERMOST_URL", "https://mm.example.com")
|
||||
|
||||
from gateway.config import GatewayConfig, _apply_env_overrides
|
||||
config = GatewayConfig()
|
||||
_apply_env_overrides(config)
|
||||
|
||||
assert Platform.MATTERMOST in config.platforms
|
||||
mc = config.platforms[Platform.MATTERMOST]
|
||||
assert mc.enabled is True
|
||||
assert mc.token == "mm-tok-abc123"
|
||||
assert mc.extra.get("url") == "https://mm.example.com"
|
||||
|
||||
def test_mattermost_not_loaded_without_token(self, monkeypatch):
|
||||
monkeypatch.delenv("MATTERMOST_TOKEN", raising=False)
|
||||
monkeypatch.delenv("MATTERMOST_URL", raising=False)
|
||||
|
||||
from gateway.config import GatewayConfig, _apply_env_overrides
|
||||
config = GatewayConfig()
|
||||
_apply_env_overrides(config)
|
||||
|
||||
assert Platform.MATTERMOST not in config.platforms
|
||||
|
||||
def test_connected_platforms_includes_mattermost(self, monkeypatch):
|
||||
monkeypatch.setenv("MATTERMOST_TOKEN", "mm-tok-abc123")
|
||||
monkeypatch.setenv("MATTERMOST_URL", "https://mm.example.com")
|
||||
|
||||
from gateway.config import GatewayConfig, _apply_env_overrides
|
||||
config = GatewayConfig()
|
||||
_apply_env_overrides(config)
|
||||
|
||||
connected = config.get_connected_platforms()
|
||||
assert Platform.MATTERMOST in connected
|
||||
|
||||
def test_mattermost_home_channel(self, monkeypatch):
|
||||
monkeypatch.setenv("MATTERMOST_TOKEN", "mm-tok-abc123")
|
||||
monkeypatch.setenv("MATTERMOST_URL", "https://mm.example.com")
|
||||
monkeypatch.setenv("MATTERMOST_HOME_CHANNEL", "ch_abc123")
|
||||
monkeypatch.setenv("MATTERMOST_HOME_CHANNEL_NAME", "General")
|
||||
|
||||
from gateway.config import GatewayConfig, _apply_env_overrides
|
||||
config = GatewayConfig()
|
||||
_apply_env_overrides(config)
|
||||
|
||||
home = config.get_home_channel(Platform.MATTERMOST)
|
||||
assert home is not None
|
||||
assert home.chat_id == "ch_abc123"
|
||||
assert home.name == "General"
|
||||
|
||||
def test_mattermost_url_warning_without_url(self, monkeypatch):
|
||||
"""MATTERMOST_TOKEN set but MATTERMOST_URL missing should still load."""
|
||||
monkeypatch.setenv("MATTERMOST_TOKEN", "mm-tok-abc123")
|
||||
monkeypatch.delenv("MATTERMOST_URL", raising=False)
|
||||
|
||||
from gateway.config import GatewayConfig, _apply_env_overrides
|
||||
config = GatewayConfig()
|
||||
_apply_env_overrides(config)
|
||||
|
||||
assert Platform.MATTERMOST in config.platforms
|
||||
assert config.platforms[Platform.MATTERMOST].extra.get("url") == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Adapter format / truncate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_adapter():
|
||||
"""Create a MattermostAdapter with mocked config."""
|
||||
from gateway.platforms.mattermost import MattermostAdapter
|
||||
config = PlatformConfig(
|
||||
enabled=True,
|
||||
token="test-token",
|
||||
extra={"url": "https://mm.example.com"},
|
||||
)
|
||||
adapter = MattermostAdapter(config)
|
||||
return adapter
|
||||
|
||||
|
||||
class TestMattermostFormatMessage:
|
||||
def setup_method(self):
|
||||
self.adapter = _make_adapter()
|
||||
|
||||
def test_image_markdown_to_url(self):
|
||||
""" should be converted to just the URL."""
|
||||
result = self.adapter.format_message("")
|
||||
assert result == "https://img.example.com/cat.png"
|
||||
|
||||
def test_image_markdown_strips_alt_text(self):
|
||||
result = self.adapter.format_message("Here:  done")
|
||||
assert ""
|
||||
assert self.adapter.format_message(content) == content
|
||||
|
||||
def test_plain_text_unchanged(self):
|
||||
content = "Hello, world!"
|
||||
assert self.adapter.format_message(content) == content
|
||||
|
||||
def test_multiple_images(self):
|
||||
content = " text "
|
||||
result = self.adapter.format_message(content)
|
||||
assert "![" not in result
|
||||
assert "http://a.com/1.png" in result
|
||||
assert "http://b.com/2.png" in result
|
||||
|
||||
|
||||
class TestMattermostTruncateMessage:
|
||||
def setup_method(self):
|
||||
self.adapter = _make_adapter()
|
||||
|
||||
def test_short_message_single_chunk(self):
|
||||
msg = "Hello, world!"
|
||||
chunks = self.adapter.truncate_message(msg, 4000)
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0] == msg
|
||||
|
||||
def test_long_message_splits(self):
|
||||
msg = "a " * 2500 # 5000 chars
|
||||
chunks = self.adapter.truncate_message(msg, 4000)
|
||||
assert len(chunks) >= 2
|
||||
for chunk in chunks:
|
||||
assert len(chunk) <= 4000
|
||||
|
||||
def test_custom_max_length(self):
|
||||
msg = "Hello " * 20
|
||||
chunks = self.adapter.truncate_message(msg, max_length=50)
|
||||
assert all(len(c) <= 50 for c in chunks)
|
||||
|
||||
def test_exactly_at_limit(self):
|
||||
msg = "x" * 4000
|
||||
chunks = self.adapter.truncate_message(msg, 4000)
|
||||
assert len(chunks) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Send
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMattermostSend:
|
||||
def setup_method(self):
|
||||
self.adapter = _make_adapter()
|
||||
self.adapter._session = MagicMock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_calls_api_post(self):
|
||||
"""send() should POST to /api/v4/posts with channel_id and message."""
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.json = AsyncMock(return_value={"id": "post123"})
|
||||
mock_resp.text = AsyncMock(return_value="")
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
self.adapter._session.post = MagicMock(return_value=mock_resp)
|
||||
|
||||
result = await self.adapter.send("channel_1", "Hello!")
|
||||
|
||||
assert result.success is True
|
||||
assert result.message_id == "post123"
|
||||
|
||||
# Verify post was called with correct URL
|
||||
call_args = self.adapter._session.post.call_args
|
||||
assert "/api/v4/posts" in call_args[0][0]
|
||||
# Verify payload
|
||||
payload = call_args[1]["json"]
|
||||
assert payload["channel_id"] == "channel_1"
|
||||
assert payload["message"] == "Hello!"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_empty_content_succeeds(self):
|
||||
"""Empty content should return success without calling the API."""
|
||||
result = await self.adapter.send("channel_1", "")
|
||||
assert result.success is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_with_thread_reply(self):
|
||||
"""When reply_mode is 'thread', reply_to should become root_id."""
|
||||
self.adapter._reply_mode = "thread"
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.json = AsyncMock(return_value={"id": "post456"})
|
||||
mock_resp.text = AsyncMock(return_value="")
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
self.adapter._session.post = MagicMock(return_value=mock_resp)
|
||||
|
||||
result = await self.adapter.send("channel_1", "Reply!", reply_to="root_post")
|
||||
|
||||
assert result.success is True
|
||||
payload = self.adapter._session.post.call_args[1]["json"]
|
||||
assert payload["root_id"] == "root_post"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_without_thread_no_root_id(self):
|
||||
"""When reply_mode is 'off', reply_to should NOT set root_id."""
|
||||
self.adapter._reply_mode = "off"
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.json = AsyncMock(return_value={"id": "post789"})
|
||||
mock_resp.text = AsyncMock(return_value="")
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
self.adapter._session.post = MagicMock(return_value=mock_resp)
|
||||
|
||||
result = await self.adapter.send("channel_1", "Reply!", reply_to="root_post")
|
||||
|
||||
assert result.success is True
|
||||
payload = self.adapter._session.post.call_args[1]["json"]
|
||||
assert "root_id" not in payload
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_api_failure(self):
|
||||
"""When API returns error, send should return failure."""
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 500
|
||||
mock_resp.json = AsyncMock(return_value={})
|
||||
mock_resp.text = AsyncMock(return_value="Internal Server Error")
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
self.adapter._session.post = MagicMock(return_value=mock_resp)
|
||||
|
||||
result = await self.adapter.send("channel_1", "Hello!")
|
||||
|
||||
assert result.success is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WebSocket event parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMattermostWebSocketParsing:
|
||||
def setup_method(self):
|
||||
self.adapter = _make_adapter()
|
||||
self.adapter._bot_user_id = "bot_user_id"
|
||||
# Mock handle_message to capture the MessageEvent without processing
|
||||
self.adapter.handle_message = AsyncMock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_posted_event(self):
|
||||
"""'posted' events should extract message from double-encoded post JSON."""
|
||||
post_data = {
|
||||
"id": "post_abc",
|
||||
"user_id": "user_123",
|
||||
"channel_id": "chan_456",
|
||||
"message": "@bot_user_id Hello from Matrix!",
|
||||
}
|
||||
event = {
|
||||
"event": "posted",
|
||||
"data": {
|
||||
"post": json.dumps(post_data), # double-encoded JSON string
|
||||
"channel_type": "O",
|
||||
"sender_name": "@alice",
|
||||
},
|
||||
}
|
||||
|
||||
await self.adapter._handle_ws_event(event)
|
||||
assert self.adapter.handle_message.called
|
||||
msg_event = self.adapter.handle_message.call_args[0][0]
|
||||
assert msg_event.text == "@bot_user_id Hello from Matrix!"
|
||||
assert msg_event.message_id == "post_abc"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ignore_own_messages(self):
|
||||
"""Messages from the bot's own user_id should be ignored."""
|
||||
post_data = {
|
||||
"id": "post_self",
|
||||
"user_id": "bot_user_id", # same as bot
|
||||
"channel_id": "chan_456",
|
||||
"message": "Bot echo",
|
||||
}
|
||||
event = {
|
||||
"event": "posted",
|
||||
"data": {
|
||||
"post": json.dumps(post_data),
|
||||
"channel_type": "O",
|
||||
},
|
||||
}
|
||||
|
||||
await self.adapter._handle_ws_event(event)
|
||||
assert not self.adapter.handle_message.called
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ignore_non_posted_events(self):
|
||||
"""Non-'posted' events should be ignored."""
|
||||
event = {
|
||||
"event": "typing",
|
||||
"data": {"user_id": "user_123"},
|
||||
}
|
||||
|
||||
await self.adapter._handle_ws_event(event)
|
||||
assert not self.adapter.handle_message.called
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ignore_system_posts(self):
|
||||
"""Posts with a 'type' field (system messages) should be ignored."""
|
||||
post_data = {
|
||||
"id": "sys_post",
|
||||
"user_id": "user_123",
|
||||
"channel_id": "chan_456",
|
||||
"message": "user joined",
|
||||
"type": "system_join_channel",
|
||||
}
|
||||
event = {
|
||||
"event": "posted",
|
||||
"data": {
|
||||
"post": json.dumps(post_data),
|
||||
"channel_type": "O",
|
||||
},
|
||||
}
|
||||
|
||||
await self.adapter._handle_ws_event(event)
|
||||
assert not self.adapter.handle_message.called
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_channel_type_mapping(self):
|
||||
"""channel_type 'D' should map to 'dm'."""
|
||||
post_data = {
|
||||
"id": "post_dm",
|
||||
"user_id": "user_123",
|
||||
"channel_id": "chan_dm",
|
||||
"message": "DM message",
|
||||
}
|
||||
event = {
|
||||
"event": "posted",
|
||||
"data": {
|
||||
"post": json.dumps(post_data),
|
||||
"channel_type": "D",
|
||||
"sender_name": "@bob",
|
||||
},
|
||||
}
|
||||
|
||||
await self.adapter._handle_ws_event(event)
|
||||
assert self.adapter.handle_message.called
|
||||
msg_event = self.adapter.handle_message.call_args[0][0]
|
||||
assert msg_event.source.chat_type == "dm"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_thread_id_from_root_id(self):
|
||||
"""Post with root_id should have thread_id set."""
|
||||
post_data = {
|
||||
"id": "post_reply",
|
||||
"user_id": "user_123",
|
||||
"channel_id": "chan_456",
|
||||
"message": "@bot_user_id Thread reply",
|
||||
"root_id": "root_post_123",
|
||||
}
|
||||
event = {
|
||||
"event": "posted",
|
||||
"data": {
|
||||
"post": json.dumps(post_data),
|
||||
"channel_type": "O",
|
||||
"sender_name": "@alice",
|
||||
},
|
||||
}
|
||||
|
||||
await self.adapter._handle_ws_event(event)
|
||||
assert self.adapter.handle_message.called
|
||||
msg_event = self.adapter.handle_message.call_args[0][0]
|
||||
assert msg_event.source.thread_id == "root_post_123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_post_json_ignored(self):
|
||||
"""Invalid JSON in data.post should be silently ignored."""
|
||||
event = {
|
||||
"event": "posted",
|
||||
"data": {
|
||||
"post": "not-valid-json{{{",
|
||||
"channel_type": "O",
|
||||
},
|
||||
}
|
||||
|
||||
await self.adapter._handle_ws_event(event)
|
||||
assert not self.adapter.handle_message.called
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# File upload (send_image)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMattermostFileUpload:
|
||||
def setup_method(self):
|
||||
self.adapter = _make_adapter()
|
||||
self.adapter._session = MagicMock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_image_downloads_and_uploads(self):
|
||||
"""send_image should download the URL, upload via /api/v4/files, then post."""
|
||||
# Mock the download (GET)
|
||||
mock_dl_resp = AsyncMock()
|
||||
mock_dl_resp.status = 200
|
||||
mock_dl_resp.read = AsyncMock(return_value=b"\x89PNG\x00fake-image-data")
|
||||
mock_dl_resp.content_type = "image/png"
|
||||
mock_dl_resp.__aenter__ = AsyncMock(return_value=mock_dl_resp)
|
||||
mock_dl_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
# Mock the upload (POST to /files)
|
||||
mock_upload_resp = AsyncMock()
|
||||
mock_upload_resp.status = 200
|
||||
mock_upload_resp.json = AsyncMock(return_value={
|
||||
"file_infos": [{"id": "file_abc123"}]
|
||||
})
|
||||
mock_upload_resp.text = AsyncMock(return_value="")
|
||||
mock_upload_resp.__aenter__ = AsyncMock(return_value=mock_upload_resp)
|
||||
mock_upload_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
# Mock the post (POST to /posts)
|
||||
mock_post_resp = AsyncMock()
|
||||
mock_post_resp.status = 200
|
||||
mock_post_resp.json = AsyncMock(return_value={"id": "post_with_file"})
|
||||
mock_post_resp.text = AsyncMock(return_value="")
|
||||
mock_post_resp.__aenter__ = AsyncMock(return_value=mock_post_resp)
|
||||
mock_post_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
# Route calls: first GET (download), then POST (upload), then POST (create post)
|
||||
self.adapter._session.get = MagicMock(return_value=mock_dl_resp)
|
||||
post_call_count = 0
|
||||
original_post_returns = [mock_upload_resp, mock_post_resp]
|
||||
|
||||
def post_side_effect(*args, **kwargs):
|
||||
nonlocal post_call_count
|
||||
resp = original_post_returns[min(post_call_count, len(original_post_returns) - 1)]
|
||||
post_call_count += 1
|
||||
return resp
|
||||
|
||||
self.adapter._session.post = MagicMock(side_effect=post_side_effect)
|
||||
|
||||
result = await self.adapter.send_image(
|
||||
"channel_1", "https://img.example.com/cat.png", caption="A cat"
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.message_id == "post_with_file"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dedup cache
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMattermostDedup:
|
||||
def setup_method(self):
|
||||
self.adapter = _make_adapter()
|
||||
self.adapter._bot_user_id = "bot_user_id"
|
||||
# Mock handle_message to capture calls without processing
|
||||
self.adapter.handle_message = AsyncMock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_duplicate_post_ignored(self):
|
||||
"""The same post_id within the TTL window should be ignored."""
|
||||
post_data = {
|
||||
"id": "post_dup",
|
||||
"user_id": "user_123",
|
||||
"channel_id": "chan_456",
|
||||
"message": "@bot_user_id Hello!",
|
||||
}
|
||||
event = {
|
||||
"event": "posted",
|
||||
"data": {
|
||||
"post": json.dumps(post_data),
|
||||
"channel_type": "O",
|
||||
"sender_name": "@alice",
|
||||
},
|
||||
}
|
||||
|
||||
# First time: should process
|
||||
await self.adapter._handle_ws_event(event)
|
||||
assert self.adapter.handle_message.call_count == 1
|
||||
|
||||
# Second time (same post_id): should be deduped
|
||||
await self.adapter._handle_ws_event(event)
|
||||
assert self.adapter.handle_message.call_count == 1 # still 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_post_ids_both_processed(self):
|
||||
"""Different post IDs should both be processed."""
|
||||
for i, pid in enumerate(["post_a", "post_b"]):
|
||||
post_data = {
|
||||
"id": pid,
|
||||
"user_id": "user_123",
|
||||
"channel_id": "chan_456",
|
||||
"message": f"@bot_user_id Message {i}",
|
||||
}
|
||||
event = {
|
||||
"event": "posted",
|
||||
"data": {
|
||||
"post": json.dumps(post_data),
|
||||
"channel_type": "O",
|
||||
"sender_name": "@alice",
|
||||
},
|
||||
}
|
||||
await self.adapter._handle_ws_event(event)
|
||||
|
||||
assert self.adapter.handle_message.call_count == 2
|
||||
|
||||
def test_prune_seen_clears_expired(self):
|
||||
"""_prune_seen should remove entries older than _SEEN_TTL."""
|
||||
now = time.time()
|
||||
# Fill with enough expired entries to trigger pruning
|
||||
for i in range(self.adapter._SEEN_MAX + 10):
|
||||
self.adapter._seen_posts[f"old_{i}"] = now - 600 # 10 min ago
|
||||
|
||||
# Add a fresh one
|
||||
self.adapter._seen_posts["fresh"] = now
|
||||
|
||||
self.adapter._prune_seen()
|
||||
|
||||
# Old entries should be pruned, fresh one kept
|
||||
assert "fresh" in self.adapter._seen_posts
|
||||
assert len(self.adapter._seen_posts) < self.adapter._SEEN_MAX
|
||||
|
||||
def test_seen_cache_tracks_post_ids(self):
|
||||
"""Posts are tracked in _seen_posts dict."""
|
||||
self.adapter._seen_posts["test_post"] = time.time()
|
||||
assert "test_post" in self.adapter._seen_posts
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Requirements check
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMattermostRequirements:
|
||||
def test_check_requirements_with_token_and_url(self, monkeypatch):
|
||||
monkeypatch.setenv("MATTERMOST_TOKEN", "test-token")
|
||||
monkeypatch.setenv("MATTERMOST_URL", "https://mm.example.com")
|
||||
from gateway.platforms.mattermost import check_mattermost_requirements
|
||||
assert check_mattermost_requirements() is True
|
||||
|
||||
def test_check_requirements_without_token(self, monkeypatch):
|
||||
monkeypatch.delenv("MATTERMOST_TOKEN", raising=False)
|
||||
monkeypatch.delenv("MATTERMOST_URL", raising=False)
|
||||
from gateway.platforms.mattermost import check_mattermost_requirements
|
||||
assert check_mattermost_requirements() is False
|
||||
|
||||
def test_check_requirements_without_url(self, monkeypatch):
|
||||
monkeypatch.setenv("MATTERMOST_TOKEN", "test-token")
|
||||
monkeypatch.delenv("MATTERMOST_URL", raising=False)
|
||||
from gateway.platforms.mattermost import check_mattermost_requirements
|
||||
assert check_mattermost_requirements() is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Media type propagation (MIME types, not bare strings)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMattermostMediaTypes:
|
||||
"""Verify that media_types contains actual MIME types (e.g. 'image/png')
|
||||
rather than bare category strings ('image'), so downstream
|
||||
``mtype.startswith("image/")`` checks in run.py work correctly."""
|
||||
|
||||
def setup_method(self):
|
||||
self.adapter = _make_adapter()
|
||||
self.adapter._bot_user_id = "bot_user_id"
|
||||
self.adapter.handle_message = AsyncMock()
|
||||
|
||||
def _make_event(self, file_ids):
|
||||
post_data = {
|
||||
"id": "post_media",
|
||||
"user_id": "user_123",
|
||||
"channel_id": "chan_456",
|
||||
"message": "@bot_user_id file attached",
|
||||
"file_ids": file_ids,
|
||||
}
|
||||
return {
|
||||
"event": "posted",
|
||||
"data": {
|
||||
"post": json.dumps(post_data),
|
||||
"channel_type": "O",
|
||||
"sender_name": "@alice",
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_media_type_is_full_mime(self):
|
||||
"""An image attachment should produce 'image/png', not 'image'."""
|
||||
file_info = {"name": "photo.png", "mime_type": "image/png"}
|
||||
self.adapter._api_get = AsyncMock(return_value=file_info)
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.read = AsyncMock(return_value=b"\x89PNG fake")
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
self.adapter._session = MagicMock()
|
||||
self.adapter._session.get = MagicMock(return_value=mock_resp)
|
||||
|
||||
with patch("gateway.platforms.base.cache_image_from_bytes", return_value="/tmp/photo.png"):
|
||||
await self.adapter._handle_ws_event(self._make_event(["file1"]))
|
||||
|
||||
msg = self.adapter.handle_message.call_args[0][0]
|
||||
assert msg.media_types == ["image/png"]
|
||||
assert msg.media_types[0].startswith("image/")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_media_type_is_full_mime(self):
|
||||
"""An audio attachment should produce 'audio/ogg', not 'audio'."""
|
||||
file_info = {"name": "voice.ogg", "mime_type": "audio/ogg"}
|
||||
self.adapter._api_get = AsyncMock(return_value=file_info)
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.read = AsyncMock(return_value=b"OGG fake")
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
self.adapter._session = MagicMock()
|
||||
self.adapter._session.get = MagicMock(return_value=mock_resp)
|
||||
|
||||
with patch("gateway.platforms.base.cache_audio_from_bytes", return_value="/tmp/voice.ogg"), \
|
||||
patch("gateway.platforms.base.cache_image_from_bytes"), \
|
||||
patch("gateway.platforms.base.cache_document_from_bytes"):
|
||||
await self.adapter._handle_ws_event(self._make_event(["file2"]))
|
||||
|
||||
msg = self.adapter.handle_message.call_args[0][0]
|
||||
assert msg.media_types == ["audio/ogg"]
|
||||
assert msg.media_types[0].startswith("audio/")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_media_type_is_full_mime(self):
|
||||
"""A document attachment should produce 'application/pdf', not 'document'."""
|
||||
file_info = {"name": "report.pdf", "mime_type": "application/pdf"}
|
||||
self.adapter._api_get = AsyncMock(return_value=file_info)
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.read = AsyncMock(return_value=b"PDF fake")
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
self.adapter._session = MagicMock()
|
||||
self.adapter._session.get = MagicMock(return_value=mock_resp)
|
||||
|
||||
with patch("gateway.platforms.base.cache_document_from_bytes", return_value="/tmp/report.pdf"), \
|
||||
patch("gateway.platforms.base.cache_image_from_bytes"):
|
||||
await self.adapter._handle_ws_event(self._make_event(["file3"]))
|
||||
|
||||
msg = self.adapter.handle_message.call_args[0][0]
|
||||
assert msg.media_types == ["application/pdf"]
|
||||
assert not msg.media_types[0].startswith("image/")
|
||||
assert not msg.media_types[0].startswith("audio/")
|
||||
184
hermes_code/tests/gateway/test_media_extraction.py
Normal file
184
hermes_code/tests/gateway/test_media_extraction.py
Normal file
|
|
@ -0,0 +1,184 @@
|
|||
"""
|
||||
Tests for MEDIA tag extraction from tool results.
|
||||
|
||||
Verifies that MEDIA tags (e.g., from TTS tool) are only extracted from
|
||||
messages in the CURRENT turn, not from the full conversation history.
|
||||
This prevents voice messages from accumulating and being sent multiple
|
||||
times per reply. (Regression test for #160)
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import re
|
||||
|
||||
|
||||
def extract_media_tags_fixed(result_messages, history_len):
|
||||
"""
|
||||
Extract MEDIA tags from tool results, but ONLY from new messages
|
||||
(those added after history_len). This is the fixed behavior.
|
||||
|
||||
Args:
|
||||
result_messages: Full list of messages including history + new
|
||||
history_len: Length of history before this turn
|
||||
|
||||
Returns:
|
||||
Tuple of (media_tags list, has_voice_directive bool)
|
||||
"""
|
||||
media_tags = []
|
||||
has_voice_directive = False
|
||||
|
||||
# Only process new messages from this turn
|
||||
new_messages = result_messages[history_len:] if len(result_messages) > history_len else []
|
||||
|
||||
for msg in new_messages:
|
||||
if msg.get("role") == "tool" or msg.get("role") == "function":
|
||||
content = msg.get("content", "")
|
||||
if "MEDIA:" in content:
|
||||
for match in re.finditer(r'MEDIA:(\S+)', content):
|
||||
path = match.group(1).strip().rstrip('",}')
|
||||
if path:
|
||||
media_tags.append(f"MEDIA:{path}")
|
||||
if "[[audio_as_voice]]" in content:
|
||||
has_voice_directive = True
|
||||
|
||||
return media_tags, has_voice_directive
|
||||
|
||||
|
||||
def extract_media_tags_broken(result_messages):
|
||||
"""
|
||||
The BROKEN behavior: extract MEDIA tags from ALL messages including history.
|
||||
This causes TTS voice messages to accumulate and be re-sent on every reply.
|
||||
"""
|
||||
media_tags = []
|
||||
has_voice_directive = False
|
||||
|
||||
for msg in result_messages:
|
||||
if msg.get("role") == "tool" or msg.get("role") == "function":
|
||||
content = msg.get("content", "")
|
||||
if "MEDIA:" in content:
|
||||
for match in re.finditer(r'MEDIA:(\S+)', content):
|
||||
path = match.group(1).strip().rstrip('",}')
|
||||
if path:
|
||||
media_tags.append(f"MEDIA:{path}")
|
||||
if "[[audio_as_voice]]" in content:
|
||||
has_voice_directive = True
|
||||
|
||||
return media_tags, has_voice_directive
|
||||
|
||||
|
||||
class TestMediaExtraction:
|
||||
"""Tests for MEDIA tag extraction from tool results."""
|
||||
|
||||
def test_media_tags_not_extracted_from_history(self):
|
||||
"""MEDIA tags from previous turns should NOT be extracted again."""
|
||||
# Simulate conversation history with a TTS call from a previous turn
|
||||
history = [
|
||||
{"role": "user", "content": "Say hello as audio"},
|
||||
{"role": "assistant", "content": None, "tool_calls": [{"id": "1", "function": {"name": "text_to_speech"}}]},
|
||||
{"role": "tool", "tool_call_id": "1", "content": '{"success": true, "media_tag": "[[audio_as_voice]]\\nMEDIA:/path/to/audio1.ogg"}'},
|
||||
{"role": "assistant", "content": "I've said hello for you!"},
|
||||
]
|
||||
|
||||
# New turn: user asks a simple question
|
||||
new_messages = [
|
||||
{"role": "user", "content": "What time is it?"},
|
||||
{"role": "assistant", "content": "It's 3:30 AM."},
|
||||
]
|
||||
|
||||
all_messages = history + new_messages
|
||||
history_len = len(history)
|
||||
|
||||
# Fixed behavior: should extract NO media tags (none in new messages)
|
||||
tags, voice_directive = extract_media_tags_fixed(all_messages, history_len)
|
||||
assert tags == [], "Fixed extraction should not find tags in history"
|
||||
assert voice_directive is False
|
||||
|
||||
# Broken behavior: would incorrectly extract the old media tag
|
||||
broken_tags, broken_voice = extract_media_tags_broken(all_messages)
|
||||
assert len(broken_tags) == 1, "Broken extraction finds tags in history"
|
||||
assert "audio1.ogg" in broken_tags[0]
|
||||
|
||||
def test_media_tags_extracted_from_current_turn(self):
|
||||
"""MEDIA tags from the current turn SHOULD be extracted."""
|
||||
# History without TTS
|
||||
history = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
|
||||
# New turn with TTS call
|
||||
new_messages = [
|
||||
{"role": "user", "content": "Say goodbye as audio"},
|
||||
{"role": "assistant", "content": None, "tool_calls": [{"id": "2", "function": {"name": "text_to_speech"}}]},
|
||||
{"role": "tool", "tool_call_id": "2", "content": '{"success": true, "media_tag": "[[audio_as_voice]]\\nMEDIA:/path/to/audio2.ogg"}'},
|
||||
{"role": "assistant", "content": "I've said goodbye!"},
|
||||
]
|
||||
|
||||
all_messages = history + new_messages
|
||||
history_len = len(history)
|
||||
|
||||
# Fixed behavior: should extract the new media tag
|
||||
tags, voice_directive = extract_media_tags_fixed(all_messages, history_len)
|
||||
assert len(tags) == 1, "Should extract media tag from current turn"
|
||||
assert "audio2.ogg" in tags[0]
|
||||
assert voice_directive is True
|
||||
|
||||
def test_multiple_tts_calls_in_history_not_accumulated(self):
|
||||
"""Multiple TTS calls in history should NOT accumulate in new responses."""
|
||||
# History with multiple TTS calls
|
||||
history = [
|
||||
{"role": "user", "content": "Say hello"},
|
||||
{"role": "tool", "tool_call_id": "1", "content": 'MEDIA:/audio/hello.ogg'},
|
||||
{"role": "assistant", "content": "Done!"},
|
||||
{"role": "user", "content": "Say goodbye"},
|
||||
{"role": "tool", "tool_call_id": "2", "content": 'MEDIA:/audio/goodbye.ogg'},
|
||||
{"role": "assistant", "content": "Done!"},
|
||||
{"role": "user", "content": "Say thanks"},
|
||||
{"role": "tool", "tool_call_id": "3", "content": 'MEDIA:/audio/thanks.ogg'},
|
||||
{"role": "assistant", "content": "Done!"},
|
||||
]
|
||||
|
||||
# New turn: no TTS
|
||||
new_messages = [
|
||||
{"role": "user", "content": "What time is it?"},
|
||||
{"role": "assistant", "content": "3 PM"},
|
||||
]
|
||||
|
||||
all_messages = history + new_messages
|
||||
history_len = len(history)
|
||||
|
||||
# Fixed: no tags
|
||||
tags, _ = extract_media_tags_fixed(all_messages, history_len)
|
||||
assert tags == [], "Should not accumulate tags from history"
|
||||
|
||||
# Broken: would have 3 tags (all the old ones)
|
||||
broken_tags, _ = extract_media_tags_broken(all_messages)
|
||||
assert len(broken_tags) == 3, "Broken version accumulates all history tags"
|
||||
|
||||
def test_deduplication_within_current_turn(self):
|
||||
"""Multiple MEDIA tags in current turn should be deduplicated."""
|
||||
history = []
|
||||
|
||||
# Current turn with multiple tool calls producing same media
|
||||
new_messages = [
|
||||
{"role": "user", "content": "Multiple TTS"},
|
||||
{"role": "tool", "tool_call_id": "1", "content": 'MEDIA:/audio/same.ogg'},
|
||||
{"role": "tool", "tool_call_id": "2", "content": 'MEDIA:/audio/same.ogg'}, # duplicate
|
||||
{"role": "tool", "tool_call_id": "3", "content": 'MEDIA:/audio/different.ogg'},
|
||||
{"role": "assistant", "content": "Done!"},
|
||||
]
|
||||
|
||||
all_messages = history + new_messages
|
||||
|
||||
tags, _ = extract_media_tags_fixed(all_messages, 0)
|
||||
# Even though same.ogg appears twice, deduplication happens after extraction
|
||||
# The extraction itself should get both, then caller deduplicates
|
||||
assert len(tags) == 3 # Raw extraction gets all
|
||||
|
||||
# Deduplication as done in the actual code:
|
||||
seen = set()
|
||||
unique = [t for t in tags if t not in seen and not seen.add(t)]
|
||||
assert len(unique) == 2 # After dedup: same.ogg and different.ogg
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
229
hermes_code/tests/gateway/test_mirror.py
Normal file
229
hermes_code/tests/gateway/test_mirror.py
Normal file
|
|
@ -0,0 +1,229 @@
|
|||
"""Tests for gateway/mirror.py — session mirroring."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import gateway.mirror as mirror_mod
|
||||
from gateway.mirror import (
|
||||
mirror_to_session,
|
||||
_find_session_id,
|
||||
_append_to_jsonl,
|
||||
)
|
||||
|
||||
|
||||
def _setup_sessions(tmp_path, sessions_data):
|
||||
"""Helper to write a fake sessions.json and patch module-level paths."""
|
||||
sessions_dir = tmp_path / "sessions"
|
||||
sessions_dir.mkdir(parents=True, exist_ok=True)
|
||||
index_file = sessions_dir / "sessions.json"
|
||||
index_file.write_text(json.dumps(sessions_data))
|
||||
return sessions_dir, index_file
|
||||
|
||||
|
||||
class TestFindSessionId:
|
||||
def test_finds_matching_session(self, tmp_path):
|
||||
sessions_dir, index_file = _setup_sessions(tmp_path, {
|
||||
"agent:main:telegram:dm": {
|
||||
"session_id": "sess_abc",
|
||||
"origin": {"platform": "telegram", "chat_id": "12345"},
|
||||
"updated_at": "2026-01-01T00:00:00",
|
||||
}
|
||||
})
|
||||
|
||||
with patch.object(mirror_mod, "_SESSIONS_DIR", sessions_dir), \
|
||||
patch.object(mirror_mod, "_SESSIONS_INDEX", index_file):
|
||||
result = _find_session_id("telegram", "12345")
|
||||
|
||||
assert result == "sess_abc"
|
||||
|
||||
def test_returns_most_recent(self, tmp_path):
|
||||
sessions_dir, index_file = _setup_sessions(tmp_path, {
|
||||
"old": {
|
||||
"session_id": "sess_old",
|
||||
"origin": {"platform": "telegram", "chat_id": "12345"},
|
||||
"updated_at": "2026-01-01T00:00:00",
|
||||
},
|
||||
"new": {
|
||||
"session_id": "sess_new",
|
||||
"origin": {"platform": "telegram", "chat_id": "12345"},
|
||||
"updated_at": "2026-02-01T00:00:00",
|
||||
},
|
||||
})
|
||||
|
||||
with patch.object(mirror_mod, "_SESSIONS_DIR", sessions_dir), \
|
||||
patch.object(mirror_mod, "_SESSIONS_INDEX", index_file):
|
||||
result = _find_session_id("telegram", "12345")
|
||||
|
||||
assert result == "sess_new"
|
||||
|
||||
def test_thread_id_disambiguates_same_chat(self, tmp_path):
|
||||
sessions_dir, index_file = _setup_sessions(tmp_path, {
|
||||
"topic_a": {
|
||||
"session_id": "sess_topic_a",
|
||||
"origin": {"platform": "telegram", "chat_id": "-1001", "thread_id": "10"},
|
||||
"updated_at": "2026-01-01T00:00:00",
|
||||
},
|
||||
"topic_b": {
|
||||
"session_id": "sess_topic_b",
|
||||
"origin": {"platform": "telegram", "chat_id": "-1001", "thread_id": "11"},
|
||||
"updated_at": "2026-02-01T00:00:00",
|
||||
},
|
||||
})
|
||||
|
||||
with patch.object(mirror_mod, "_SESSIONS_DIR", sessions_dir), \
|
||||
patch.object(mirror_mod, "_SESSIONS_INDEX", index_file):
|
||||
result = _find_session_id("telegram", "-1001", thread_id="10")
|
||||
|
||||
assert result == "sess_topic_a"
|
||||
|
||||
def test_no_match_returns_none(self, tmp_path):
|
||||
sessions_dir, index_file = _setup_sessions(tmp_path, {
|
||||
"sess": {
|
||||
"session_id": "sess_1",
|
||||
"origin": {"platform": "discord", "chat_id": "999"},
|
||||
"updated_at": "2026-01-01T00:00:00",
|
||||
}
|
||||
})
|
||||
|
||||
with patch.object(mirror_mod, "_SESSIONS_INDEX", index_file):
|
||||
result = _find_session_id("telegram", "12345")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_missing_sessions_file(self, tmp_path):
|
||||
with patch.object(mirror_mod, "_SESSIONS_INDEX", tmp_path / "nope.json"):
|
||||
result = _find_session_id("telegram", "12345")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_platform_case_insensitive(self, tmp_path):
|
||||
sessions_dir, index_file = _setup_sessions(tmp_path, {
|
||||
"s1": {
|
||||
"session_id": "sess_1",
|
||||
"origin": {"platform": "Telegram", "chat_id": "123"},
|
||||
"updated_at": "2026-01-01T00:00:00",
|
||||
}
|
||||
})
|
||||
|
||||
with patch.object(mirror_mod, "_SESSIONS_INDEX", index_file):
|
||||
result = _find_session_id("telegram", "123")
|
||||
|
||||
assert result == "sess_1"
|
||||
|
||||
|
||||
class TestAppendToJsonl:
|
||||
def test_appends_message(self, tmp_path):
|
||||
sessions_dir = tmp_path / "sessions"
|
||||
sessions_dir.mkdir()
|
||||
|
||||
with patch.object(mirror_mod, "_SESSIONS_DIR", sessions_dir):
|
||||
_append_to_jsonl("sess_1", {"role": "assistant", "content": "Hello"})
|
||||
|
||||
transcript = sessions_dir / "sess_1.jsonl"
|
||||
lines = transcript.read_text().strip().splitlines()
|
||||
assert len(lines) == 1
|
||||
msg = json.loads(lines[0])
|
||||
assert msg["role"] == "assistant"
|
||||
assert msg["content"] == "Hello"
|
||||
|
||||
def test_appends_multiple_messages(self, tmp_path):
|
||||
sessions_dir = tmp_path / "sessions"
|
||||
sessions_dir.mkdir()
|
||||
|
||||
with patch.object(mirror_mod, "_SESSIONS_DIR", sessions_dir):
|
||||
_append_to_jsonl("sess_1", {"role": "assistant", "content": "msg1"})
|
||||
_append_to_jsonl("sess_1", {"role": "assistant", "content": "msg2"})
|
||||
|
||||
transcript = sessions_dir / "sess_1.jsonl"
|
||||
lines = transcript.read_text().strip().splitlines()
|
||||
assert len(lines) == 2
|
||||
|
||||
|
||||
class TestMirrorToSession:
|
||||
def test_successful_mirror(self, tmp_path):
|
||||
sessions_dir, index_file = _setup_sessions(tmp_path, {
|
||||
"s1": {
|
||||
"session_id": "sess_abc",
|
||||
"origin": {"platform": "telegram", "chat_id": "12345"},
|
||||
"updated_at": "2026-01-01T00:00:00",
|
||||
}
|
||||
})
|
||||
|
||||
with patch.object(mirror_mod, "_SESSIONS_DIR", sessions_dir), \
|
||||
patch.object(mirror_mod, "_SESSIONS_INDEX", index_file), \
|
||||
patch("gateway.mirror._append_to_sqlite"):
|
||||
result = mirror_to_session("telegram", "12345", "Hello!", source_label="cli")
|
||||
|
||||
assert result is True
|
||||
|
||||
# Check JSONL was written
|
||||
transcript = sessions_dir / "sess_abc.jsonl"
|
||||
assert transcript.exists()
|
||||
msg = json.loads(transcript.read_text().strip())
|
||||
assert msg["content"] == "Hello!"
|
||||
assert msg["role"] == "assistant"
|
||||
assert msg["mirror"] is True
|
||||
assert msg["mirror_source"] == "cli"
|
||||
|
||||
def test_successful_mirror_uses_thread_id(self, tmp_path):
|
||||
sessions_dir, index_file = _setup_sessions(tmp_path, {
|
||||
"topic_a": {
|
||||
"session_id": "sess_topic_a",
|
||||
"origin": {"platform": "telegram", "chat_id": "-1001", "thread_id": "10"},
|
||||
"updated_at": "2026-01-01T00:00:00",
|
||||
},
|
||||
"topic_b": {
|
||||
"session_id": "sess_topic_b",
|
||||
"origin": {"platform": "telegram", "chat_id": "-1001", "thread_id": "11"},
|
||||
"updated_at": "2026-02-01T00:00:00",
|
||||
},
|
||||
})
|
||||
|
||||
with patch.object(mirror_mod, "_SESSIONS_DIR", sessions_dir), \
|
||||
patch.object(mirror_mod, "_SESSIONS_INDEX", index_file), \
|
||||
patch("gateway.mirror._append_to_sqlite"):
|
||||
result = mirror_to_session("telegram", "-1001", "Hello topic!", source_label="cron", thread_id="10")
|
||||
|
||||
assert result is True
|
||||
assert (sessions_dir / "sess_topic_a.jsonl").exists()
|
||||
assert not (sessions_dir / "sess_topic_b.jsonl").exists()
|
||||
|
||||
def test_no_matching_session(self, tmp_path):
|
||||
sessions_dir, index_file = _setup_sessions(tmp_path, {})
|
||||
|
||||
with patch.object(mirror_mod, "_SESSIONS_DIR", sessions_dir), \
|
||||
patch.object(mirror_mod, "_SESSIONS_INDEX", index_file):
|
||||
result = mirror_to_session("telegram", "99999", "Hello!")
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_error_returns_false(self, tmp_path):
|
||||
with patch("gateway.mirror._find_session_id", side_effect=Exception("boom")):
|
||||
result = mirror_to_session("telegram", "123", "msg")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestAppendToSqlite:
|
||||
def test_connection_is_closed_after_use(self, tmp_path):
|
||||
"""Verify _append_to_sqlite closes the SessionDB connection."""
|
||||
from gateway.mirror import _append_to_sqlite
|
||||
mock_db = MagicMock()
|
||||
|
||||
with patch("hermes_state.SessionDB", return_value=mock_db):
|
||||
_append_to_sqlite("sess_1", {"role": "assistant", "content": "hello"})
|
||||
|
||||
mock_db.append_message.assert_called_once()
|
||||
mock_db.close.assert_called_once()
|
||||
|
||||
def test_connection_closed_even_on_error(self, tmp_path):
|
||||
"""Verify connection is closed even when append_message raises."""
|
||||
from gateway.mirror import _append_to_sqlite
|
||||
mock_db = MagicMock()
|
||||
mock_db.append_message.side_effect = Exception("db error")
|
||||
|
||||
with patch("hermes_state.SessionDB", return_value=mock_db):
|
||||
_append_to_sqlite("sess_1", {"role": "assistant", "content": "hello"})
|
||||
|
||||
mock_db.close.assert_called_once()
|
||||
356
hermes_code/tests/gateway/test_pairing.py
Normal file
356
hermes_code/tests/gateway/test_pairing.py
Normal file
|
|
@ -0,0 +1,356 @@
|
|||
"""Tests for gateway/pairing.py — DM pairing security system."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from gateway.pairing import (
|
||||
PairingStore,
|
||||
ALPHABET,
|
||||
CODE_LENGTH,
|
||||
CODE_TTL_SECONDS,
|
||||
RATE_LIMIT_SECONDS,
|
||||
MAX_PENDING_PER_PLATFORM,
|
||||
MAX_FAILED_ATTEMPTS,
|
||||
LOCKOUT_SECONDS,
|
||||
_secure_write,
|
||||
)
|
||||
|
||||
|
||||
def _make_store(tmp_path):
|
||||
"""Create a PairingStore with PAIRING_DIR pointed to tmp_path."""
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
return PairingStore()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _secure_write
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSecureWrite:
|
||||
def test_creates_parent_dirs(self, tmp_path):
|
||||
target = tmp_path / "sub" / "dir" / "file.json"
|
||||
_secure_write(target, '{"hello": "world"}')
|
||||
assert target.exists()
|
||||
assert json.loads(target.read_text()) == {"hello": "world"}
|
||||
|
||||
def test_sets_file_permissions(self, tmp_path):
|
||||
target = tmp_path / "secret.json"
|
||||
_secure_write(target, "data")
|
||||
mode = oct(target.stat().st_mode & 0o777)
|
||||
assert mode == "0o600"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Code generation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCodeGeneration:
|
||||
def test_code_format(self, tmp_path):
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
store = PairingStore()
|
||||
code = store.generate_code("telegram", "user1", "Alice")
|
||||
assert isinstance(code, str) and len(code) == CODE_LENGTH
|
||||
assert len(code) == CODE_LENGTH
|
||||
assert all(c in ALPHABET for c in code)
|
||||
|
||||
def test_code_uniqueness(self, tmp_path):
|
||||
"""Multiple codes for different users should be distinct."""
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
store = PairingStore()
|
||||
codes = set()
|
||||
for i in range(3):
|
||||
code = store.generate_code("telegram", f"user{i}")
|
||||
assert isinstance(code, str) and len(code) == CODE_LENGTH
|
||||
codes.add(code)
|
||||
assert len(codes) == 3
|
||||
|
||||
def test_stores_pending_entry(self, tmp_path):
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
store = PairingStore()
|
||||
code = store.generate_code("telegram", "user1", "Alice")
|
||||
pending = store.list_pending("telegram")
|
||||
assert len(pending) == 1
|
||||
assert pending[0]["code"] == code
|
||||
assert pending[0]["user_id"] == "user1"
|
||||
assert pending[0]["user_name"] == "Alice"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Rate limiting
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRateLimiting:
|
||||
def test_same_user_rate_limited(self, tmp_path):
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
store = PairingStore()
|
||||
code1 = store.generate_code("telegram", "user1")
|
||||
code2 = store.generate_code("telegram", "user1")
|
||||
assert isinstance(code1, str) and len(code1) == CODE_LENGTH
|
||||
assert code2 is None # rate limited
|
||||
|
||||
def test_different_users_not_rate_limited(self, tmp_path):
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
store = PairingStore()
|
||||
code1 = store.generate_code("telegram", "user1")
|
||||
code2 = store.generate_code("telegram", "user2")
|
||||
assert isinstance(code1, str) and len(code1) == CODE_LENGTH
|
||||
assert isinstance(code2, str) and len(code2) == CODE_LENGTH
|
||||
|
||||
def test_rate_limit_expires(self, tmp_path):
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
store = PairingStore()
|
||||
code1 = store.generate_code("telegram", "user1")
|
||||
assert isinstance(code1, str) and len(code1) == CODE_LENGTH
|
||||
|
||||
# Simulate rate limit expiry
|
||||
limits = store._load_json(store._rate_limit_path())
|
||||
limits["telegram:user1"] = time.time() - RATE_LIMIT_SECONDS - 1
|
||||
store._save_json(store._rate_limit_path(), limits)
|
||||
|
||||
code2 = store.generate_code("telegram", "user1")
|
||||
assert isinstance(code2, str) and len(code2) == CODE_LENGTH
|
||||
assert code2 != code1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Max pending limit
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMaxPending:
|
||||
def test_max_pending_per_platform(self, tmp_path):
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
store = PairingStore()
|
||||
codes = []
|
||||
for i in range(MAX_PENDING_PER_PLATFORM + 1):
|
||||
code = store.generate_code("telegram", f"user{i}")
|
||||
codes.append(code)
|
||||
|
||||
# First MAX_PENDING_PER_PLATFORM should succeed
|
||||
assert all(isinstance(c, str) and len(c) == CODE_LENGTH for c in codes[:MAX_PENDING_PER_PLATFORM])
|
||||
# Next one should be blocked
|
||||
assert codes[MAX_PENDING_PER_PLATFORM] is None
|
||||
|
||||
def test_different_platforms_independent(self, tmp_path):
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
store = PairingStore()
|
||||
for i in range(MAX_PENDING_PER_PLATFORM):
|
||||
store.generate_code("telegram", f"user{i}")
|
||||
# Different platform should still work
|
||||
code = store.generate_code("discord", "user0")
|
||||
assert isinstance(code, str) and len(code) == CODE_LENGTH
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Approval flow
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestApprovalFlow:
|
||||
def test_approve_valid_code(self, tmp_path):
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
store = PairingStore()
|
||||
code = store.generate_code("telegram", "user1", "Alice")
|
||||
result = store.approve_code("telegram", code)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "user_id" in result
|
||||
assert "user_name" in result
|
||||
assert result["user_id"] == "user1"
|
||||
assert result["user_name"] == "Alice"
|
||||
|
||||
def test_approved_user_is_approved(self, tmp_path):
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
store = PairingStore()
|
||||
code = store.generate_code("telegram", "user1", "Alice")
|
||||
store.approve_code("telegram", code)
|
||||
assert store.is_approved("telegram", "user1") is True
|
||||
|
||||
def test_unapproved_user_not_approved(self, tmp_path):
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
store = PairingStore()
|
||||
assert store.is_approved("telegram", "nonexistent") is False
|
||||
|
||||
def test_approve_removes_from_pending(self, tmp_path):
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
store = PairingStore()
|
||||
code = store.generate_code("telegram", "user1")
|
||||
store.approve_code("telegram", code)
|
||||
pending = store.list_pending("telegram")
|
||||
assert len(pending) == 0
|
||||
|
||||
def test_approve_case_insensitive(self, tmp_path):
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
store = PairingStore()
|
||||
code = store.generate_code("telegram", "user1", "Alice")
|
||||
result = store.approve_code("telegram", code.lower())
|
||||
assert isinstance(result, dict)
|
||||
assert result["user_id"] == "user1"
|
||||
assert result["user_name"] == "Alice"
|
||||
|
||||
def test_approve_strips_whitespace(self, tmp_path):
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
store = PairingStore()
|
||||
code = store.generate_code("telegram", "user1", "Alice")
|
||||
result = store.approve_code("telegram", f" {code} ")
|
||||
assert isinstance(result, dict)
|
||||
assert result["user_id"] == "user1"
|
||||
assert result["user_name"] == "Alice"
|
||||
|
||||
def test_invalid_code_returns_none(self, tmp_path):
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
store = PairingStore()
|
||||
result = store.approve_code("telegram", "INVALIDCODE")
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lockout after failed attempts
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLockout:
|
||||
def test_lockout_after_max_failures(self, tmp_path):
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
store = PairingStore()
|
||||
# Generate a valid code so platform has data
|
||||
store.generate_code("telegram", "user1")
|
||||
|
||||
# Exhaust failed attempts
|
||||
for _ in range(MAX_FAILED_ATTEMPTS):
|
||||
store.approve_code("telegram", "WRONGCODE")
|
||||
|
||||
# Platform should now be locked out — can't generate new codes
|
||||
assert store._is_locked_out("telegram") is True
|
||||
|
||||
def test_lockout_blocks_code_generation(self, tmp_path):
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
store = PairingStore()
|
||||
for _ in range(MAX_FAILED_ATTEMPTS):
|
||||
store.approve_code("telegram", "WRONG")
|
||||
|
||||
code = store.generate_code("telegram", "newuser")
|
||||
assert code is None
|
||||
|
||||
def test_lockout_expires(self, tmp_path):
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
store = PairingStore()
|
||||
for _ in range(MAX_FAILED_ATTEMPTS):
|
||||
store.approve_code("telegram", "WRONG")
|
||||
|
||||
# Simulate lockout expiry
|
||||
limits = store._load_json(store._rate_limit_path())
|
||||
lockout_key = "_lockout:telegram"
|
||||
limits[lockout_key] = time.time() - 1 # expired
|
||||
store._save_json(store._rate_limit_path(), limits)
|
||||
|
||||
assert store._is_locked_out("telegram") is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Code expiry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCodeExpiry:
|
||||
def test_expired_codes_cleaned_up(self, tmp_path):
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
store = PairingStore()
|
||||
code = store.generate_code("telegram", "user1")
|
||||
|
||||
# Manually expire the code
|
||||
pending = store._load_json(store._pending_path("telegram"))
|
||||
pending[code]["created_at"] = time.time() - CODE_TTL_SECONDS - 1
|
||||
store._save_json(store._pending_path("telegram"), pending)
|
||||
|
||||
# Cleanup happens on next operation
|
||||
remaining = store.list_pending("telegram")
|
||||
assert len(remaining) == 0
|
||||
|
||||
def test_expired_code_cannot_be_approved(self, tmp_path):
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
store = PairingStore()
|
||||
code = store.generate_code("telegram", "user1")
|
||||
|
||||
# Expire it
|
||||
pending = store._load_json(store._pending_path("telegram"))
|
||||
pending[code]["created_at"] = time.time() - CODE_TTL_SECONDS - 1
|
||||
store._save_json(store._pending_path("telegram"), pending)
|
||||
|
||||
result = store.approve_code("telegram", code)
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Revoke
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRevoke:
|
||||
def test_revoke_approved_user(self, tmp_path):
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
store = PairingStore()
|
||||
code = store.generate_code("telegram", "user1", "Alice")
|
||||
store.approve_code("telegram", code)
|
||||
assert store.is_approved("telegram", "user1") is True
|
||||
|
||||
revoked = store.revoke("telegram", "user1")
|
||||
assert revoked is True
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
assert store.is_approved("telegram", "user1") is False
|
||||
|
||||
def test_revoke_nonexistent_returns_false(self, tmp_path):
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
store = PairingStore()
|
||||
assert store.revoke("telegram", "nobody") is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# List & clear
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestListAndClear:
|
||||
def test_list_approved(self, tmp_path):
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
store = PairingStore()
|
||||
code = store.generate_code("telegram", "user1", "Alice")
|
||||
store.approve_code("telegram", code)
|
||||
approved = store.list_approved("telegram")
|
||||
assert len(approved) == 1
|
||||
assert approved[0]["user_id"] == "user1"
|
||||
assert approved[0]["platform"] == "telegram"
|
||||
|
||||
def test_list_approved_all_platforms(self, tmp_path):
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
store = PairingStore()
|
||||
c1 = store.generate_code("telegram", "user1")
|
||||
store.approve_code("telegram", c1)
|
||||
c2 = store.generate_code("discord", "user2")
|
||||
store.approve_code("discord", c2)
|
||||
approved = store.list_approved()
|
||||
assert len(approved) == 2
|
||||
|
||||
def test_clear_pending(self, tmp_path):
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
store = PairingStore()
|
||||
store.generate_code("telegram", "user1")
|
||||
store.generate_code("telegram", "user2")
|
||||
count = store.clear_pending("telegram")
|
||||
remaining = store.list_pending("telegram")
|
||||
assert count == 2
|
||||
assert len(remaining) == 0
|
||||
|
||||
def test_clear_pending_all_platforms(self, tmp_path):
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
store = PairingStore()
|
||||
store.generate_code("telegram", "user1")
|
||||
store.generate_code("discord", "user2")
|
||||
count = store.clear_pending()
|
||||
assert count == 2
|
||||
156
hermes_code/tests/gateway/test_pii_redaction.py
Normal file
156
hermes_code/tests/gateway/test_pii_redaction.py
Normal file
|
|
@ -0,0 +1,156 @@
|
|||
"""Tests for PII redaction in gateway session context prompts."""
|
||||
|
||||
from gateway.session import (
|
||||
SessionContext,
|
||||
SessionSource,
|
||||
build_session_context_prompt,
|
||||
_hash_id,
|
||||
_hash_sender_id,
|
||||
_hash_chat_id,
|
||||
_looks_like_phone,
|
||||
)
|
||||
from gateway.config import Platform, HomeChannel
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Low-level helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestHashHelpers:
|
||||
def test_hash_id_deterministic(self):
|
||||
assert _hash_id("12345") == _hash_id("12345")
|
||||
|
||||
def test_hash_id_12_hex_chars(self):
|
||||
h = _hash_id("user-abc")
|
||||
assert len(h) == 12
|
||||
assert all(c in "0123456789abcdef" for c in h)
|
||||
|
||||
def test_hash_sender_id_prefix(self):
|
||||
assert _hash_sender_id("12345").startswith("user_")
|
||||
assert len(_hash_sender_id("12345")) == 17 # "user_" + 12
|
||||
|
||||
def test_hash_chat_id_preserves_prefix(self):
|
||||
result = _hash_chat_id("telegram:12345")
|
||||
assert result.startswith("telegram:")
|
||||
assert "12345" not in result
|
||||
|
||||
def test_hash_chat_id_no_prefix(self):
|
||||
result = _hash_chat_id("12345")
|
||||
assert len(result) == 12
|
||||
assert "12345" not in result
|
||||
|
||||
def test_looks_like_phone(self):
|
||||
assert _looks_like_phone("+15551234567")
|
||||
assert _looks_like_phone("15551234567")
|
||||
assert _looks_like_phone("+1-555-123-4567")
|
||||
assert not _looks_like_phone("alice")
|
||||
assert not _looks_like_phone("user-123")
|
||||
assert not _looks_like_phone("")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: build_session_context_prompt
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_context(
|
||||
user_id="user-123",
|
||||
user_name=None,
|
||||
chat_id="telegram:99999",
|
||||
platform=Platform.TELEGRAM,
|
||||
home_channels=None,
|
||||
):
|
||||
source = SessionSource(
|
||||
platform=platform,
|
||||
chat_id=chat_id,
|
||||
chat_type="dm",
|
||||
user_id=user_id,
|
||||
user_name=user_name,
|
||||
)
|
||||
return SessionContext(
|
||||
source=source,
|
||||
connected_platforms=[platform],
|
||||
home_channels=home_channels or {},
|
||||
)
|
||||
|
||||
|
||||
class TestBuildSessionContextPromptRedaction:
|
||||
def test_no_redaction_by_default(self):
|
||||
ctx = _make_context(user_id="user-123")
|
||||
prompt = build_session_context_prompt(ctx)
|
||||
assert "user-123" in prompt
|
||||
|
||||
def test_user_id_hashed_when_redact_pii(self):
|
||||
ctx = _make_context(user_id="user-123")
|
||||
prompt = build_session_context_prompt(ctx, redact_pii=True)
|
||||
assert "user-123" not in prompt
|
||||
assert "user_" in prompt # hashed ID present
|
||||
|
||||
def test_user_name_not_redacted(self):
|
||||
ctx = _make_context(user_id="user-123", user_name="Alice")
|
||||
prompt = build_session_context_prompt(ctx, redact_pii=True)
|
||||
assert "Alice" in prompt
|
||||
# user_id should not appear when user_name is present (name takes priority)
|
||||
assert "user-123" not in prompt
|
||||
|
||||
def test_home_channel_id_hashed(self):
|
||||
hc = {
|
||||
Platform.TELEGRAM: HomeChannel(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="telegram:99999",
|
||||
name="Home Chat",
|
||||
)
|
||||
}
|
||||
ctx = _make_context(home_channels=hc)
|
||||
prompt = build_session_context_prompt(ctx, redact_pii=True)
|
||||
assert "99999" not in prompt
|
||||
assert "telegram:" in prompt # prefix preserved
|
||||
assert "Home Chat" in prompt # name not redacted
|
||||
|
||||
def test_home_channel_id_preserved_without_redaction(self):
|
||||
hc = {
|
||||
Platform.TELEGRAM: HomeChannel(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="telegram:99999",
|
||||
name="Home Chat",
|
||||
)
|
||||
}
|
||||
ctx = _make_context(home_channels=hc)
|
||||
prompt = build_session_context_prompt(ctx, redact_pii=False)
|
||||
assert "99999" in prompt
|
||||
|
||||
def test_redaction_is_deterministic(self):
|
||||
ctx = _make_context(user_id="+15551234567")
|
||||
prompt1 = build_session_context_prompt(ctx, redact_pii=True)
|
||||
prompt2 = build_session_context_prompt(ctx, redact_pii=True)
|
||||
assert prompt1 == prompt2
|
||||
|
||||
def test_different_ids_produce_different_hashes(self):
|
||||
ctx1 = _make_context(user_id="user-A")
|
||||
ctx2 = _make_context(user_id="user-B")
|
||||
p1 = build_session_context_prompt(ctx1, redact_pii=True)
|
||||
p2 = build_session_context_prompt(ctx2, redact_pii=True)
|
||||
assert p1 != p2
|
||||
|
||||
def test_discord_ids_not_redacted_even_with_flag(self):
|
||||
"""Discord needs real IDs for <@user_id> mentions."""
|
||||
ctx = _make_context(user_id="123456789", platform=Platform.DISCORD)
|
||||
prompt = build_session_context_prompt(ctx, redact_pii=True)
|
||||
assert "123456789" in prompt
|
||||
|
||||
def test_whatsapp_ids_redacted(self):
|
||||
ctx = _make_context(user_id="+15551234567", platform=Platform.WHATSAPP)
|
||||
prompt = build_session_context_prompt(ctx, redact_pii=True)
|
||||
assert "+15551234567" not in prompt
|
||||
assert "user_" in prompt
|
||||
|
||||
def test_signal_ids_redacted(self):
|
||||
ctx = _make_context(user_id="+15551234567", platform=Platform.SIGNAL)
|
||||
prompt = build_session_context_prompt(ctx, redact_pii=True)
|
||||
assert "+15551234567" not in prompt
|
||||
assert "user_" in prompt
|
||||
|
||||
def test_slack_ids_not_redacted(self):
|
||||
"""Slack may need IDs for mentions too."""
|
||||
ctx = _make_context(user_id="U12345ABC", platform=Platform.SLACK)
|
||||
prompt = build_session_context_prompt(ctx, redact_pii=True)
|
||||
assert "U12345ABC" in prompt
|
||||
129
hermes_code/tests/gateway/test_plan_command.py
Normal file
129
hermes_code/tests/gateway/test_plan_command.py
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
"""Tests for the /plan gateway slash command."""
|
||||
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from agent.skill_commands import scan_skill_commands
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import MessageEvent
|
||||
from gateway.session import SessionEntry, SessionSource
|
||||
|
||||
|
||||
def _make_runner():
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(
|
||||
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}
|
||||
)
|
||||
runner.adapters = {}
|
||||
runner._voice_mode = {}
|
||||
runner.hooks = SimpleNamespace(emit=AsyncMock(), loaded_hooks=False)
|
||||
runner.session_store = MagicMock()
|
||||
runner.session_store.get_or_create_session.return_value = SessionEntry(
|
||||
session_key="agent:main:telegram:dm:c1:u1",
|
||||
session_id="sess-1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
)
|
||||
runner.session_store.load_transcript.return_value = []
|
||||
runner.session_store.has_any_sessions.return_value = True
|
||||
runner.session_store.append_to_transcript = MagicMock()
|
||||
runner.session_store.rewrite_transcript = MagicMock()
|
||||
runner._running_agents = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._session_db = None
|
||||
runner._reasoning_config = None
|
||||
runner._provider_routing = {}
|
||||
runner._fallback_model = None
|
||||
runner._show_reasoning = False
|
||||
runner._is_user_authorized = lambda _source: True
|
||||
runner._set_session_env = lambda _context: None
|
||||
runner._run_agent = AsyncMock(
|
||||
return_value={
|
||||
"final_response": "planned",
|
||||
"messages": [],
|
||||
"tools": [],
|
||||
"history_offset": 0,
|
||||
"last_prompt_tokens": 0,
|
||||
}
|
||||
)
|
||||
return runner
|
||||
|
||||
|
||||
def _make_event(text="/plan"):
|
||||
return MessageEvent(
|
||||
text=text,
|
||||
source=SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
user_id="u1",
|
||||
chat_id="c1",
|
||||
user_name="tester",
|
||||
chat_type="dm",
|
||||
),
|
||||
message_id="m1",
|
||||
)
|
||||
|
||||
|
||||
def _make_plan_skill(skills_dir):
|
||||
skill_dir = skills_dir / "plan"
|
||||
skill_dir.mkdir(parents=True, exist_ok=True)
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
"""---
|
||||
name: plan
|
||||
description: Plan mode skill.
|
||||
---
|
||||
|
||||
# Plan
|
||||
|
||||
Use the current conversation context when no explicit instruction is provided.
|
||||
Save plans under the active workspace's .hermes/plans directory.
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
class TestGatewayPlanCommand:
|
||||
@pytest.mark.asyncio
|
||||
async def test_plan_command_loads_skill_and_runs_agent(self, monkeypatch, tmp_path):
|
||||
import gateway.run as gateway_run
|
||||
|
||||
runner = _make_runner()
|
||||
event = _make_event("/plan Add OAuth login")
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
|
||||
monkeypatch.setattr(
|
||||
"agent.model_metadata.get_model_context_length",
|
||||
lambda *_args, **_kwargs: 100_000,
|
||||
)
|
||||
|
||||
with patch("tools.skills_tool.SKILLS_DIR", tmp_path):
|
||||
_make_plan_skill(tmp_path)
|
||||
scan_skill_commands()
|
||||
result = await runner._handle_message(event)
|
||||
|
||||
assert result == "planned"
|
||||
forwarded = runner._run_agent.call_args.kwargs["message"]
|
||||
assert "Plan mode skill" in forwarded
|
||||
assert "Add OAuth login" in forwarded
|
||||
assert ".hermes/plans" in forwarded
|
||||
assert str(tmp_path / "plans") not in forwarded
|
||||
assert "active workspace/backend cwd" in forwarded
|
||||
assert "Runtime note:" in forwarded
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plan_command_appears_in_help_output_via_skill_listing(self, tmp_path):
|
||||
runner = _make_runner()
|
||||
event = _make_event("/help")
|
||||
|
||||
with patch("tools.skills_tool.SKILLS_DIR", tmp_path):
|
||||
_make_plan_skill(tmp_path)
|
||||
scan_skill_commands()
|
||||
result = await runner._handle_help_command(event)
|
||||
|
||||
assert "/plan" in result
|
||||
412
hermes_code/tests/gateway/test_platform_base.py
Normal file
412
hermes_code/tests/gateway/test_platform_base.py
Normal file
|
|
@ -0,0 +1,412 @@
|
|||
"""Tests for gateway/platforms/base.py — MessageEvent, media extraction, message truncation."""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
GATEWAY_SECRET_CAPTURE_UNSUPPORTED_MESSAGE,
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
)
|
||||
|
||||
|
||||
class TestSecretCaptureGuidance:
|
||||
def test_gateway_secret_capture_message_points_to_local_setup(self):
|
||||
message = GATEWAY_SECRET_CAPTURE_UNSUPPORTED_MESSAGE
|
||||
assert "local cli" in message.lower()
|
||||
assert "~/.hermes/.env" in message
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MessageEvent — command parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMessageEventIsCommand:
|
||||
def test_slash_command(self):
|
||||
event = MessageEvent(text="/new")
|
||||
assert event.is_command() is True
|
||||
|
||||
def test_regular_text(self):
|
||||
event = MessageEvent(text="hello world")
|
||||
assert event.is_command() is False
|
||||
|
||||
def test_empty_text(self):
|
||||
event = MessageEvent(text="")
|
||||
assert event.is_command() is False
|
||||
|
||||
def test_slash_only(self):
|
||||
event = MessageEvent(text="/")
|
||||
assert event.is_command() is True
|
||||
|
||||
|
||||
class TestMessageEventGetCommand:
|
||||
def test_simple_command(self):
|
||||
event = MessageEvent(text="/new")
|
||||
assert event.get_command() == "new"
|
||||
|
||||
def test_command_with_args(self):
|
||||
event = MessageEvent(text="/reset session")
|
||||
assert event.get_command() == "reset"
|
||||
|
||||
def test_not_a_command(self):
|
||||
event = MessageEvent(text="hello")
|
||||
assert event.get_command() is None
|
||||
|
||||
def test_command_is_lowercased(self):
|
||||
event = MessageEvent(text="/HELP")
|
||||
assert event.get_command() == "help"
|
||||
|
||||
def test_slash_only_returns_empty(self):
|
||||
event = MessageEvent(text="/")
|
||||
assert event.get_command() == ""
|
||||
|
||||
|
||||
class TestMessageEventGetCommandArgs:
|
||||
def test_command_with_args(self):
|
||||
event = MessageEvent(text="/new session id 123")
|
||||
assert event.get_command_args() == "session id 123"
|
||||
|
||||
def test_command_without_args(self):
|
||||
event = MessageEvent(text="/new")
|
||||
assert event.get_command_args() == ""
|
||||
|
||||
def test_not_a_command_returns_full_text(self):
|
||||
event = MessageEvent(text="hello world")
|
||||
assert event.get_command_args() == "hello world"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# extract_images
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExtractImages:
|
||||
def test_no_images(self):
|
||||
images, cleaned = BasePlatformAdapter.extract_images("Just regular text.")
|
||||
assert images == []
|
||||
assert cleaned == "Just regular text."
|
||||
|
||||
def test_markdown_image_with_image_ext(self):
|
||||
content = "Here is a photo: "
|
||||
images, cleaned = BasePlatformAdapter.extract_images(content)
|
||||
assert len(images) == 1
|
||||
assert images[0][0] == "https://example.com/cat.png"
|
||||
assert images[0][1] == "cat"
|
||||
assert "![cat]" not in cleaned
|
||||
|
||||
def test_markdown_image_jpg(self):
|
||||
content = ""
|
||||
images, _ = BasePlatformAdapter.extract_images(content)
|
||||
assert len(images) == 1
|
||||
assert images[0][0] == "https://example.com/photo.jpg"
|
||||
assert images[0][1] == "photo"
|
||||
|
||||
def test_markdown_image_jpeg(self):
|
||||
content = ""
|
||||
images, _ = BasePlatformAdapter.extract_images(content)
|
||||
assert len(images) == 1
|
||||
assert images[0][0] == "https://example.com/photo.jpeg"
|
||||
assert images[0][1] == ""
|
||||
|
||||
def test_markdown_image_gif(self):
|
||||
content = ""
|
||||
images, _ = BasePlatformAdapter.extract_images(content)
|
||||
assert len(images) == 1
|
||||
assert images[0][0] == "https://example.com/anim.gif"
|
||||
assert images[0][1] == "anim"
|
||||
|
||||
def test_markdown_image_webp(self):
|
||||
content = ""
|
||||
images, _ = BasePlatformAdapter.extract_images(content)
|
||||
assert len(images) == 1
|
||||
assert images[0][0] == "https://example.com/img.webp"
|
||||
assert images[0][1] == ""
|
||||
|
||||
def test_fal_media_cdn(self):
|
||||
content = ""
|
||||
images, _ = BasePlatformAdapter.extract_images(content)
|
||||
assert len(images) == 1
|
||||
assert images[0][0] == "https://fal.media/files/abc123/output.png"
|
||||
assert images[0][1] == "gen"
|
||||
|
||||
def test_fal_cdn_url(self):
|
||||
content = ""
|
||||
images, _ = BasePlatformAdapter.extract_images(content)
|
||||
assert len(images) == 1
|
||||
assert images[0][0] == "https://fal-cdn.example.com/result"
|
||||
assert images[0][1] == ""
|
||||
|
||||
def test_replicate_delivery(self):
|
||||
content = ""
|
||||
images, _ = BasePlatformAdapter.extract_images(content)
|
||||
assert len(images) == 1
|
||||
assert images[0][0] == "https://replicate.delivery/pbxt/abc/output"
|
||||
assert images[0][1] == ""
|
||||
|
||||
def test_non_image_ext_not_extracted(self):
|
||||
"""Markdown image with non-image extension should not be extracted."""
|
||||
content = ""
|
||||
images, cleaned = BasePlatformAdapter.extract_images(content)
|
||||
assert images == []
|
||||
assert "![doc]" in cleaned # Should be preserved
|
||||
|
||||
def test_html_img_tag(self):
|
||||
content = 'Check this: <img src="https://example.com/photo.png">'
|
||||
images, cleaned = BasePlatformAdapter.extract_images(content)
|
||||
assert len(images) == 1
|
||||
assert images[0][0] == "https://example.com/photo.png"
|
||||
assert images[0][1] == "" # HTML images have no alt text
|
||||
assert "<img" not in cleaned
|
||||
|
||||
def test_html_img_self_closing(self):
|
||||
content = '<img src="https://example.com/photo.png"/>'
|
||||
images, _ = BasePlatformAdapter.extract_images(content)
|
||||
assert len(images) == 1
|
||||
assert images[0][0] == "https://example.com/photo.png"
|
||||
assert images[0][1] == ""
|
||||
|
||||
def test_html_img_with_closing_tag(self):
|
||||
content = '<img src="https://example.com/photo.png"></img>'
|
||||
images, _ = BasePlatformAdapter.extract_images(content)
|
||||
assert len(images) == 1
|
||||
assert images[0][0] == "https://example.com/photo.png"
|
||||
assert images[0][1] == ""
|
||||
|
||||
def test_multiple_images(self):
|
||||
content = "\n"
|
||||
images, cleaned = BasePlatformAdapter.extract_images(content)
|
||||
assert len(images) == 2
|
||||
assert "![a]" not in cleaned
|
||||
assert "![b]" not in cleaned
|
||||
|
||||
def test_mixed_markdown_and_html(self):
|
||||
content = '\n<img src="https://example.com/dog.jpg">'
|
||||
images, _ = BasePlatformAdapter.extract_images(content)
|
||||
assert len(images) == 2
|
||||
|
||||
def test_cleaned_content_trims_excess_newlines(self):
|
||||
content = "Before\n\n\n\n\n\nAfter"
|
||||
_, cleaned = BasePlatformAdapter.extract_images(content)
|
||||
assert "\n\n\n" not in cleaned
|
||||
|
||||
def test_non_http_url_not_matched(self):
|
||||
content = ""
|
||||
images, _ = BasePlatformAdapter.extract_images(content)
|
||||
assert images == []
|
||||
|
||||
def test_non_image_link_preserved_when_mixed_with_images(self):
|
||||
"""Regression: non-image markdown links must not be silently removed
|
||||
when the response also contains real images."""
|
||||
content = (
|
||||
"Here is the image: \n"
|
||||
"And a doc: "
|
||||
)
|
||||
images, cleaned = BasePlatformAdapter.extract_images(content)
|
||||
assert len(images) == 1
|
||||
assert images[0][0] == "https://fal.media/cat.png"
|
||||
# The PDF link must survive in cleaned content
|
||||
assert "" in cleaned
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# extract_media
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExtractMedia:
|
||||
def test_no_media(self):
|
||||
media, cleaned = BasePlatformAdapter.extract_media("Just text.")
|
||||
assert media == []
|
||||
assert cleaned == "Just text."
|
||||
|
||||
def test_single_media_tag(self):
|
||||
content = "MEDIA:/path/to/audio.ogg"
|
||||
media, cleaned = BasePlatformAdapter.extract_media(content)
|
||||
assert len(media) == 1
|
||||
assert media[0][0] == "/path/to/audio.ogg"
|
||||
assert media[0][1] is False # no voice tag
|
||||
|
||||
def test_media_with_voice_directive(self):
|
||||
content = "[[audio_as_voice]]\nMEDIA:/path/to/voice.ogg"
|
||||
media, cleaned = BasePlatformAdapter.extract_media(content)
|
||||
assert len(media) == 1
|
||||
assert media[0][0] == "/path/to/voice.ogg"
|
||||
assert media[0][1] is True # voice tag present
|
||||
|
||||
def test_multiple_media_tags(self):
|
||||
content = "MEDIA:/a.ogg\nMEDIA:/b.ogg"
|
||||
media, _ = BasePlatformAdapter.extract_media(content)
|
||||
assert len(media) == 2
|
||||
|
||||
def test_voice_directive_removed_from_content(self):
|
||||
content = "[[audio_as_voice]]\nSome text\nMEDIA:/voice.ogg"
|
||||
_, cleaned = BasePlatformAdapter.extract_media(content)
|
||||
assert "[[audio_as_voice]]" not in cleaned
|
||||
assert "MEDIA:" not in cleaned
|
||||
assert "Some text" in cleaned
|
||||
|
||||
def test_media_with_text_before(self):
|
||||
content = "Here is your audio:\nMEDIA:/output.ogg"
|
||||
media, cleaned = BasePlatformAdapter.extract_media(content)
|
||||
assert len(media) == 1
|
||||
assert "Here is your audio" in cleaned
|
||||
|
||||
def test_cleaned_content_trims_excess_newlines(self):
|
||||
content = "Before\n\nMEDIA:/audio.ogg\n\n\n\nAfter"
|
||||
_, cleaned = BasePlatformAdapter.extract_media(content)
|
||||
assert "\n\n\n" not in cleaned
|
||||
|
||||
def test_media_tag_allows_optional_whitespace_after_colon(self):
|
||||
content = "MEDIA: /path/to/audio.ogg"
|
||||
media, cleaned = BasePlatformAdapter.extract_media(content)
|
||||
assert media == [("/path/to/audio.ogg", False)]
|
||||
assert cleaned == ""
|
||||
|
||||
def test_media_tag_strips_wrapping_quotes_and_backticks(self):
|
||||
content = "MEDIA: `/path/to/file.png`\nMEDIA:\"/path/to/file2.png\"\nMEDIA:'/path/to/file3.png'"
|
||||
media, cleaned = BasePlatformAdapter.extract_media(content)
|
||||
assert media == [
|
||||
("/path/to/file.png", False),
|
||||
("/path/to/file2.png", False),
|
||||
("/path/to/file3.png", False),
|
||||
]
|
||||
assert cleaned == ""
|
||||
|
||||
def test_media_tag_supports_quoted_paths_with_spaces(self):
|
||||
content = "Here\nMEDIA: '/tmp/my image.png'\nAfter"
|
||||
media, cleaned = BasePlatformAdapter.extract_media(content)
|
||||
assert media == [("/tmp/my image.png", False)]
|
||||
assert "Here" in cleaned
|
||||
assert "After" in cleaned
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# truncate_message
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTruncateMessage:
|
||||
def _adapter(self):
|
||||
"""Create a minimal adapter instance for testing static/instance methods."""
|
||||
|
||||
class StubAdapter(BasePlatformAdapter):
|
||||
async def connect(self):
|
||||
return True
|
||||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
async def send(self, *a, **kw):
|
||||
pass
|
||||
|
||||
async def get_chat_info(self, *a):
|
||||
return {}
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
|
||||
config = PlatformConfig(enabled=True, token="test")
|
||||
return StubAdapter(config=config, platform=Platform.TELEGRAM)
|
||||
|
||||
def test_short_message_single_chunk(self):
|
||||
adapter = self._adapter()
|
||||
chunks = adapter.truncate_message("Hello world", max_length=100)
|
||||
assert chunks == ["Hello world"]
|
||||
|
||||
def test_exact_length_single_chunk(self):
|
||||
adapter = self._adapter()
|
||||
msg = "x" * 100
|
||||
chunks = adapter.truncate_message(msg, max_length=100)
|
||||
assert chunks == [msg]
|
||||
|
||||
def test_long_message_splits(self):
|
||||
adapter = self._adapter()
|
||||
msg = "word " * 200 # ~1000 chars
|
||||
chunks = adapter.truncate_message(msg, max_length=200)
|
||||
assert len(chunks) > 1
|
||||
# Verify all original content is preserved across chunks
|
||||
reassembled = "".join(chunks)
|
||||
# Strip chunk indicators like (1/N) to get raw content
|
||||
for word in msg.strip().split():
|
||||
assert word in reassembled, f"Word '{word}' lost during truncation"
|
||||
|
||||
def test_chunks_have_indicators(self):
|
||||
adapter = self._adapter()
|
||||
msg = "word " * 200
|
||||
chunks = adapter.truncate_message(msg, max_length=200)
|
||||
assert "(1/" in chunks[0]
|
||||
assert f"({len(chunks)}/{len(chunks)})" in chunks[-1]
|
||||
|
||||
def test_code_block_first_chunk_closed(self):
|
||||
adapter = self._adapter()
|
||||
msg = "Before\n```python\n" + "x = 1\n" * 100 + "```\nAfter"
|
||||
chunks = adapter.truncate_message(msg, max_length=300)
|
||||
assert len(chunks) > 1
|
||||
# First chunk must have a closing fence appended (code block was split)
|
||||
first_fences = chunks[0].count("```")
|
||||
assert first_fences == 2, "First chunk should have opening + closing fence"
|
||||
|
||||
def test_code_block_language_tag_carried(self):
|
||||
adapter = self._adapter()
|
||||
msg = "Start\n```javascript\n" + "console.log('x');\n" * 80 + "```\nEnd"
|
||||
chunks = adapter.truncate_message(msg, max_length=300)
|
||||
if len(chunks) > 1:
|
||||
# At least one continuation chunk should reopen with ```javascript
|
||||
reopened_with_lang = any("```javascript" in chunk for chunk in chunks[1:])
|
||||
assert reopened_with_lang, (
|
||||
"No continuation chunk reopened with language tag"
|
||||
)
|
||||
|
||||
def test_continuation_chunks_have_balanced_fences(self):
|
||||
"""Regression: continuation chunks must close reopened code blocks."""
|
||||
adapter = self._adapter()
|
||||
msg = "Before\n```python\n" + "x = 1\n" * 100 + "```\nAfter"
|
||||
chunks = adapter.truncate_message(msg, max_length=300)
|
||||
assert len(chunks) > 1
|
||||
for i, chunk in enumerate(chunks):
|
||||
fence_count = chunk.count("```")
|
||||
assert fence_count % 2 == 0, (
|
||||
f"Chunk {i} has unbalanced fences ({fence_count})"
|
||||
)
|
||||
|
||||
def test_each_chunk_under_max_length(self):
|
||||
adapter = self._adapter()
|
||||
msg = "word " * 500
|
||||
max_len = 200
|
||||
chunks = adapter.truncate_message(msg, max_length=max_len)
|
||||
for i, chunk in enumerate(chunks):
|
||||
assert len(chunk) <= max_len + 20, (
|
||||
f"Chunk {i} too long: {len(chunk)} > {max_len}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _get_human_delay
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetHumanDelay:
|
||||
def test_off_mode(self):
|
||||
with patch.dict(os.environ, {"HERMES_HUMAN_DELAY_MODE": "off"}):
|
||||
assert BasePlatformAdapter._get_human_delay() == 0.0
|
||||
|
||||
def test_default_is_off(self):
|
||||
with patch.dict(os.environ, {}, clear=False):
|
||||
os.environ.pop("HERMES_HUMAN_DELAY_MODE", None)
|
||||
assert BasePlatformAdapter._get_human_delay() == 0.0
|
||||
|
||||
def test_natural_mode_range(self):
|
||||
with patch.dict(os.environ, {"HERMES_HUMAN_DELAY_MODE": "natural"}):
|
||||
delay = BasePlatformAdapter._get_human_delay()
|
||||
assert 0.8 <= delay <= 2.5
|
||||
|
||||
def test_custom_mode_uses_env_vars(self):
|
||||
env = {
|
||||
"HERMES_HUMAN_DELAY_MODE": "custom",
|
||||
"HERMES_HUMAN_DELAY_MIN_MS": "100",
|
||||
"HERMES_HUMAN_DELAY_MAX_MS": "200",
|
||||
}
|
||||
with patch.dict(os.environ, env):
|
||||
delay = BasePlatformAdapter._get_human_delay()
|
||||
assert 0.1 <= delay <= 0.2
|
||||
401
hermes_code/tests/gateway/test_platform_reconnect.py
Normal file
401
hermes_code/tests/gateway/test_platform_reconnect.py
Normal file
|
|
@ -0,0 +1,401 @@
|
|||
"""Tests for the gateway platform reconnection watcher."""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, SendResult
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
|
||||
class StubAdapter(BasePlatformAdapter):
|
||||
"""Adapter whose connect() result can be controlled."""
|
||||
|
||||
def __init__(self, *, succeed=True, fatal_error=None, fatal_retryable=True):
|
||||
super().__init__(PlatformConfig(enabled=True, token="test"), Platform.TELEGRAM)
|
||||
self._succeed = succeed
|
||||
self._fatal_error = fatal_error
|
||||
self._fatal_retryable = fatal_retryable
|
||||
|
||||
async def connect(self):
|
||||
if self._fatal_error:
|
||||
self._set_fatal_error("test_error", self._fatal_error, retryable=self._fatal_retryable)
|
||||
return False
|
||||
return self._succeed
|
||||
|
||||
async def disconnect(self):
|
||||
return None
|
||||
|
||||
async def send(self, chat_id, content, reply_to=None, metadata=None):
|
||||
return SendResult(success=True, message_id="1")
|
||||
|
||||
async def send_typing(self, chat_id, metadata=None):
|
||||
return None
|
||||
|
||||
async def get_chat_info(self, chat_id):
|
||||
return {"id": chat_id}
|
||||
|
||||
|
||||
def _make_runner():
|
||||
"""Create a minimal GatewayRunner via object.__new__ to skip __init__."""
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(
|
||||
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="test")}
|
||||
)
|
||||
runner._running = True
|
||||
runner._shutdown_event = asyncio.Event()
|
||||
runner._exit_reason = None
|
||||
runner._exit_with_failure = False
|
||||
runner._exit_cleanly = False
|
||||
runner._failed_platforms = {}
|
||||
runner.adapters = {}
|
||||
runner.delivery_router = MagicMock()
|
||||
runner._running_agents = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._honcho_managers = {}
|
||||
runner._honcho_configs = {}
|
||||
runner._shutdown_all_gateway_honcho = lambda: None
|
||||
return runner
|
||||
|
||||
|
||||
# --- Startup queueing ---
|
||||
|
||||
class TestStartupFailureQueuing:
|
||||
"""Verify that failed platforms are queued during startup."""
|
||||
|
||||
def test_failed_platform_queued_on_connect_failure(self):
|
||||
"""When adapter.connect() returns False without fatal error, queue for retry."""
|
||||
runner = _make_runner()
|
||||
platform_config = PlatformConfig(enabled=True, token="test")
|
||||
runner._failed_platforms[Platform.TELEGRAM] = {
|
||||
"config": platform_config,
|
||||
"attempts": 1,
|
||||
"next_retry": time.monotonic() + 30,
|
||||
}
|
||||
assert Platform.TELEGRAM in runner._failed_platforms
|
||||
assert runner._failed_platforms[Platform.TELEGRAM]["attempts"] == 1
|
||||
|
||||
def test_failed_platform_not_queued_for_nonretryable(self):
|
||||
"""Non-retryable errors should not be in the retry queue."""
|
||||
runner = _make_runner()
|
||||
# Simulate: adapter had a non-retryable error, wasn't queued
|
||||
assert Platform.TELEGRAM not in runner._failed_platforms
|
||||
|
||||
|
||||
# --- Reconnect watcher ---
|
||||
|
||||
class TestPlatformReconnectWatcher:
|
||||
"""Test the _platform_reconnect_watcher background task."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect_succeeds_on_retry(self):
|
||||
"""Watcher should reconnect a failed platform when connect() succeeds."""
|
||||
runner = _make_runner()
|
||||
runner._sync_voice_mode_state_to_adapter = MagicMock()
|
||||
|
||||
platform_config = PlatformConfig(enabled=True, token="test")
|
||||
runner._failed_platforms[Platform.TELEGRAM] = {
|
||||
"config": platform_config,
|
||||
"attempts": 1,
|
||||
"next_retry": time.monotonic() - 1, # Already past retry time
|
||||
}
|
||||
|
||||
succeed_adapter = StubAdapter(succeed=True)
|
||||
real_sleep = asyncio.sleep
|
||||
|
||||
with patch.object(runner, "_create_adapter", return_value=succeed_adapter):
|
||||
with patch("gateway.run.build_channel_directory", create=True):
|
||||
# Run one iteration of the watcher then stop
|
||||
async def run_one_iteration():
|
||||
runner._running = True
|
||||
# Patch the sleep to exit after first check
|
||||
call_count = 0
|
||||
|
||||
async def fake_sleep(n):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count > 1:
|
||||
runner._running = False
|
||||
await real_sleep(0)
|
||||
|
||||
with patch("asyncio.sleep", side_effect=fake_sleep):
|
||||
await runner._platform_reconnect_watcher()
|
||||
|
||||
await run_one_iteration()
|
||||
|
||||
assert Platform.TELEGRAM not in runner._failed_platforms
|
||||
assert Platform.TELEGRAM in runner.adapters
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect_nonretryable_removed_from_queue(self):
|
||||
"""Non-retryable errors should remove the platform from the retry queue."""
|
||||
runner = _make_runner()
|
||||
|
||||
platform_config = PlatformConfig(enabled=True, token="test")
|
||||
runner._failed_platforms[Platform.TELEGRAM] = {
|
||||
"config": platform_config,
|
||||
"attempts": 1,
|
||||
"next_retry": time.monotonic() - 1,
|
||||
}
|
||||
|
||||
fail_adapter = StubAdapter(
|
||||
succeed=False, fatal_error="bad token", fatal_retryable=False
|
||||
)
|
||||
|
||||
real_sleep = asyncio.sleep
|
||||
|
||||
with patch.object(runner, "_create_adapter", return_value=fail_adapter):
|
||||
async def run_one_iteration():
|
||||
runner._running = True
|
||||
call_count = 0
|
||||
|
||||
async def fake_sleep(n):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count > 1:
|
||||
runner._running = False
|
||||
await real_sleep(0)
|
||||
|
||||
with patch("asyncio.sleep", side_effect=fake_sleep):
|
||||
await runner._platform_reconnect_watcher()
|
||||
|
||||
await run_one_iteration()
|
||||
|
||||
assert Platform.TELEGRAM not in runner._failed_platforms
|
||||
assert Platform.TELEGRAM not in runner.adapters
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect_retryable_stays_in_queue(self):
|
||||
"""Retryable failures should remain in the queue with incremented attempts."""
|
||||
runner = _make_runner()
|
||||
|
||||
platform_config = PlatformConfig(enabled=True, token="test")
|
||||
runner._failed_platforms[Platform.TELEGRAM] = {
|
||||
"config": platform_config,
|
||||
"attempts": 1,
|
||||
"next_retry": time.monotonic() - 1,
|
||||
}
|
||||
|
||||
fail_adapter = StubAdapter(
|
||||
succeed=False, fatal_error="DNS failure", fatal_retryable=True
|
||||
)
|
||||
|
||||
real_sleep = asyncio.sleep
|
||||
|
||||
with patch.object(runner, "_create_adapter", return_value=fail_adapter):
|
||||
async def run_one_iteration():
|
||||
runner._running = True
|
||||
call_count = 0
|
||||
|
||||
async def fake_sleep(n):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count > 1:
|
||||
runner._running = False
|
||||
await real_sleep(0)
|
||||
|
||||
with patch("asyncio.sleep", side_effect=fake_sleep):
|
||||
await runner._platform_reconnect_watcher()
|
||||
|
||||
await run_one_iteration()
|
||||
|
||||
assert Platform.TELEGRAM in runner._failed_platforms
|
||||
assert runner._failed_platforms[Platform.TELEGRAM]["attempts"] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect_gives_up_after_max_attempts(self):
|
||||
"""After max attempts, platform should be removed from retry queue."""
|
||||
runner = _make_runner()
|
||||
|
||||
platform_config = PlatformConfig(enabled=True, token="test")
|
||||
runner._failed_platforms[Platform.TELEGRAM] = {
|
||||
"config": platform_config,
|
||||
"attempts": 20, # At max
|
||||
"next_retry": time.monotonic() - 1,
|
||||
}
|
||||
|
||||
real_sleep = asyncio.sleep
|
||||
|
||||
with patch.object(runner, "_create_adapter") as mock_create:
|
||||
async def run_one_iteration():
|
||||
runner._running = True
|
||||
call_count = 0
|
||||
|
||||
async def fake_sleep(n):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count > 1:
|
||||
runner._running = False
|
||||
await real_sleep(0)
|
||||
|
||||
with patch("asyncio.sleep", side_effect=fake_sleep):
|
||||
await runner._platform_reconnect_watcher()
|
||||
|
||||
await run_one_iteration()
|
||||
|
||||
assert Platform.TELEGRAM not in runner._failed_platforms
|
||||
mock_create.assert_not_called() # Should give up without trying
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect_skips_when_not_time_yet(self):
|
||||
"""Watcher should skip platforms whose next_retry is in the future."""
|
||||
runner = _make_runner()
|
||||
|
||||
platform_config = PlatformConfig(enabled=True, token="test")
|
||||
runner._failed_platforms[Platform.TELEGRAM] = {
|
||||
"config": platform_config,
|
||||
"attempts": 1,
|
||||
"next_retry": time.monotonic() + 9999, # Far in the future
|
||||
}
|
||||
|
||||
real_sleep = asyncio.sleep
|
||||
|
||||
with patch.object(runner, "_create_adapter") as mock_create:
|
||||
async def run_one_iteration():
|
||||
runner._running = True
|
||||
call_count = 0
|
||||
|
||||
async def fake_sleep(n):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count > 1:
|
||||
runner._running = False
|
||||
await real_sleep(0)
|
||||
|
||||
with patch("asyncio.sleep", side_effect=fake_sleep):
|
||||
await runner._platform_reconnect_watcher()
|
||||
|
||||
await run_one_iteration()
|
||||
|
||||
assert Platform.TELEGRAM in runner._failed_platforms
|
||||
mock_create.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_failed_platforms_watcher_idles(self):
|
||||
"""When no platforms are failed, watcher should just idle."""
|
||||
runner = _make_runner()
|
||||
# No failed platforms
|
||||
|
||||
real_sleep = asyncio.sleep
|
||||
|
||||
with patch.object(runner, "_create_adapter") as mock_create:
|
||||
async def run_briefly():
|
||||
runner._running = True
|
||||
call_count = 0
|
||||
|
||||
async def fake_sleep(n):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count > 2:
|
||||
runner._running = False
|
||||
await real_sleep(0)
|
||||
|
||||
with patch("asyncio.sleep", side_effect=fake_sleep):
|
||||
await runner._platform_reconnect_watcher()
|
||||
|
||||
await run_briefly()
|
||||
|
||||
mock_create.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_adapter_create_returns_none(self):
|
||||
"""If _create_adapter returns None, remove from queue (missing deps)."""
|
||||
runner = _make_runner()
|
||||
|
||||
platform_config = PlatformConfig(enabled=True, token="test")
|
||||
runner._failed_platforms[Platform.TELEGRAM] = {
|
||||
"config": platform_config,
|
||||
"attempts": 1,
|
||||
"next_retry": time.monotonic() - 1,
|
||||
}
|
||||
|
||||
real_sleep = asyncio.sleep
|
||||
|
||||
with patch.object(runner, "_create_adapter", return_value=None):
|
||||
async def run_one_iteration():
|
||||
runner._running = True
|
||||
call_count = 0
|
||||
|
||||
async def fake_sleep(n):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count > 1:
|
||||
runner._running = False
|
||||
await real_sleep(0)
|
||||
|
||||
with patch("asyncio.sleep", side_effect=fake_sleep):
|
||||
await runner._platform_reconnect_watcher()
|
||||
|
||||
await run_one_iteration()
|
||||
|
||||
assert Platform.TELEGRAM not in runner._failed_platforms
|
||||
|
||||
|
||||
# --- Runtime disconnection queueing ---
|
||||
|
||||
class TestRuntimeDisconnectQueuing:
|
||||
"""Test that _handle_adapter_fatal_error queues retryable disconnections."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retryable_runtime_error_queued_for_reconnect(self):
|
||||
"""Retryable runtime errors should add the platform to _failed_platforms."""
|
||||
runner = _make_runner()
|
||||
|
||||
adapter = StubAdapter(succeed=True)
|
||||
adapter._set_fatal_error("network_error", "DNS failure", retryable=True)
|
||||
runner.adapters[Platform.TELEGRAM] = adapter
|
||||
|
||||
await runner._handle_adapter_fatal_error(adapter)
|
||||
|
||||
assert Platform.TELEGRAM in runner._failed_platforms
|
||||
assert runner._failed_platforms[Platform.TELEGRAM]["attempts"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nonretryable_runtime_error_not_queued(self):
|
||||
"""Non-retryable runtime errors should not be queued for reconnection."""
|
||||
runner = _make_runner()
|
||||
|
||||
adapter = StubAdapter(succeed=True)
|
||||
adapter._set_fatal_error("auth_error", "bad token", retryable=False)
|
||||
runner.adapters[Platform.TELEGRAM] = adapter
|
||||
|
||||
# Need to prevent stop() from running fully
|
||||
runner.stop = AsyncMock()
|
||||
|
||||
await runner._handle_adapter_fatal_error(adapter)
|
||||
|
||||
assert Platform.TELEGRAM not in runner._failed_platforms
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retryable_error_prevents_shutdown_when_queued(self):
|
||||
"""Gateway should not shut down if failed platforms are queued for reconnection."""
|
||||
runner = _make_runner()
|
||||
runner.stop = AsyncMock()
|
||||
|
||||
adapter = StubAdapter(succeed=True)
|
||||
adapter._set_fatal_error("network_error", "DNS failure", retryable=True)
|
||||
runner.adapters[Platform.TELEGRAM] = adapter
|
||||
|
||||
await runner._handle_adapter_fatal_error(adapter)
|
||||
|
||||
# stop() should NOT have been called since we have platforms queued
|
||||
runner.stop.assert_not_called()
|
||||
assert Platform.TELEGRAM in runner._failed_platforms
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nonretryable_error_triggers_shutdown(self):
|
||||
"""Gateway should shut down when no adapters remain and nothing is queued."""
|
||||
runner = _make_runner()
|
||||
runner.stop = AsyncMock()
|
||||
|
||||
adapter = StubAdapter(succeed=True)
|
||||
adapter._set_fatal_error("auth_error", "bad token", retryable=False)
|
||||
runner.adapters[Platform.TELEGRAM] = adapter
|
||||
|
||||
await runner._handle_adapter_fatal_error(adapter)
|
||||
|
||||
runner.stop.assert_called_once()
|
||||
165
hermes_code/tests/gateway/test_queue_consumption.py
Normal file
165
hermes_code/tests/gateway/test_queue_consumption.py
Normal file
|
|
@ -0,0 +1,165 @@
|
|||
"""Tests for /queue message consumption after normal agent completion.
|
||||
|
||||
Verifies that messages queued via /queue (which store in
|
||||
adapter._pending_messages WITHOUT triggering an interrupt) are consumed
|
||||
after the agent finishes its current task — not silently dropped.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
PlatformConfig,
|
||||
Platform,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Minimal adapter for testing pending message storage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _StubAdapter(BasePlatformAdapter):
|
||||
def __init__(self):
|
||||
super().__init__(PlatformConfig(enabled=True, token="test"), Platform.TELEGRAM)
|
||||
|
||||
async def connect(self) -> bool:
|
||||
return True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
self._mark_disconnected()
|
||||
|
||||
async def send(self, chat_id, content, reply_to=None, metadata=None):
|
||||
from gateway.platforms.base import SendResult
|
||||
return SendResult(success=True, message_id="msg-1")
|
||||
|
||||
async def get_chat_info(self, chat_id):
|
||||
return {"id": chat_id, "type": "dm"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestQueueMessageStorage:
|
||||
"""Verify /queue stores messages correctly in adapter._pending_messages."""
|
||||
|
||||
def test_queue_stores_message_in_pending(self):
|
||||
adapter = _StubAdapter()
|
||||
session_key = "telegram:user:123"
|
||||
event = MessageEvent(
|
||||
text="do this next",
|
||||
message_type=MessageType.TEXT,
|
||||
source=MagicMock(chat_id="123", platform=Platform.TELEGRAM),
|
||||
message_id="q1",
|
||||
)
|
||||
adapter._pending_messages[session_key] = event
|
||||
|
||||
assert session_key in adapter._pending_messages
|
||||
assert adapter._pending_messages[session_key].text == "do this next"
|
||||
|
||||
def test_get_pending_message_consumes_and_clears(self):
|
||||
adapter = _StubAdapter()
|
||||
session_key = "telegram:user:123"
|
||||
event = MessageEvent(
|
||||
text="queued prompt",
|
||||
message_type=MessageType.TEXT,
|
||||
source=MagicMock(chat_id="123", platform=Platform.TELEGRAM),
|
||||
message_id="q2",
|
||||
)
|
||||
adapter._pending_messages[session_key] = event
|
||||
|
||||
retrieved = adapter.get_pending_message(session_key)
|
||||
assert retrieved is not None
|
||||
assert retrieved.text == "queued prompt"
|
||||
# Should be consumed (cleared)
|
||||
assert adapter.get_pending_message(session_key) is None
|
||||
|
||||
def test_queue_does_not_set_interrupt_event(self):
|
||||
"""The whole point of /queue — no interrupt signal."""
|
||||
adapter = _StubAdapter()
|
||||
session_key = "telegram:user:123"
|
||||
|
||||
# Simulate an active session (agent running)
|
||||
adapter._active_sessions[session_key] = asyncio.Event()
|
||||
|
||||
# Store a queued message (what /queue does)
|
||||
event = MessageEvent(
|
||||
text="queued",
|
||||
message_type=MessageType.TEXT,
|
||||
source=MagicMock(),
|
||||
message_id="q3",
|
||||
)
|
||||
adapter._pending_messages[session_key] = event
|
||||
|
||||
# The interrupt event should NOT be set
|
||||
assert not adapter._active_sessions[session_key].is_set()
|
||||
assert not adapter.has_pending_interrupt(session_key)
|
||||
|
||||
def test_regular_message_sets_interrupt_event(self):
|
||||
"""Contrast: regular messages DO trigger interrupt."""
|
||||
adapter = _StubAdapter()
|
||||
session_key = "telegram:user:123"
|
||||
|
||||
adapter._active_sessions[session_key] = asyncio.Event()
|
||||
|
||||
# Simulate regular message arrival (what handle_message does)
|
||||
event = MessageEvent(
|
||||
text="new message",
|
||||
message_type=MessageType.TEXT,
|
||||
source=MagicMock(),
|
||||
message_id="m1",
|
||||
)
|
||||
adapter._pending_messages[session_key] = event
|
||||
adapter._active_sessions[session_key].set() # this is what handle_message does
|
||||
|
||||
assert adapter.has_pending_interrupt(session_key)
|
||||
|
||||
|
||||
class TestQueueConsumptionAfterCompletion:
|
||||
"""Verify that pending messages are consumed after normal completion."""
|
||||
|
||||
def test_pending_message_available_after_normal_completion(self):
|
||||
"""After agent finishes without interrupt, pending message should
|
||||
still be retrievable from adapter._pending_messages."""
|
||||
adapter = _StubAdapter()
|
||||
session_key = "telegram:user:123"
|
||||
|
||||
# Simulate: agent starts, /queue stores a message, agent finishes
|
||||
adapter._active_sessions[session_key] = asyncio.Event()
|
||||
event = MessageEvent(
|
||||
text="process this after",
|
||||
message_type=MessageType.TEXT,
|
||||
source=MagicMock(),
|
||||
message_id="q4",
|
||||
)
|
||||
adapter._pending_messages[session_key] = event
|
||||
|
||||
# Agent finishes (no interrupt)
|
||||
del adapter._active_sessions[session_key]
|
||||
|
||||
# The queued message should still be retrievable
|
||||
retrieved = adapter.get_pending_message(session_key)
|
||||
assert retrieved is not None
|
||||
assert retrieved.text == "process this after"
|
||||
|
||||
def test_multiple_queues_last_one_wins(self):
|
||||
"""If user /queue's multiple times, last message overwrites."""
|
||||
adapter = _StubAdapter()
|
||||
session_key = "telegram:user:123"
|
||||
|
||||
for text in ["first", "second", "third"]:
|
||||
event = MessageEvent(
|
||||
text=text,
|
||||
message_type=MessageType.TEXT,
|
||||
source=MagicMock(),
|
||||
message_id=f"q-{text}",
|
||||
)
|
||||
adapter._pending_messages[session_key] = event
|
||||
|
||||
retrieved = adapter.get_pending_message(session_key)
|
||||
assert retrieved.text == "third"
|
||||
220
hermes_code/tests/gateway/test_reasoning_command.py
Normal file
220
hermes_code/tests/gateway/test_reasoning_command.py
Normal file
|
|
@ -0,0 +1,220 @@
|
|||
"""Tests for gateway /reasoning command and hot reload behavior."""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import sys
|
||||
import types
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
import gateway.run as gateway_run
|
||||
from gateway.config import Platform
|
||||
from gateway.platforms.base import MessageEvent
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
def _make_event(text="/reasoning", platform=Platform.TELEGRAM, user_id="12345", chat_id="67890"):
|
||||
"""Build a MessageEvent for testing."""
|
||||
source = SessionSource(
|
||||
platform=platform,
|
||||
user_id=user_id,
|
||||
chat_id=chat_id,
|
||||
user_name="testuser",
|
||||
)
|
||||
return MessageEvent(text=text, source=source)
|
||||
|
||||
|
||||
def _make_runner():
|
||||
"""Create a bare GatewayRunner without calling __init__."""
|
||||
runner = object.__new__(gateway_run.GatewayRunner)
|
||||
runner.adapters = {}
|
||||
runner._ephemeral_system_prompt = ""
|
||||
runner._prefill_messages = []
|
||||
runner._reasoning_config = None
|
||||
runner._show_reasoning = False
|
||||
runner._provider_routing = {}
|
||||
runner._fallback_model = None
|
||||
runner._running_agents = {}
|
||||
runner.hooks = MagicMock()
|
||||
runner.hooks.emit = AsyncMock()
|
||||
runner.hooks.loaded_hooks = []
|
||||
runner._session_db = None
|
||||
runner._get_or_create_gateway_honcho = lambda session_key: (None, None)
|
||||
return runner
|
||||
|
||||
|
||||
class _CapturingAgent:
|
||||
"""Fake agent that records init kwargs for assertions."""
|
||||
|
||||
last_init = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
type(self).last_init = dict(kwargs)
|
||||
self.tools = []
|
||||
|
||||
def run_conversation(self, user_message: str, conversation_history=None, task_id=None):
|
||||
return {
|
||||
"final_response": "ok",
|
||||
"messages": [],
|
||||
"api_calls": 1,
|
||||
}
|
||||
|
||||
|
||||
class TestReasoningCommand:
|
||||
@pytest.mark.asyncio
|
||||
async def test_reasoning_in_help_output(self):
|
||||
runner = _make_runner()
|
||||
event = _make_event(text="/help")
|
||||
|
||||
result = await runner._handle_help_command(event)
|
||||
|
||||
assert "/reasoning [level|show|hide]" in result
|
||||
|
||||
def test_reasoning_is_known_command(self):
|
||||
source = inspect.getsource(gateway_run.GatewayRunner._handle_message)
|
||||
assert '"reasoning"' in source
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reasoning_command_reloads_current_state_from_config(self, tmp_path, monkeypatch):
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
config_path = hermes_home / "config.yaml"
|
||||
config_path.write_text(
|
||||
"agent:\n reasoning_effort: none\ndisplay:\n show_reasoning: true\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", hermes_home)
|
||||
monkeypatch.delenv("HERMES_REASONING_EFFORT", raising=False)
|
||||
|
||||
runner = _make_runner()
|
||||
runner._reasoning_config = {"enabled": True, "effort": "xhigh"}
|
||||
runner._show_reasoning = False
|
||||
|
||||
result = await runner._handle_reasoning_command(_make_event("/reasoning"))
|
||||
|
||||
assert "**Effort:** `none (disabled)`" in result
|
||||
assert "**Display:** on ✓" in result
|
||||
assert runner._reasoning_config == {"enabled": False}
|
||||
assert runner._show_reasoning is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_reasoning_command_updates_config_and_cache(self, tmp_path, monkeypatch):
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
config_path = hermes_home / "config.yaml"
|
||||
config_path.write_text("agent:\n reasoning_effort: medium\n", encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", hermes_home)
|
||||
monkeypatch.delenv("HERMES_REASONING_EFFORT", raising=False)
|
||||
|
||||
runner = _make_runner()
|
||||
runner._reasoning_config = {"enabled": True, "effort": "medium"}
|
||||
|
||||
result = await runner._handle_reasoning_command(_make_event("/reasoning low"))
|
||||
|
||||
saved = yaml.safe_load(config_path.read_text(encoding="utf-8"))
|
||||
assert saved["agent"]["reasoning_effort"] == "low"
|
||||
assert runner._reasoning_config == {"enabled": True, "effort": "low"}
|
||||
assert "takes effect on next message" in result
|
||||
|
||||
def test_run_agent_reloads_reasoning_config_per_message(self, tmp_path, monkeypatch):
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
(hermes_home / "config.yaml").write_text("agent:\n reasoning_effort: low\n", encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", hermes_home)
|
||||
monkeypatch.setattr(gateway_run, "_env_path", hermes_home / ".env")
|
||||
monkeypatch.setattr(gateway_run, "load_dotenv", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(
|
||||
gateway_run,
|
||||
"_resolve_runtime_agent_kwargs",
|
||||
lambda: {
|
||||
"provider": "openrouter",
|
||||
"api_mode": "chat_completions",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
"api_key": "test-key",
|
||||
},
|
||||
)
|
||||
monkeypatch.delenv("HERMES_REASONING_EFFORT", raising=False)
|
||||
fake_run_agent = types.ModuleType("run_agent")
|
||||
fake_run_agent.AIAgent = _CapturingAgent
|
||||
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||
|
||||
_CapturingAgent.last_init = None
|
||||
runner = _make_runner()
|
||||
runner._reasoning_config = {"enabled": True, "effort": "xhigh"}
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.LOCAL,
|
||||
chat_id="cli",
|
||||
chat_name="CLI",
|
||||
chat_type="dm",
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
result = asyncio.run(
|
||||
runner._run_agent(
|
||||
message="ping",
|
||||
context_prompt="",
|
||||
history=[],
|
||||
source=source,
|
||||
session_id="session-1",
|
||||
session_key="agent:main:local:dm",
|
||||
)
|
||||
)
|
||||
|
||||
assert result["final_response"] == "ok"
|
||||
assert _CapturingAgent.last_init is not None
|
||||
assert _CapturingAgent.last_init["reasoning_config"] == {"enabled": True, "effort": "low"}
|
||||
|
||||
def test_run_agent_prefers_config_over_stale_reasoning_env(self, tmp_path, monkeypatch):
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
(hermes_home / "config.yaml").write_text("agent:\n reasoning_effort: none\n", encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", hermes_home)
|
||||
monkeypatch.setattr(gateway_run, "_env_path", hermes_home / ".env")
|
||||
monkeypatch.setattr(gateway_run, "load_dotenv", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(
|
||||
gateway_run,
|
||||
"_resolve_runtime_agent_kwargs",
|
||||
lambda: {
|
||||
"provider": "openrouter",
|
||||
"api_mode": "chat_completions",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
"api_key": "test-key",
|
||||
},
|
||||
)
|
||||
monkeypatch.setenv("HERMES_REASONING_EFFORT", "low")
|
||||
fake_run_agent = types.ModuleType("run_agent")
|
||||
fake_run_agent.AIAgent = _CapturingAgent
|
||||
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||
|
||||
_CapturingAgent.last_init = None
|
||||
runner = _make_runner()
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.LOCAL,
|
||||
chat_id="cli",
|
||||
chat_name="CLI",
|
||||
chat_type="dm",
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
result = asyncio.run(
|
||||
runner._run_agent(
|
||||
message="ping",
|
||||
context_prompt="",
|
||||
history=[],
|
||||
source=source,
|
||||
session_id="session-1",
|
||||
session_key="agent:main:local:dm",
|
||||
)
|
||||
)
|
||||
|
||||
assert result["final_response"] == "ok"
|
||||
assert _CapturingAgent.last_init is not None
|
||||
assert _CapturingAgent.last_init["reasoning_config"] == {"enabled": False}
|
||||
226
hermes_code/tests/gateway/test_resume_command.py
Normal file
226
hermes_code/tests/gateway/test_resume_command.py
Normal file
|
|
@ -0,0 +1,226 @@
|
|||
"""Tests for /resume gateway slash command.
|
||||
|
||||
Tests the _handle_resume_command handler (switch to a previously-named session)
|
||||
across gateway messenger platforms.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform
|
||||
from gateway.platforms.base import MessageEvent
|
||||
from gateway.session import SessionSource, build_session_key
|
||||
|
||||
|
||||
def _make_event(text="/resume", platform=Platform.TELEGRAM,
|
||||
user_id="12345", chat_id="67890"):
|
||||
"""Build a MessageEvent for testing."""
|
||||
source = SessionSource(
|
||||
platform=platform,
|
||||
user_id=user_id,
|
||||
chat_id=chat_id,
|
||||
user_name="testuser",
|
||||
)
|
||||
return MessageEvent(text=text, source=source)
|
||||
|
||||
|
||||
def _session_key_for_event(event):
|
||||
"""Get the session key that build_session_key produces for an event."""
|
||||
return build_session_key(event.source)
|
||||
|
||||
|
||||
def _make_runner(session_db=None, current_session_id="current_session_001",
|
||||
event=None):
|
||||
"""Create a bare GatewayRunner with a mock session_store and optional session_db."""
|
||||
from gateway.run import GatewayRunner
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.adapters = {}
|
||||
runner._voice_mode = {}
|
||||
runner._session_db = session_db
|
||||
runner._running_agents = {}
|
||||
|
||||
# Compute the real session key if an event is provided
|
||||
session_key = build_session_key(event.source) if event else "agent:main:telegram:dm"
|
||||
|
||||
# Mock session_store that returns a session entry with a known session_id
|
||||
mock_session_entry = MagicMock()
|
||||
mock_session_entry.session_id = current_session_id
|
||||
mock_session_entry.session_key = session_key
|
||||
mock_store = MagicMock()
|
||||
mock_store.get_or_create_session.return_value = mock_session_entry
|
||||
mock_store.load_transcript.return_value = []
|
||||
mock_store.switch_session.return_value = mock_session_entry
|
||||
runner.session_store = mock_store
|
||||
|
||||
# Stub out memory flushing
|
||||
runner._async_flush_memories = AsyncMock()
|
||||
|
||||
return runner
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _handle_resume_command
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHandleResumeCommand:
|
||||
"""Tests for GatewayRunner._handle_resume_command."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_session_db(self):
|
||||
"""Returns error when session database is unavailable."""
|
||||
runner = _make_runner(session_db=None)
|
||||
event = _make_event(text="/resume My Project")
|
||||
result = await runner._handle_resume_command(event)
|
||||
assert "not available" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_named_sessions_when_no_arg(self, tmp_path):
|
||||
"""With no argument, lists recently titled sessions."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("sess_001", "telegram")
|
||||
db.create_session("sess_002", "telegram")
|
||||
db.set_session_title("sess_001", "Research")
|
||||
db.set_session_title("sess_002", "Coding")
|
||||
|
||||
event = _make_event(text="/resume")
|
||||
runner = _make_runner(session_db=db, event=event)
|
||||
result = await runner._handle_resume_command(event)
|
||||
assert "Research" in result
|
||||
assert "Coding" in result
|
||||
assert "Named Sessions" in result
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_shows_usage_when_no_titled(self, tmp_path):
|
||||
"""With no arg and no titled sessions, shows instructions."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("sess_001", "telegram") # No title
|
||||
|
||||
event = _make_event(text="/resume")
|
||||
runner = _make_runner(session_db=db, event=event)
|
||||
result = await runner._handle_resume_command(event)
|
||||
assert "No named sessions" in result
|
||||
assert "/title" in result
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_by_name(self, tmp_path):
|
||||
"""Resolves a title and switches to that session."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("old_session_abc", "telegram")
|
||||
db.set_session_title("old_session_abc", "My Project")
|
||||
db.create_session("current_session_001", "telegram")
|
||||
|
||||
event = _make_event(text="/resume My Project")
|
||||
runner = _make_runner(session_db=db, current_session_id="current_session_001",
|
||||
event=event)
|
||||
result = await runner._handle_resume_command(event)
|
||||
|
||||
assert "Resumed" in result
|
||||
assert "My Project" in result
|
||||
# Verify switch_session was called with the old session ID
|
||||
runner.session_store.switch_session.assert_called_once()
|
||||
call_args = runner.session_store.switch_session.call_args
|
||||
assert call_args[0][1] == "old_session_abc"
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_nonexistent_name(self, tmp_path):
|
||||
"""Returns error for unknown session name."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("current_session_001", "telegram")
|
||||
|
||||
event = _make_event(text="/resume Nonexistent Session")
|
||||
runner = _make_runner(session_db=db, event=event)
|
||||
result = await runner._handle_resume_command(event)
|
||||
assert "No session found" in result
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_already_on_session(self, tmp_path):
|
||||
"""Returns friendly message when already on the requested session."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("current_session_001", "telegram")
|
||||
db.set_session_title("current_session_001", "Active Project")
|
||||
|
||||
event = _make_event(text="/resume Active Project")
|
||||
runner = _make_runner(session_db=db, current_session_id="current_session_001",
|
||||
event=event)
|
||||
result = await runner._handle_resume_command(event)
|
||||
assert "Already on session" in result
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_auto_lineage(self, tmp_path):
|
||||
"""Asking for 'My Project' when 'My Project #2' exists gets the latest."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("sess_v1", "telegram")
|
||||
db.set_session_title("sess_v1", "My Project")
|
||||
db.create_session("sess_v2", "telegram")
|
||||
db.set_session_title("sess_v2", "My Project #2")
|
||||
db.create_session("current_session_001", "telegram")
|
||||
|
||||
event = _make_event(text="/resume My Project")
|
||||
runner = _make_runner(session_db=db, current_session_id="current_session_001",
|
||||
event=event)
|
||||
result = await runner._handle_resume_command(event)
|
||||
|
||||
assert "Resumed" in result
|
||||
# Should resolve to #2 (latest in lineage)
|
||||
call_args = runner.session_store.switch_session.call_args
|
||||
assert call_args[0][1] == "sess_v2"
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_clears_running_agent(self, tmp_path):
|
||||
"""Switching sessions clears any cached running agent."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("old_session", "telegram")
|
||||
db.set_session_title("old_session", "Old Work")
|
||||
db.create_session("current_session_001", "telegram")
|
||||
|
||||
event = _make_event(text="/resume Old Work")
|
||||
runner = _make_runner(session_db=db, current_session_id="current_session_001",
|
||||
event=event)
|
||||
# Simulate a running agent using the real session key
|
||||
real_key = _session_key_for_event(event)
|
||||
runner._running_agents[real_key] = MagicMock()
|
||||
|
||||
await runner._handle_resume_command(event)
|
||||
|
||||
assert real_key not in runner._running_agents
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_flushes_memories_with_gateway_session_key(self, tmp_path):
|
||||
"""Resume should preserve the gateway session key for Honcho flushes."""
|
||||
from hermes_state import SessionDB
|
||||
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("old_session", "telegram")
|
||||
db.set_session_title("old_session", "Old Work")
|
||||
db.create_session("current_session_001", "telegram")
|
||||
|
||||
event = _make_event(text="/resume Old Work")
|
||||
runner = _make_runner(
|
||||
session_db=db,
|
||||
current_session_id="current_session_001",
|
||||
event=event,
|
||||
)
|
||||
|
||||
await runner._handle_resume_command(event)
|
||||
|
||||
runner._async_flush_memories.assert_called_once_with(
|
||||
"current_session_001",
|
||||
_session_key_for_event(event),
|
||||
)
|
||||
db.close()
|
||||
97
hermes_code/tests/gateway/test_retry_replacement.py
Normal file
97
hermes_code/tests/gateway/test_retry_replacement.py
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
"""Regression tests for /retry replacement semantics."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig
|
||||
from gateway.platforms.base import MessageEvent, MessageType
|
||||
from gateway.run import GatewayRunner
|
||||
from gateway.session import SessionStore
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gateway_retry_replaces_last_user_turn_in_transcript(tmp_path):
|
||||
config = GatewayConfig()
|
||||
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||
store = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
store._db = None
|
||||
store._loaded = True
|
||||
|
||||
session_id = "retry_session"
|
||||
for msg in [
|
||||
{"role": "session_meta", "tools": []},
|
||||
{"role": "user", "content": "first question"},
|
||||
{"role": "assistant", "content": "first answer"},
|
||||
{"role": "user", "content": "retry me"},
|
||||
{"role": "assistant", "content": "old answer"},
|
||||
]:
|
||||
store.append_to_transcript(session_id, msg)
|
||||
|
||||
gw = GatewayRunner.__new__(GatewayRunner)
|
||||
gw.config = config
|
||||
gw.session_store = store
|
||||
|
||||
session_entry = MagicMock(session_id=session_id)
|
||||
session_entry.last_prompt_tokens = 111
|
||||
gw.session_store.get_or_create_session = MagicMock(return_value=session_entry)
|
||||
|
||||
async def fake_handle_message(event):
|
||||
assert event.text == "retry me"
|
||||
transcript_before = store.load_transcript(session_id)
|
||||
assert [m.get("content") for m in transcript_before if m.get("role") == "user"] == [
|
||||
"first question"
|
||||
]
|
||||
store.append_to_transcript(session_id, {"role": "user", "content": event.text})
|
||||
store.append_to_transcript(session_id, {"role": "assistant", "content": "new answer"})
|
||||
return "new answer"
|
||||
|
||||
gw._handle_message = AsyncMock(side_effect=fake_handle_message)
|
||||
|
||||
result = await gw._handle_retry_command(
|
||||
MessageEvent(text="/retry", message_type=MessageType.TEXT, source=MagicMock())
|
||||
)
|
||||
|
||||
assert result == "new answer"
|
||||
transcript_after = store.load_transcript(session_id)
|
||||
assert [m.get("content") for m in transcript_after if m.get("role") == "user"] == [
|
||||
"first question",
|
||||
"retry me",
|
||||
]
|
||||
assert [m.get("content") for m in transcript_after if m.get("role") == "assistant"] == [
|
||||
"first answer",
|
||||
"new answer",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gateway_retry_replays_original_text_not_retry_command(tmp_path):
|
||||
config = MagicMock()
|
||||
config.sessions_dir = tmp_path
|
||||
config.max_context_messages = 20
|
||||
gw = GatewayRunner.__new__(GatewayRunner)
|
||||
gw.config = config
|
||||
gw.session_store = MagicMock()
|
||||
|
||||
session_entry = MagicMock(session_id="test-session")
|
||||
session_entry.last_prompt_tokens = 55
|
||||
gw.session_store.get_or_create_session.return_value = session_entry
|
||||
gw.session_store.load_transcript.return_value = [
|
||||
{"role": "user", "content": "real message"},
|
||||
{"role": "assistant", "content": "answer"},
|
||||
]
|
||||
gw.session_store.rewrite_transcript = MagicMock()
|
||||
|
||||
captured = {}
|
||||
|
||||
async def fake_handle_message(event):
|
||||
captured["text"] = event.text
|
||||
return "ok"
|
||||
|
||||
gw._handle_message = AsyncMock(side_effect=fake_handle_message)
|
||||
|
||||
await gw._handle_retry_command(
|
||||
MessageEvent(text="/retry", message_type=MessageType.TEXT, source=MagicMock())
|
||||
)
|
||||
|
||||
assert captured["text"] == "real message"
|
||||
60
hermes_code/tests/gateway/test_retry_response.py
Normal file
60
hermes_code/tests/gateway/test_retry_response.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
"""Regression test: /retry must return the agent response, not None.
|
||||
|
||||
Before the fix in PR #441, _handle_retry_command() called
|
||||
_handle_message(retry_event) but discarded its return value with `return None`,
|
||||
so users never received the final response.
|
||||
"""
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from gateway.run import GatewayRunner
|
||||
from gateway.platforms.base import MessageEvent, MessageType
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gateway(tmp_path):
|
||||
config = MagicMock()
|
||||
config.sessions_dir = tmp_path
|
||||
config.max_context_messages = 20
|
||||
gw = GatewayRunner.__new__(GatewayRunner)
|
||||
gw.config = config
|
||||
gw.session_store = MagicMock()
|
||||
return gw
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_returns_response_not_none(gateway):
|
||||
"""_handle_retry_command must return the inner handler response, not None."""
|
||||
gateway.session_store.get_or_create_session.return_value = MagicMock(
|
||||
session_id="test-session"
|
||||
)
|
||||
gateway.session_store.load_transcript.return_value = [
|
||||
{"role": "user", "content": "Hello Hermes"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
gateway.session_store.rewrite_transcript = MagicMock()
|
||||
expected_response = "Hi there! (retried)"
|
||||
gateway._handle_message = AsyncMock(return_value=expected_response)
|
||||
event = MessageEvent(
|
||||
text="/retry",
|
||||
message_type=MessageType.TEXT,
|
||||
source=MagicMock(),
|
||||
)
|
||||
result = await gateway._handle_retry_command(event)
|
||||
assert result is not None, "/retry must not return None"
|
||||
assert result == expected_response
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_no_previous_message(gateway):
|
||||
"""If there is no previous user message, return early with a message."""
|
||||
gateway.session_store.get_or_create_session.return_value = MagicMock(
|
||||
session_id="test-session"
|
||||
)
|
||||
gateway.session_store.load_transcript.return_value = []
|
||||
event = MessageEvent(
|
||||
text="/retry",
|
||||
message_type=MessageType.TEXT,
|
||||
source=MagicMock(),
|
||||
)
|
||||
result = await gateway._handle_retry_command(event)
|
||||
assert result == "No previous message to retry."
|
||||
135
hermes_code/tests/gateway/test_run_progress_topics.py
Normal file
135
hermes_code/tests/gateway/test_run_progress_topics.py
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
"""Tests for topic-aware gateway progress updates."""
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
import time
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter, SendResult
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
class ProgressCaptureAdapter(BasePlatformAdapter):
|
||||
def __init__(self):
|
||||
super().__init__(PlatformConfig(enabled=True, token="fake-token"), Platform.TELEGRAM)
|
||||
self.sent = []
|
||||
self.edits = []
|
||||
self.typing = []
|
||||
|
||||
async def connect(self) -> bool:
|
||||
return True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
return None
|
||||
|
||||
async def send(self, chat_id, content, reply_to=None, metadata=None) -> SendResult:
|
||||
self.sent.append(
|
||||
{
|
||||
"chat_id": chat_id,
|
||||
"content": content,
|
||||
"reply_to": reply_to,
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
return SendResult(success=True, message_id="progress-1")
|
||||
|
||||
async def edit_message(self, chat_id, message_id, content) -> SendResult:
|
||||
self.edits.append(
|
||||
{
|
||||
"chat_id": chat_id,
|
||||
"message_id": message_id,
|
||||
"content": content,
|
||||
}
|
||||
)
|
||||
return SendResult(success=True, message_id=message_id)
|
||||
|
||||
async def send_typing(self, chat_id, metadata=None) -> None:
|
||||
self.typing.append({"chat_id": chat_id, "metadata": metadata})
|
||||
|
||||
async def get_chat_info(self, chat_id: str):
|
||||
return {"id": chat_id}
|
||||
|
||||
|
||||
class FakeAgent:
|
||||
def __init__(self, **kwargs):
|
||||
self.tool_progress_callback = kwargs.get("tool_progress_callback")
|
||||
self.tools = []
|
||||
|
||||
def run_conversation(self, message, conversation_history=None, task_id=None):
|
||||
self.tool_progress_callback("terminal", "pwd")
|
||||
time.sleep(0.35)
|
||||
self.tool_progress_callback("browser_navigate", "https://example.com")
|
||||
time.sleep(0.35)
|
||||
return {
|
||||
"final_response": "done",
|
||||
"messages": [],
|
||||
"api_calls": 1,
|
||||
}
|
||||
|
||||
|
||||
def _make_runner(adapter):
|
||||
gateway_run = importlib.import_module("gateway.run")
|
||||
GatewayRunner = gateway_run.GatewayRunner
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.adapters = {Platform.TELEGRAM: adapter}
|
||||
runner._voice_mode = {}
|
||||
runner._prefill_messages = []
|
||||
runner._ephemeral_system_prompt = ""
|
||||
runner._reasoning_config = None
|
||||
runner._provider_routing = {}
|
||||
runner._fallback_model = None
|
||||
runner._session_db = None
|
||||
runner._running_agents = {}
|
||||
runner.hooks = SimpleNamespace(loaded_hooks=False)
|
||||
return runner
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_progress_stays_in_originating_topic(monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("HERMES_TOOL_PROGRESS_MODE", "all")
|
||||
|
||||
fake_dotenv = types.ModuleType("dotenv")
|
||||
fake_dotenv.load_dotenv = lambda *args, **kwargs: None
|
||||
monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv)
|
||||
|
||||
fake_run_agent = types.ModuleType("run_agent")
|
||||
fake_run_agent.AIAgent = FakeAgent
|
||||
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||
|
||||
adapter = ProgressCaptureAdapter()
|
||||
runner = _make_runner(adapter)
|
||||
gateway_run = importlib.import_module("gateway.run")
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "fake"})
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="-1001",
|
||||
chat_type="group",
|
||||
thread_id="17585",
|
||||
)
|
||||
|
||||
result = await runner._run_agent(
|
||||
message="hello",
|
||||
context_prompt="",
|
||||
history=[],
|
||||
source=source,
|
||||
session_id="sess-1",
|
||||
session_key="agent:main:telegram:group:-1001:17585",
|
||||
)
|
||||
|
||||
assert result["final_response"] == "done"
|
||||
assert adapter.sent == [
|
||||
{
|
||||
"chat_id": "-1001",
|
||||
"content": '💻 terminal: "pwd"',
|
||||
"reply_to": None,
|
||||
"metadata": {"thread_id": "17585"},
|
||||
}
|
||||
]
|
||||
assert adapter.edits
|
||||
assert all(call["metadata"] == {"thread_id": "17585"} for call in adapter.typing)
|
||||
95
hermes_code/tests/gateway/test_runner_fatal_adapter.py
Normal file
95
hermes_code/tests/gateway/test_runner_fatal_adapter.py
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
|
||||
class _FatalAdapter(BasePlatformAdapter):
|
||||
def __init__(self):
|
||||
super().__init__(PlatformConfig(enabled=True, token="token"), Platform.TELEGRAM)
|
||||
|
||||
async def connect(self) -> bool:
|
||||
self._set_fatal_error(
|
||||
"telegram_token_lock",
|
||||
"Another local Hermes gateway is already using this Telegram bot token.",
|
||||
retryable=False,
|
||||
)
|
||||
return False
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
self._mark_disconnected()
|
||||
|
||||
async def send(self, chat_id, content, reply_to=None, metadata=None):
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_chat_info(self, chat_id):
|
||||
return {"id": chat_id}
|
||||
|
||||
|
||||
class _RuntimeRetryableAdapter(BasePlatformAdapter):
|
||||
def __init__(self):
|
||||
super().__init__(PlatformConfig(enabled=True, token="token"), Platform.WHATSAPP)
|
||||
|
||||
async def connect(self) -> bool:
|
||||
return True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
self._mark_disconnected()
|
||||
|
||||
async def send(self, chat_id, content, reply_to=None, metadata=None):
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_chat_info(self, chat_id):
|
||||
return {"id": chat_id}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_requests_clean_exit_for_nonretryable_startup_conflict(monkeypatch, tmp_path):
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.TELEGRAM: PlatformConfig(enabled=True, token="token")
|
||||
},
|
||||
sessions_dir=tmp_path / "sessions",
|
||||
)
|
||||
runner = GatewayRunner(config)
|
||||
|
||||
monkeypatch.setattr(runner, "_create_adapter", lambda platform, platform_config: _FatalAdapter())
|
||||
|
||||
ok = await runner.start()
|
||||
|
||||
assert ok is True
|
||||
assert runner.should_exit_cleanly is True
|
||||
assert "already using this Telegram bot token" in runner.exit_reason
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_queues_retryable_runtime_fatal_for_reconnection(monkeypatch, tmp_path):
|
||||
"""Retryable runtime fatal errors queue the platform for reconnection
|
||||
instead of shutting down the gateway."""
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.WHATSAPP: PlatformConfig(enabled=True, token="token")
|
||||
},
|
||||
sessions_dir=tmp_path / "sessions",
|
||||
)
|
||||
runner = GatewayRunner(config)
|
||||
adapter = _RuntimeRetryableAdapter()
|
||||
adapter._set_fatal_error(
|
||||
"whatsapp_bridge_exited",
|
||||
"WhatsApp bridge process exited unexpectedly (code 1).",
|
||||
retryable=True,
|
||||
)
|
||||
|
||||
runner.adapters = {Platform.WHATSAPP: adapter}
|
||||
runner.delivery_router.adapters = runner.adapters
|
||||
runner.stop = AsyncMock()
|
||||
|
||||
await runner._handle_adapter_fatal_error(adapter)
|
||||
|
||||
# Should NOT shut down — platform is queued for reconnection
|
||||
runner.stop.assert_not_awaited()
|
||||
assert Platform.WHATSAPP in runner._failed_platforms
|
||||
assert runner._failed_platforms[Platform.WHATSAPP]["attempts"] == 0
|
||||
89
hermes_code/tests/gateway/test_runner_startup_failures.py
Normal file
89
hermes_code/tests/gateway/test_runner_startup_failures.py
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter
|
||||
from gateway.run import GatewayRunner
|
||||
from gateway.status import read_runtime_status
|
||||
|
||||
|
||||
class _RetryableFailureAdapter(BasePlatformAdapter):
|
||||
def __init__(self):
|
||||
super().__init__(PlatformConfig(enabled=True, token="***"), Platform.TELEGRAM)
|
||||
|
||||
async def connect(self) -> bool:
|
||||
self._set_fatal_error(
|
||||
"telegram_connect_error",
|
||||
"Telegram startup failed: temporary DNS resolution failure.",
|
||||
retryable=True,
|
||||
)
|
||||
return False
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
self._mark_disconnected()
|
||||
|
||||
async def send(self, chat_id, content, reply_to=None, metadata=None):
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_chat_info(self, chat_id):
|
||||
return {"id": chat_id}
|
||||
|
||||
|
||||
class _DisabledAdapter(BasePlatformAdapter):
|
||||
def __init__(self):
|
||||
super().__init__(PlatformConfig(enabled=False, token="***"), Platform.TELEGRAM)
|
||||
|
||||
async def connect(self) -> bool:
|
||||
raise AssertionError("connect should not be called for disabled platforms")
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
self._mark_disconnected()
|
||||
|
||||
async def send(self, chat_id, content, reply_to=None, metadata=None):
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_chat_info(self, chat_id):
|
||||
return {"id": chat_id}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_returns_failure_for_retryable_startup_errors(monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")
|
||||
},
|
||||
sessions_dir=tmp_path / "sessions",
|
||||
)
|
||||
runner = GatewayRunner(config)
|
||||
|
||||
monkeypatch.setattr(runner, "_create_adapter", lambda platform, platform_config: _RetryableFailureAdapter())
|
||||
|
||||
ok = await runner.start()
|
||||
|
||||
assert ok is False
|
||||
assert runner.should_exit_cleanly is False
|
||||
state = read_runtime_status()
|
||||
assert state["gateway_state"] == "startup_failed"
|
||||
assert "temporary DNS resolution failure" in state["exit_reason"]
|
||||
assert state["platforms"]["telegram"]["state"] == "fatal"
|
||||
assert state["platforms"]["telegram"]["error_code"] == "telegram_connect_error"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_allows_cron_only_mode_when_no_platforms_are_enabled(monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.TELEGRAM: PlatformConfig(enabled=False, token="***")
|
||||
},
|
||||
sessions_dir=tmp_path / "sessions",
|
||||
)
|
||||
runner = GatewayRunner(config)
|
||||
|
||||
ok = await runner.start()
|
||||
|
||||
assert ok is True
|
||||
assert runner.should_exit_cleanly is False
|
||||
assert runner.adapters == {}
|
||||
state = read_runtime_status()
|
||||
assert state["gateway_state"] == "running"
|
||||
437
hermes_code/tests/gateway/test_send_image_file.py
Normal file
437
hermes_code/tests/gateway/test_send_image_file.py
Normal file
|
|
@ -0,0 +1,437 @@
|
|||
"""
|
||||
Tests for send_image_file() on Telegram, Discord, and Slack platforms,
|
||||
and MEDIA: .png extraction/routing in the base platform adapter.
|
||||
|
||||
Covers: local image file sending, file-not-found handling, fallback on error,
|
||||
MEDIA: tag extraction for image extensions, and routing to send_image_file.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter, SendResult
|
||||
|
||||
|
||||
def _run(coro):
|
||||
"""Run a coroutine in a fresh event loop for sync-style tests."""
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MEDIA: extraction tests for image files
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExtractMediaImages:
|
||||
"""Test that MEDIA: tags with image extensions are correctly extracted."""
|
||||
|
||||
def test_png_image_extracted(self):
|
||||
content = "Here is the screenshot:\nMEDIA:/home/user/.hermes/browser_screenshots/shot.png"
|
||||
media, cleaned = BasePlatformAdapter.extract_media(content)
|
||||
assert len(media) == 1
|
||||
assert media[0][0] == "/home/user/.hermes/browser_screenshots/shot.png"
|
||||
assert "MEDIA:" not in cleaned
|
||||
assert "Here is the screenshot" in cleaned
|
||||
|
||||
def test_jpg_image_extracted(self):
|
||||
content = "MEDIA:/tmp/photo.jpg"
|
||||
media, cleaned = BasePlatformAdapter.extract_media(content)
|
||||
assert len(media) == 1
|
||||
assert media[0][0] == "/tmp/photo.jpg"
|
||||
|
||||
def test_webp_image_extracted(self):
|
||||
content = "MEDIA:/tmp/image.webp"
|
||||
media, _ = BasePlatformAdapter.extract_media(content)
|
||||
assert len(media) == 1
|
||||
|
||||
def test_mixed_audio_and_image(self):
|
||||
content = "MEDIA:/audio.ogg\nMEDIA:/screenshot.png"
|
||||
media, _ = BasePlatformAdapter.extract_media(content)
|
||||
assert len(media) == 2
|
||||
paths = [m[0] for m in media]
|
||||
assert "/audio.ogg" in paths
|
||||
assert "/screenshot.png" in paths
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Telegram send_image_file tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _ensure_telegram_mock():
|
||||
"""Install mock telegram modules so TelegramAdapter can be imported."""
|
||||
if "telegram" in sys.modules and hasattr(sys.modules["telegram"], "__file__"):
|
||||
return
|
||||
|
||||
telegram_mod = MagicMock()
|
||||
telegram_mod.ext.ContextTypes.DEFAULT_TYPE = type(None)
|
||||
telegram_mod.constants.ParseMode.MARKDOWN_V2 = "MarkdownV2"
|
||||
telegram_mod.constants.ChatType.GROUP = "group"
|
||||
telegram_mod.constants.ChatType.SUPERGROUP = "supergroup"
|
||||
telegram_mod.constants.ChatType.CHANNEL = "channel"
|
||||
telegram_mod.constants.ChatType.PRIVATE = "private"
|
||||
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants"):
|
||||
sys.modules.setdefault(name, telegram_mod)
|
||||
|
||||
|
||||
_ensure_telegram_mock()
|
||||
|
||||
from gateway.platforms.telegram import TelegramAdapter # noqa: E402
|
||||
|
||||
|
||||
class TestTelegramSendImageFile:
|
||||
@pytest.fixture
|
||||
def adapter(self):
|
||||
config = PlatformConfig(enabled=True, token="fake-token")
|
||||
a = TelegramAdapter(config)
|
||||
a._bot = MagicMock()
|
||||
return a
|
||||
|
||||
def test_sends_local_image_as_photo(self, adapter, tmp_path):
|
||||
"""send_image_file should call bot.send_photo with the opened file."""
|
||||
img = tmp_path / "screenshot.png"
|
||||
img.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100) # Minimal PNG-like
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.message_id = 42
|
||||
adapter._bot.send_photo = AsyncMock(return_value=mock_msg)
|
||||
|
||||
result = _run(
|
||||
adapter.send_image_file(chat_id="12345", image_path=str(img))
|
||||
)
|
||||
assert result.success
|
||||
assert result.message_id == "42"
|
||||
adapter._bot.send_photo.assert_awaited_once()
|
||||
|
||||
# Verify photo arg was a file object (opened in rb mode)
|
||||
call_kwargs = adapter._bot.send_photo.call_args
|
||||
assert call_kwargs.kwargs["chat_id"] == 12345
|
||||
|
||||
def test_returns_error_when_file_missing(self, adapter):
|
||||
"""send_image_file should return error for nonexistent file."""
|
||||
result = _run(
|
||||
adapter.send_image_file(chat_id="12345", image_path="/nonexistent/image.png")
|
||||
)
|
||||
assert not result.success
|
||||
assert "not found" in result.error
|
||||
|
||||
def test_returns_error_when_not_connected(self, adapter):
|
||||
"""send_image_file should return error when bot is None."""
|
||||
adapter._bot = None
|
||||
result = _run(
|
||||
adapter.send_image_file(chat_id="12345", image_path="/tmp/img.png")
|
||||
)
|
||||
assert not result.success
|
||||
assert "Not connected" in result.error
|
||||
|
||||
def test_caption_truncated_to_1024(self, adapter, tmp_path):
|
||||
"""Telegram captions have a 1024 char limit."""
|
||||
img = tmp_path / "shot.png"
|
||||
img.write_bytes(b"\x89PNG" + b"\x00" * 50)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.message_id = 1
|
||||
adapter._bot.send_photo = AsyncMock(return_value=mock_msg)
|
||||
|
||||
long_caption = "A" * 2000
|
||||
_run(
|
||||
adapter.send_image_file(chat_id="12345", image_path=str(img), caption=long_caption)
|
||||
)
|
||||
|
||||
call_kwargs = adapter._bot.send_photo.call_args.kwargs
|
||||
assert len(call_kwargs["caption"]) == 1024
|
||||
|
||||
def test_thread_id_forwarded(self, adapter, tmp_path):
|
||||
"""metadata thread_id is forwarded as message_thread_id (required for Telegram forum groups)."""
|
||||
img = tmp_path / "shot.png"
|
||||
img.write_bytes(b"\x89PNG" + b"\x00" * 50)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.message_id = 43
|
||||
adapter._bot.send_photo = AsyncMock(return_value=mock_msg)
|
||||
|
||||
_run(
|
||||
adapter.send_image_file(
|
||||
chat_id="12345",
|
||||
image_path=str(img),
|
||||
metadata={"thread_id": "789"},
|
||||
)
|
||||
)
|
||||
|
||||
call_kwargs = adapter._bot.send_photo.call_args.kwargs
|
||||
assert call_kwargs["message_thread_id"] == 789
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Discord send_image_file tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _ensure_discord_mock():
|
||||
"""Install mock discord module so DiscordAdapter can be imported."""
|
||||
if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
|
||||
return
|
||||
|
||||
discord_mod = MagicMock()
|
||||
discord_mod.Intents.default.return_value = MagicMock()
|
||||
discord_mod.Client = MagicMock
|
||||
discord_mod.File = MagicMock
|
||||
|
||||
for name in ("discord", "discord.ext", "discord.ext.commands"):
|
||||
sys.modules.setdefault(name, discord_mod)
|
||||
|
||||
|
||||
_ensure_discord_mock()
|
||||
|
||||
import discord as discord_mod_ref # noqa: E402
|
||||
from gateway.platforms.discord import DiscordAdapter # noqa: E402
|
||||
|
||||
|
||||
class TestDiscordSendImageFile:
|
||||
@pytest.fixture
|
||||
def adapter(self):
|
||||
config = PlatformConfig(enabled=True, token="fake-token")
|
||||
a = DiscordAdapter(config)
|
||||
a._client = MagicMock()
|
||||
return a
|
||||
|
||||
def test_sends_local_image_as_attachment(self, adapter, tmp_path):
|
||||
"""send_image_file should create discord.File and send to channel."""
|
||||
img = tmp_path / "screenshot.png"
|
||||
img.write_bytes(b"\x89PNG" + b"\x00" * 50)
|
||||
|
||||
mock_channel = MagicMock()
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.id = 99
|
||||
mock_channel.send = AsyncMock(return_value=mock_msg)
|
||||
adapter._client.get_channel = MagicMock(return_value=mock_channel)
|
||||
|
||||
result = _run(
|
||||
adapter.send_image_file(chat_id="67890", image_path=str(img))
|
||||
)
|
||||
assert result.success
|
||||
assert result.message_id == "99"
|
||||
mock_channel.send.assert_awaited_once()
|
||||
|
||||
def test_send_document_uploads_file_attachment(self, adapter, tmp_path):
|
||||
"""send_document should upload a native Discord attachment."""
|
||||
pdf = tmp_path / "sample.pdf"
|
||||
pdf.write_bytes(b"%PDF-1.4\n%\xe2\xe3\xcf\xd3\n")
|
||||
|
||||
mock_channel = MagicMock()
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.id = 100
|
||||
mock_channel.send = AsyncMock(return_value=mock_msg)
|
||||
adapter._client.get_channel = MagicMock(return_value=mock_channel)
|
||||
|
||||
with patch.object(discord_mod_ref, "File", MagicMock()) as file_cls:
|
||||
result = _run(
|
||||
adapter.send_document(
|
||||
chat_id="67890",
|
||||
file_path=str(pdf),
|
||||
file_name="renamed.pdf",
|
||||
metadata={"thread_id": "123"},
|
||||
)
|
||||
)
|
||||
|
||||
assert result.success
|
||||
assert result.message_id == "100"
|
||||
assert "file" in mock_channel.send.call_args.kwargs
|
||||
assert file_cls.call_args.kwargs["filename"] == "renamed.pdf"
|
||||
|
||||
def test_send_video_uploads_file_attachment(self, adapter, tmp_path):
|
||||
"""send_video should upload a native Discord attachment."""
|
||||
video = tmp_path / "clip.mp4"
|
||||
video.write_bytes(b"\x00\x00\x00\x18ftypmp42" + b"\x00" * 50)
|
||||
|
||||
mock_channel = MagicMock()
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.id = 101
|
||||
mock_channel.send = AsyncMock(return_value=mock_msg)
|
||||
adapter._client.get_channel = MagicMock(return_value=mock_channel)
|
||||
|
||||
with patch.object(discord_mod_ref, "File", MagicMock()) as file_cls:
|
||||
result = _run(
|
||||
adapter.send_video(
|
||||
chat_id="67890",
|
||||
video_path=str(video),
|
||||
metadata={"thread_id": "123"},
|
||||
)
|
||||
)
|
||||
|
||||
assert result.success
|
||||
assert result.message_id == "101"
|
||||
assert "file" in mock_channel.send.call_args.kwargs
|
||||
assert file_cls.call_args.kwargs["filename"] == "clip.mp4"
|
||||
|
||||
def test_returns_error_when_file_missing(self, adapter):
|
||||
result = _run(
|
||||
adapter.send_image_file(chat_id="67890", image_path="/nonexistent.png")
|
||||
)
|
||||
assert not result.success
|
||||
assert "not found" in result.error
|
||||
|
||||
def test_returns_error_when_not_connected(self, adapter):
|
||||
adapter._client = None
|
||||
result = _run(
|
||||
adapter.send_image_file(chat_id="67890", image_path="/tmp/img.png")
|
||||
)
|
||||
assert not result.success
|
||||
assert "Not connected" in result.error
|
||||
|
||||
def test_handles_missing_channel(self, adapter):
|
||||
adapter._client.get_channel = MagicMock(return_value=None)
|
||||
adapter._client.fetch_channel = AsyncMock(return_value=None)
|
||||
|
||||
result = _run(
|
||||
adapter.send_image_file(chat_id="99999", image_path="/tmp/img.png")
|
||||
)
|
||||
assert not result.success
|
||||
assert "not found" in result.error
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slack send_image_file tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _ensure_slack_mock():
|
||||
"""Install mock slack_bolt module so SlackAdapter can be imported."""
|
||||
if "slack_bolt" in sys.modules and hasattr(sys.modules["slack_bolt"], "__file__"):
|
||||
return
|
||||
|
||||
slack_mod = MagicMock()
|
||||
for name in ("slack_bolt", "slack_bolt.async_app", "slack_sdk", "slack_sdk.web.async_client"):
|
||||
sys.modules.setdefault(name, slack_mod)
|
||||
|
||||
|
||||
_ensure_slack_mock()
|
||||
|
||||
from gateway.platforms.slack import SlackAdapter # noqa: E402
|
||||
|
||||
|
||||
class TestSlackSendImageFile:
|
||||
@pytest.fixture
|
||||
def adapter(self):
|
||||
config = PlatformConfig(enabled=True, token="xoxb-fake")
|
||||
a = SlackAdapter(config)
|
||||
a._app = MagicMock()
|
||||
return a
|
||||
|
||||
def test_sends_local_image_via_upload(self, adapter, tmp_path):
|
||||
"""send_image_file should call files_upload_v2 with the local path."""
|
||||
img = tmp_path / "screenshot.png"
|
||||
img.write_bytes(b"\x89PNG" + b"\x00" * 50)
|
||||
|
||||
mock_result = MagicMock()
|
||||
adapter._app.client.files_upload_v2 = AsyncMock(return_value=mock_result)
|
||||
|
||||
result = _run(
|
||||
adapter.send_image_file(chat_id="C12345", image_path=str(img))
|
||||
)
|
||||
assert result.success
|
||||
adapter._app.client.files_upload_v2.assert_awaited_once()
|
||||
|
||||
call_kwargs = adapter._app.client.files_upload_v2.call_args.kwargs
|
||||
assert call_kwargs["file"] == str(img)
|
||||
assert call_kwargs["filename"] == "screenshot.png"
|
||||
assert call_kwargs["channel"] == "C12345"
|
||||
|
||||
def test_returns_error_when_file_missing(self, adapter):
|
||||
result = _run(
|
||||
adapter.send_image_file(chat_id="C12345", image_path="/nonexistent.png")
|
||||
)
|
||||
assert not result.success
|
||||
assert "not found" in result.error
|
||||
|
||||
def test_returns_error_when_not_connected(self, adapter):
|
||||
adapter._app = None
|
||||
result = _run(
|
||||
adapter.send_image_file(chat_id="C12345", image_path="/tmp/img.png")
|
||||
)
|
||||
assert not result.success
|
||||
assert "Not connected" in result.error
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# browser_vision screenshot cleanup tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestScreenshotCleanup:
|
||||
def test_cleanup_removes_old_screenshots(self, tmp_path):
|
||||
"""_cleanup_old_screenshots should remove files older than max_age_hours."""
|
||||
import time
|
||||
from tools.browser_tool import _cleanup_old_screenshots, _last_screenshot_cleanup_by_dir
|
||||
|
||||
_last_screenshot_cleanup_by_dir.clear()
|
||||
|
||||
# Create a "fresh" file
|
||||
fresh = tmp_path / "browser_screenshot_fresh.png"
|
||||
fresh.write_bytes(b"new")
|
||||
|
||||
# Create an "old" file and backdate its mtime
|
||||
old = tmp_path / "browser_screenshot_old.png"
|
||||
old.write_bytes(b"old")
|
||||
old_time = time.time() - (25 * 3600) # 25 hours ago
|
||||
os.utime(str(old), (old_time, old_time))
|
||||
|
||||
_cleanup_old_screenshots(tmp_path, max_age_hours=24)
|
||||
|
||||
assert fresh.exists(), "Fresh screenshot should not be removed"
|
||||
assert not old.exists(), "Old screenshot should be removed"
|
||||
|
||||
def test_cleanup_is_throttled_per_directory(self, tmp_path):
|
||||
import time
|
||||
from tools.browser_tool import _cleanup_old_screenshots, _last_screenshot_cleanup_by_dir
|
||||
|
||||
_last_screenshot_cleanup_by_dir.clear()
|
||||
|
||||
old = tmp_path / "browser_screenshot_old.png"
|
||||
old.write_bytes(b"old")
|
||||
old_time = time.time() - (25 * 3600)
|
||||
os.utime(str(old), (old_time, old_time))
|
||||
|
||||
_cleanup_old_screenshots(tmp_path, max_age_hours=24)
|
||||
assert not old.exists()
|
||||
|
||||
old.write_bytes(b"old-again")
|
||||
os.utime(str(old), (old_time, old_time))
|
||||
_cleanup_old_screenshots(tmp_path, max_age_hours=24)
|
||||
|
||||
assert old.exists(), "Repeated cleanup should be skipped while throttled"
|
||||
|
||||
def test_cleanup_ignores_non_screenshot_files(self, tmp_path):
|
||||
"""Only files matching browser_screenshot_*.png should be cleaned."""
|
||||
import time
|
||||
from tools.browser_tool import _cleanup_old_screenshots, _last_screenshot_cleanup_by_dir
|
||||
|
||||
_last_screenshot_cleanup_by_dir.clear()
|
||||
|
||||
other_file = tmp_path / "important_data.txt"
|
||||
other_file.write_bytes(b"keep me")
|
||||
old_time = time.time() - (48 * 3600)
|
||||
os.utime(str(other_file), (old_time, old_time))
|
||||
|
||||
_cleanup_old_screenshots(tmp_path, max_age_hours=24)
|
||||
|
||||
assert other_file.exists(), "Non-screenshot files should not be touched"
|
||||
|
||||
def test_cleanup_handles_empty_dir(self, tmp_path):
|
||||
"""Cleanup should not fail on empty directory."""
|
||||
from tools.browser_tool import _cleanup_old_screenshots, _last_screenshot_cleanup_by_dir
|
||||
_last_screenshot_cleanup_by_dir.clear()
|
||||
_cleanup_old_screenshots(tmp_path, max_age_hours=24) # Should not raise
|
||||
|
||||
def test_cleanup_handles_nonexistent_dir(self):
|
||||
"""Cleanup should not fail if directory doesn't exist."""
|
||||
from pathlib import Path
|
||||
from tools.browser_tool import _cleanup_old_screenshots, _last_screenshot_cleanup_by_dir
|
||||
_last_screenshot_cleanup_by_dir.clear()
|
||||
_cleanup_old_screenshots(Path("/nonexistent/dir"), max_age_hours=24) # Should not raise
|
||||
767
hermes_code/tests/gateway/test_session.py
Normal file
767
hermes_code/tests/gateway/test_session.py
Normal file
|
|
@ -0,0 +1,767 @@
|
|||
"""Tests for gateway session management."""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
from gateway.config import Platform, HomeChannel, GatewayConfig, PlatformConfig
|
||||
from gateway.session import (
|
||||
SessionSource,
|
||||
SessionStore,
|
||||
build_session_context,
|
||||
build_session_context_prompt,
|
||||
build_session_key,
|
||||
)
|
||||
|
||||
|
||||
class TestSessionSourceRoundtrip:
|
||||
def test_full_roundtrip(self):
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="12345",
|
||||
chat_name="My Group",
|
||||
chat_type="group",
|
||||
user_id="99",
|
||||
user_name="alice",
|
||||
thread_id="t1",
|
||||
)
|
||||
d = source.to_dict()
|
||||
restored = SessionSource.from_dict(d)
|
||||
|
||||
assert restored.platform == Platform.TELEGRAM
|
||||
assert restored.chat_id == "12345"
|
||||
assert restored.chat_name == "My Group"
|
||||
assert restored.chat_type == "group"
|
||||
assert restored.user_id == "99"
|
||||
assert restored.user_name == "alice"
|
||||
assert restored.thread_id == "t1"
|
||||
|
||||
def test_full_roundtrip_with_chat_topic(self):
|
||||
"""chat_topic should survive to_dict/from_dict roundtrip."""
|
||||
source = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="789",
|
||||
chat_name="Server / #project-planning",
|
||||
chat_type="group",
|
||||
user_id="42",
|
||||
user_name="bob",
|
||||
chat_topic="Planning and coordination for Project X",
|
||||
)
|
||||
d = source.to_dict()
|
||||
assert d["chat_topic"] == "Planning and coordination for Project X"
|
||||
|
||||
restored = SessionSource.from_dict(d)
|
||||
assert restored.chat_topic == "Planning and coordination for Project X"
|
||||
assert restored.chat_name == "Server / #project-planning"
|
||||
|
||||
def test_minimal_roundtrip(self):
|
||||
source = SessionSource(platform=Platform.LOCAL, chat_id="cli")
|
||||
d = source.to_dict()
|
||||
restored = SessionSource.from_dict(d)
|
||||
assert restored.platform == Platform.LOCAL
|
||||
assert restored.chat_id == "cli"
|
||||
assert restored.chat_type == "dm" # default value preserved
|
||||
|
||||
def test_chat_id_coerced_to_string(self):
|
||||
"""from_dict should handle numeric chat_id (common from Telegram)."""
|
||||
restored = SessionSource.from_dict({
|
||||
"platform": "telegram",
|
||||
"chat_id": 12345,
|
||||
})
|
||||
assert restored.chat_id == "12345"
|
||||
assert isinstance(restored.chat_id, str)
|
||||
|
||||
def test_missing_optional_fields(self):
|
||||
restored = SessionSource.from_dict({
|
||||
"platform": "discord",
|
||||
"chat_id": "abc",
|
||||
})
|
||||
assert restored.chat_name is None
|
||||
assert restored.user_id is None
|
||||
assert restored.user_name is None
|
||||
assert restored.thread_id is None
|
||||
assert restored.chat_topic is None
|
||||
assert restored.chat_type == "dm"
|
||||
|
||||
def test_invalid_platform_raises(self):
|
||||
with pytest.raises((ValueError, KeyError)):
|
||||
SessionSource.from_dict({"platform": "nonexistent", "chat_id": "1"})
|
||||
|
||||
|
||||
class TestSessionSourceDescription:
|
||||
def test_local_cli(self):
|
||||
source = SessionSource.local_cli()
|
||||
assert source.description == "CLI terminal"
|
||||
|
||||
def test_dm_with_username(self):
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM, chat_id="123",
|
||||
chat_type="dm", user_name="bob",
|
||||
)
|
||||
assert "DM" in source.description
|
||||
assert "bob" in source.description
|
||||
|
||||
def test_dm_without_username_falls_back_to_user_id(self):
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM, chat_id="123",
|
||||
chat_type="dm", user_id="456",
|
||||
)
|
||||
assert "456" in source.description
|
||||
|
||||
def test_group_shows_chat_name(self):
|
||||
source = SessionSource(
|
||||
platform=Platform.DISCORD, chat_id="789",
|
||||
chat_type="group", chat_name="Dev Chat",
|
||||
)
|
||||
assert "group" in source.description
|
||||
assert "Dev Chat" in source.description
|
||||
|
||||
def test_channel_type(self):
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM, chat_id="100",
|
||||
chat_type="channel", chat_name="Announcements",
|
||||
)
|
||||
assert "channel" in source.description
|
||||
assert "Announcements" in source.description
|
||||
|
||||
def test_thread_id_appended(self):
|
||||
source = SessionSource(
|
||||
platform=Platform.DISCORD, chat_id="789",
|
||||
chat_type="group", chat_name="General",
|
||||
thread_id="thread-42",
|
||||
)
|
||||
assert "thread" in source.description
|
||||
assert "thread-42" in source.description
|
||||
|
||||
def test_unknown_chat_type_uses_name(self):
|
||||
source = SessionSource(
|
||||
platform=Platform.SLACK, chat_id="C01",
|
||||
chat_type="forum", chat_name="Questions",
|
||||
)
|
||||
assert "Questions" in source.description
|
||||
|
||||
|
||||
class TestLocalCliFactory:
|
||||
def test_local_cli_defaults(self):
|
||||
source = SessionSource.local_cli()
|
||||
assert source.platform == Platform.LOCAL
|
||||
assert source.chat_id == "cli"
|
||||
assert source.chat_type == "dm"
|
||||
assert source.chat_name == "CLI terminal"
|
||||
|
||||
|
||||
class TestBuildSessionContextPrompt:
|
||||
def test_telegram_prompt_contains_platform_and_chat(self):
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.TELEGRAM: PlatformConfig(
|
||||
enabled=True,
|
||||
token="fake-token",
|
||||
home_channel=HomeChannel(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="111",
|
||||
name="Home Chat",
|
||||
),
|
||||
),
|
||||
},
|
||||
)
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="111",
|
||||
chat_name="Home Chat",
|
||||
chat_type="dm",
|
||||
)
|
||||
ctx = build_session_context(source, config)
|
||||
prompt = build_session_context_prompt(ctx)
|
||||
|
||||
assert "Telegram" in prompt
|
||||
assert "Home Chat" in prompt
|
||||
|
||||
def test_discord_prompt(self):
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.DISCORD: PlatformConfig(
|
||||
enabled=True,
|
||||
token="fake-d...oken",
|
||||
),
|
||||
},
|
||||
)
|
||||
source = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="guild-123",
|
||||
chat_name="Server",
|
||||
chat_type="group",
|
||||
user_name="alice",
|
||||
)
|
||||
ctx = build_session_context(source, config)
|
||||
prompt = build_session_context_prompt(ctx)
|
||||
|
||||
assert "Discord" in prompt
|
||||
assert "cannot search" in prompt.lower() or "do not have access" in prompt.lower()
|
||||
|
||||
def test_slack_prompt_includes_platform_notes(self):
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.SLACK: PlatformConfig(enabled=True, token="fake"),
|
||||
},
|
||||
)
|
||||
source = SessionSource(
|
||||
platform=Platform.SLACK,
|
||||
chat_id="C123",
|
||||
chat_name="general",
|
||||
chat_type="group",
|
||||
user_name="bob",
|
||||
)
|
||||
ctx = build_session_context(source, config)
|
||||
prompt = build_session_context_prompt(ctx)
|
||||
|
||||
assert "Slack" in prompt
|
||||
assert "cannot search" in prompt.lower()
|
||||
assert "pin" in prompt.lower()
|
||||
|
||||
def test_discord_prompt_with_channel_topic(self):
|
||||
"""Channel topic should appear in the session context prompt."""
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.DISCORD: PlatformConfig(
|
||||
enabled=True,
|
||||
token="fake-discord-token",
|
||||
),
|
||||
},
|
||||
)
|
||||
source = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="guild-123",
|
||||
chat_name="Server / #project-planning",
|
||||
chat_type="group",
|
||||
user_name="alice",
|
||||
chat_topic="Planning and coordination for Project X",
|
||||
)
|
||||
ctx = build_session_context(source, config)
|
||||
prompt = build_session_context_prompt(ctx)
|
||||
|
||||
assert "Discord" in prompt
|
||||
assert "**Channel Topic:** Planning and coordination for Project X" in prompt
|
||||
|
||||
def test_prompt_omits_channel_topic_when_none(self):
|
||||
"""Channel Topic line should NOT appear when chat_topic is None."""
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.DISCORD: PlatformConfig(
|
||||
enabled=True,
|
||||
token="fake-discord-token",
|
||||
),
|
||||
},
|
||||
)
|
||||
source = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="guild-123",
|
||||
chat_name="Server / #general",
|
||||
chat_type="group",
|
||||
user_name="alice",
|
||||
)
|
||||
ctx = build_session_context(source, config)
|
||||
prompt = build_session_context_prompt(ctx)
|
||||
|
||||
assert "Channel Topic" not in prompt
|
||||
|
||||
def test_local_prompt_mentions_machine(self):
|
||||
config = GatewayConfig()
|
||||
source = SessionSource.local_cli()
|
||||
ctx = build_session_context(source, config)
|
||||
prompt = build_session_context_prompt(ctx)
|
||||
|
||||
assert "Local" in prompt
|
||||
assert "machine running this agent" in prompt
|
||||
|
||||
def test_whatsapp_prompt(self):
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.WHATSAPP: PlatformConfig(enabled=True, token=""),
|
||||
},
|
||||
)
|
||||
source = SessionSource(
|
||||
platform=Platform.WHATSAPP,
|
||||
chat_id="15551234567@s.whatsapp.net",
|
||||
chat_type="dm",
|
||||
user_name="Phone User",
|
||||
)
|
||||
ctx = build_session_context(source, config)
|
||||
prompt = build_session_context_prompt(ctx)
|
||||
|
||||
assert "WhatsApp" in prompt or "whatsapp" in prompt.lower()
|
||||
|
||||
|
||||
class TestSessionStoreRewriteTranscript:
|
||||
"""Regression: /retry and /undo must persist truncated history to disk."""
|
||||
|
||||
@pytest.fixture()
|
||||
def store(self, tmp_path):
|
||||
config = GatewayConfig()
|
||||
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||
s = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
s._db = None # no SQLite for these tests
|
||||
s._loaded = True
|
||||
return s
|
||||
|
||||
def test_rewrite_replaces_jsonl(self, store, tmp_path):
|
||||
session_id = "test_session_1"
|
||||
# Write initial transcript
|
||||
for msg in [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "hi"},
|
||||
{"role": "user", "content": "undo this"},
|
||||
{"role": "assistant", "content": "ok"},
|
||||
]:
|
||||
store.append_to_transcript(session_id, msg)
|
||||
|
||||
# Rewrite with truncated history
|
||||
store.rewrite_transcript(session_id, [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "hi"},
|
||||
])
|
||||
|
||||
reloaded = store.load_transcript(session_id)
|
||||
assert len(reloaded) == 2
|
||||
assert reloaded[0]["content"] == "hello"
|
||||
assert reloaded[1]["content"] == "hi"
|
||||
|
||||
def test_rewrite_with_empty_list(self, store):
|
||||
session_id = "test_session_2"
|
||||
store.append_to_transcript(session_id, {"role": "user", "content": "hi"})
|
||||
|
||||
store.rewrite_transcript(session_id, [])
|
||||
|
||||
reloaded = store.load_transcript(session_id)
|
||||
assert reloaded == []
|
||||
|
||||
|
||||
class TestLoadTranscriptCorruptLines:
|
||||
"""Regression: corrupt JSONL lines (e.g. from mid-write crash) must be
|
||||
skipped instead of crashing the entire transcript load. GH-1193."""
|
||||
|
||||
@pytest.fixture()
|
||||
def store(self, tmp_path):
|
||||
config = GatewayConfig()
|
||||
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||
s = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
s._db = None
|
||||
s._loaded = True
|
||||
return s
|
||||
|
||||
def test_corrupt_line_skipped(self, store, tmp_path):
|
||||
session_id = "corrupt_test"
|
||||
transcript_path = store.get_transcript_path(session_id)
|
||||
transcript_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(transcript_path, "w") as f:
|
||||
f.write('{"role": "user", "content": "hello"}\n')
|
||||
f.write('{"role": "assistant", "content": "hi th') # truncated
|
||||
f.write("\n")
|
||||
f.write('{"role": "user", "content": "goodbye"}\n')
|
||||
|
||||
messages = store.load_transcript(session_id)
|
||||
assert len(messages) == 2
|
||||
assert messages[0]["content"] == "hello"
|
||||
assert messages[1]["content"] == "goodbye"
|
||||
|
||||
def test_all_lines_corrupt_returns_empty(self, store, tmp_path):
|
||||
session_id = "all_corrupt"
|
||||
transcript_path = store.get_transcript_path(session_id)
|
||||
transcript_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(transcript_path, "w") as f:
|
||||
f.write("not json at all\n")
|
||||
f.write("{truncated\n")
|
||||
|
||||
messages = store.load_transcript(session_id)
|
||||
assert messages == []
|
||||
|
||||
def test_valid_transcript_unaffected(self, store, tmp_path):
|
||||
session_id = "valid_test"
|
||||
store.append_to_transcript(session_id, {"role": "user", "content": "a"})
|
||||
store.append_to_transcript(session_id, {"role": "assistant", "content": "b"})
|
||||
|
||||
messages = store.load_transcript(session_id)
|
||||
assert len(messages) == 2
|
||||
assert messages[0]["content"] == "a"
|
||||
assert messages[1]["content"] == "b"
|
||||
|
||||
|
||||
class TestWhatsAppDMSessionKeyConsistency:
|
||||
"""Regression: all session-key construction must go through build_session_key
|
||||
so DMs are isolated by chat_id across platforms."""
|
||||
|
||||
@pytest.fixture()
|
||||
def store(self, tmp_path):
|
||||
config = GatewayConfig()
|
||||
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||
s = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
s._db = None
|
||||
s._loaded = True
|
||||
return s
|
||||
|
||||
def test_whatsapp_dm_includes_chat_id(self):
|
||||
source = SessionSource(
|
||||
platform=Platform.WHATSAPP,
|
||||
chat_id="15551234567@s.whatsapp.net",
|
||||
chat_type="dm",
|
||||
user_name="Phone User",
|
||||
)
|
||||
key = build_session_key(source)
|
||||
assert key == "agent:main:whatsapp:dm:15551234567@s.whatsapp.net"
|
||||
|
||||
def test_store_delegates_to_build_session_key(self, store):
|
||||
"""SessionStore._generate_session_key must produce the same result."""
|
||||
source = SessionSource(
|
||||
platform=Platform.WHATSAPP,
|
||||
chat_id="15551234567@s.whatsapp.net",
|
||||
chat_type="dm",
|
||||
user_name="Phone User",
|
||||
)
|
||||
assert store._generate_session_key(source) == build_session_key(source)
|
||||
|
||||
def test_store_creates_distinct_group_sessions_per_user(self, store):
|
||||
first = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="guild-123",
|
||||
chat_type="group",
|
||||
user_id="alice",
|
||||
user_name="Alice",
|
||||
)
|
||||
second = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="guild-123",
|
||||
chat_type="group",
|
||||
user_id="bob",
|
||||
user_name="Bob",
|
||||
)
|
||||
|
||||
first_entry = store.get_or_create_session(first)
|
||||
second_entry = store.get_or_create_session(second)
|
||||
|
||||
assert first_entry.session_key == "agent:main:discord:group:guild-123:alice"
|
||||
assert second_entry.session_key == "agent:main:discord:group:guild-123:bob"
|
||||
assert first_entry.session_id != second_entry.session_id
|
||||
|
||||
def test_store_shares_group_sessions_when_disabled_in_config(self, store):
|
||||
store.config.group_sessions_per_user = False
|
||||
|
||||
first = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="guild-123",
|
||||
chat_type="group",
|
||||
user_id="alice",
|
||||
user_name="Alice",
|
||||
)
|
||||
second = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="guild-123",
|
||||
chat_type="group",
|
||||
user_id="bob",
|
||||
user_name="Bob",
|
||||
)
|
||||
|
||||
first_entry = store.get_or_create_session(first)
|
||||
second_entry = store.get_or_create_session(second)
|
||||
|
||||
assert first_entry.session_key == "agent:main:discord:group:guild-123"
|
||||
assert second_entry.session_key == "agent:main:discord:group:guild-123"
|
||||
assert first_entry.session_id == second_entry.session_id
|
||||
|
||||
def test_telegram_dm_includes_chat_id(self):
|
||||
"""Non-WhatsApp DMs should also include chat_id to separate users."""
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="99",
|
||||
chat_type="dm",
|
||||
)
|
||||
key = build_session_key(source)
|
||||
assert key == "agent:main:telegram:dm:99"
|
||||
|
||||
def test_distinct_dm_chat_ids_get_distinct_session_keys(self):
|
||||
"""Different DM chats must not collapse into one shared session."""
|
||||
first = SessionSource(platform=Platform.TELEGRAM, chat_id="99", chat_type="dm")
|
||||
second = SessionSource(platform=Platform.TELEGRAM, chat_id="100", chat_type="dm")
|
||||
|
||||
assert build_session_key(first) == "agent:main:telegram:dm:99"
|
||||
assert build_session_key(second) == "agent:main:telegram:dm:100"
|
||||
assert build_session_key(first) != build_session_key(second)
|
||||
|
||||
def test_discord_group_includes_chat_id(self):
|
||||
"""Group/channel keys include chat_type and chat_id."""
|
||||
source = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="guild-123",
|
||||
chat_type="group",
|
||||
)
|
||||
key = build_session_key(source)
|
||||
assert key == "agent:main:discord:group:guild-123"
|
||||
|
||||
def test_group_sessions_are_isolated_per_user_when_user_id_present(self):
|
||||
first = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="guild-123",
|
||||
chat_type="group",
|
||||
user_id="alice",
|
||||
)
|
||||
second = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="guild-123",
|
||||
chat_type="group",
|
||||
user_id="bob",
|
||||
)
|
||||
|
||||
assert build_session_key(first) == "agent:main:discord:group:guild-123:alice"
|
||||
assert build_session_key(second) == "agent:main:discord:group:guild-123:bob"
|
||||
assert build_session_key(first) != build_session_key(second)
|
||||
|
||||
def test_group_sessions_can_be_shared_when_isolation_disabled(self):
|
||||
first = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="guild-123",
|
||||
chat_type="group",
|
||||
user_id="alice",
|
||||
)
|
||||
second = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="guild-123",
|
||||
chat_type="group",
|
||||
user_id="bob",
|
||||
)
|
||||
|
||||
assert build_session_key(first, group_sessions_per_user=False) == "agent:main:discord:group:guild-123"
|
||||
assert build_session_key(second, group_sessions_per_user=False) == "agent:main:discord:group:guild-123"
|
||||
|
||||
def test_group_thread_includes_thread_id(self):
|
||||
"""Forum-style threads need a distinct session key within one group."""
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="-1002285219667",
|
||||
chat_type="group",
|
||||
thread_id="17585",
|
||||
)
|
||||
key = build_session_key(source)
|
||||
assert key == "agent:main:telegram:group:-1002285219667:17585"
|
||||
|
||||
def test_group_thread_sessions_are_isolated_per_user(self):
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="-1002285219667",
|
||||
chat_type="group",
|
||||
thread_id="17585",
|
||||
user_id="42",
|
||||
)
|
||||
key = build_session_key(source)
|
||||
assert key == "agent:main:telegram:group:-1002285219667:17585:42"
|
||||
|
||||
|
||||
class TestSessionStoreEntriesAttribute:
|
||||
"""Regression: /reset must access _entries, not _sessions."""
|
||||
|
||||
def test_entries_attribute_exists(self):
|
||||
config = GatewayConfig()
|
||||
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||
store = SessionStore(sessions_dir=Path("/tmp"), config=config)
|
||||
store._loaded = True
|
||||
assert hasattr(store, "_entries")
|
||||
assert not hasattr(store, "_sessions")
|
||||
|
||||
|
||||
class TestHasAnySessions:
|
||||
"""Tests for has_any_sessions() fix (issue #351)."""
|
||||
|
||||
@pytest.fixture
|
||||
def store_with_mock_db(self, tmp_path):
|
||||
"""SessionStore with a mocked database."""
|
||||
config = GatewayConfig()
|
||||
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||
s = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
s._loaded = True
|
||||
s._entries = {}
|
||||
s._db = MagicMock()
|
||||
return s
|
||||
|
||||
def test_uses_database_count_when_available(self, store_with_mock_db):
|
||||
"""has_any_sessions should use database session_count, not len(_entries)."""
|
||||
store = store_with_mock_db
|
||||
# Simulate single-platform user with only 1 entry in memory
|
||||
store._entries = {"telegram:12345": MagicMock()}
|
||||
# But database has 3 sessions (current + 2 previous resets)
|
||||
store._db.session_count.return_value = 3
|
||||
|
||||
assert store.has_any_sessions() is True
|
||||
store._db.session_count.assert_called_once()
|
||||
|
||||
def test_first_session_ever_returns_false(self, store_with_mock_db):
|
||||
"""First session ever should return False (only current session in DB)."""
|
||||
store = store_with_mock_db
|
||||
store._entries = {"telegram:12345": MagicMock()}
|
||||
# Database has exactly 1 session (the current one just created)
|
||||
store._db.session_count.return_value = 1
|
||||
|
||||
assert store.has_any_sessions() is False
|
||||
|
||||
def test_fallback_without_database(self, tmp_path):
|
||||
"""Should fall back to len(_entries) when DB is not available."""
|
||||
config = GatewayConfig()
|
||||
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||
store = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
store._loaded = True
|
||||
store._db = None
|
||||
store._entries = {"key1": MagicMock(), "key2": MagicMock()}
|
||||
|
||||
# > 1 entries means has sessions
|
||||
assert store.has_any_sessions() is True
|
||||
|
||||
store._entries = {"key1": MagicMock()}
|
||||
assert store.has_any_sessions() is False
|
||||
|
||||
|
||||
class TestLastPromptTokens:
|
||||
"""Tests for the last_prompt_tokens field — actual API token tracking."""
|
||||
|
||||
def test_session_entry_default(self):
|
||||
"""New sessions should have last_prompt_tokens=0."""
|
||||
from gateway.session import SessionEntry
|
||||
from datetime import datetime
|
||||
entry = SessionEntry(
|
||||
session_key="test",
|
||||
session_id="s1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
assert entry.last_prompt_tokens == 0
|
||||
|
||||
def test_session_entry_roundtrip(self):
|
||||
"""last_prompt_tokens should survive serialization/deserialization."""
|
||||
from gateway.session import SessionEntry
|
||||
from datetime import datetime
|
||||
entry = SessionEntry(
|
||||
session_key="test",
|
||||
session_id="s1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
last_prompt_tokens=42000,
|
||||
)
|
||||
d = entry.to_dict()
|
||||
assert d["last_prompt_tokens"] == 42000
|
||||
restored = SessionEntry.from_dict(d)
|
||||
assert restored.last_prompt_tokens == 42000
|
||||
|
||||
def test_session_entry_from_old_data(self):
|
||||
"""Old session data without last_prompt_tokens should default to 0."""
|
||||
from gateway.session import SessionEntry
|
||||
data = {
|
||||
"session_key": "test",
|
||||
"session_id": "s1",
|
||||
"created_at": "2025-01-01T00:00:00",
|
||||
"updated_at": "2025-01-01T00:00:00",
|
||||
"input_tokens": 100,
|
||||
"output_tokens": 50,
|
||||
"total_tokens": 150,
|
||||
# No last_prompt_tokens — old format
|
||||
}
|
||||
entry = SessionEntry.from_dict(data)
|
||||
assert entry.last_prompt_tokens == 0
|
||||
|
||||
def test_update_session_sets_last_prompt_tokens(self, tmp_path):
|
||||
"""update_session should store the actual prompt token count."""
|
||||
config = GatewayConfig()
|
||||
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||
store = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
store._loaded = True
|
||||
store._db = None
|
||||
store._save = MagicMock()
|
||||
|
||||
from gateway.session import SessionEntry
|
||||
from datetime import datetime
|
||||
entry = SessionEntry(
|
||||
session_key="k1",
|
||||
session_id="s1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
store._entries = {"k1": entry}
|
||||
|
||||
store.update_session("k1", last_prompt_tokens=85000)
|
||||
assert entry.last_prompt_tokens == 85000
|
||||
|
||||
def test_update_session_none_does_not_change(self, tmp_path):
|
||||
"""update_session with default (None) should not change last_prompt_tokens."""
|
||||
config = GatewayConfig()
|
||||
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||
store = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
store._loaded = True
|
||||
store._db = None
|
||||
store._save = MagicMock()
|
||||
|
||||
from gateway.session import SessionEntry
|
||||
from datetime import datetime
|
||||
entry = SessionEntry(
|
||||
session_key="k1",
|
||||
session_id="s1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
last_prompt_tokens=50000,
|
||||
)
|
||||
store._entries = {"k1": entry}
|
||||
|
||||
store.update_session("k1") # No last_prompt_tokens arg
|
||||
assert entry.last_prompt_tokens == 50000 # unchanged
|
||||
|
||||
def test_update_session_zero_resets(self, tmp_path):
|
||||
"""update_session with last_prompt_tokens=0 should reset the field."""
|
||||
config = GatewayConfig()
|
||||
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||
store = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
store._loaded = True
|
||||
store._db = None
|
||||
store._save = MagicMock()
|
||||
|
||||
from gateway.session import SessionEntry
|
||||
from datetime import datetime
|
||||
entry = SessionEntry(
|
||||
session_key="k1",
|
||||
session_id="s1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
last_prompt_tokens=85000,
|
||||
)
|
||||
store._entries = {"k1": entry}
|
||||
|
||||
store.update_session("k1", last_prompt_tokens=0)
|
||||
assert entry.last_prompt_tokens == 0
|
||||
|
||||
def test_update_session_passes_model_to_db(self, tmp_path):
|
||||
"""Gateway session updates should forward the resolved model to SQLite."""
|
||||
config = GatewayConfig()
|
||||
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||
store = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
store._loaded = True
|
||||
store._save = MagicMock()
|
||||
store._db = MagicMock()
|
||||
|
||||
from gateway.session import SessionEntry
|
||||
from datetime import datetime
|
||||
entry = SessionEntry(
|
||||
session_key="k1",
|
||||
session_id="s1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
store._entries = {"k1": entry}
|
||||
|
||||
store.update_session("k1", model="openai/gpt-5.4")
|
||||
|
||||
store._db.update_token_counts.assert_called_once_with(
|
||||
"s1",
|
||||
input_tokens=0,
|
||||
output_tokens=0,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
estimated_cost_usd=None,
|
||||
cost_status=None,
|
||||
cost_source=None,
|
||||
billing_provider=None,
|
||||
billing_base_url=None,
|
||||
model="openai/gpt-5.4",
|
||||
)
|
||||
45
hermes_code/tests/gateway/test_session_env.py
Normal file
45
hermes_code/tests/gateway/test_session_env.py
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
import os
|
||||
|
||||
from gateway.config import Platform
|
||||
from gateway.run import GatewayRunner
|
||||
from gateway.session import SessionContext, SessionSource
|
||||
|
||||
|
||||
def test_set_session_env_includes_thread_id(monkeypatch):
|
||||
runner = object.__new__(GatewayRunner)
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="-1001",
|
||||
chat_name="Group",
|
||||
chat_type="group",
|
||||
thread_id="17585",
|
||||
)
|
||||
context = SessionContext(source=source, connected_platforms=[], home_channels={})
|
||||
|
||||
monkeypatch.delenv("HERMES_SESSION_PLATFORM", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_CHAT_ID", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_CHAT_NAME", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_THREAD_ID", raising=False)
|
||||
|
||||
runner._set_session_env(context)
|
||||
|
||||
assert os.getenv("HERMES_SESSION_PLATFORM") == "telegram"
|
||||
assert os.getenv("HERMES_SESSION_CHAT_ID") == "-1001"
|
||||
assert os.getenv("HERMES_SESSION_CHAT_NAME") == "Group"
|
||||
assert os.getenv("HERMES_SESSION_THREAD_ID") == "17585"
|
||||
|
||||
|
||||
def test_clear_session_env_removes_thread_id(monkeypatch):
|
||||
runner = object.__new__(GatewayRunner)
|
||||
|
||||
monkeypatch.setenv("HERMES_SESSION_PLATFORM", "telegram")
|
||||
monkeypatch.setenv("HERMES_SESSION_CHAT_ID", "-1001")
|
||||
monkeypatch.setenv("HERMES_SESSION_CHAT_NAME", "Group")
|
||||
monkeypatch.setenv("HERMES_SESSION_THREAD_ID", "17585")
|
||||
|
||||
runner._clear_session_env()
|
||||
|
||||
assert os.getenv("HERMES_SESSION_PLATFORM") is None
|
||||
assert os.getenv("HERMES_SESSION_CHAT_ID") is None
|
||||
assert os.getenv("HERMES_SESSION_CHAT_NAME") is None
|
||||
assert os.getenv("HERMES_SESSION_THREAD_ID") is None
|
||||
383
hermes_code/tests/gateway/test_session_hygiene.py
Normal file
383
hermes_code/tests/gateway/test_session_hygiene.py
Normal file
|
|
@ -0,0 +1,383 @@
|
|||
"""Tests for gateway session hygiene — auto-compression of large sessions.
|
||||
|
||||
Verifies that the gateway detects pathologically large transcripts and
|
||||
triggers auto-compression before running the agent. (#628)
|
||||
|
||||
The hygiene system uses the SAME compression config as the agent:
|
||||
compression.threshold × model context length
|
||||
so CLI and messaging platforms behave identically.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
import types
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from agent.model_metadata import estimate_messages_tokens_rough
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, SendResult
|
||||
from gateway.session import SessionEntry, SessionSource
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_history(n_messages: int, content_size: int = 100) -> list:
|
||||
"""Build a fake transcript with n_messages user/assistant pairs."""
|
||||
history = []
|
||||
content = "x" * content_size
|
||||
for i in range(n_messages):
|
||||
role = "user" if i % 2 == 0 else "assistant"
|
||||
history.append({"role": role, "content": content, "timestamp": f"t{i}"})
|
||||
return history
|
||||
|
||||
|
||||
def _make_large_history_tokens(target_tokens: int) -> list:
|
||||
"""Build a history that estimates to roughly target_tokens tokens."""
|
||||
# estimate_messages_tokens_rough counts total chars in str(msg) // 4
|
||||
# Each msg dict has ~60 chars of overhead + content chars
|
||||
# So for N tokens we need roughly N * 4 total chars across all messages
|
||||
target_chars = target_tokens * 4
|
||||
# Each message as a dict string is roughly len(content) + 60 chars
|
||||
msg_overhead = 60
|
||||
# Use 50 messages with appropriately sized content
|
||||
n_msgs = 50
|
||||
content_size = max(10, (target_chars // n_msgs) - msg_overhead)
|
||||
return _make_history(n_msgs, content_size=content_size)
|
||||
|
||||
|
||||
class HygieneCaptureAdapter(BasePlatformAdapter):
|
||||
def __init__(self):
|
||||
super().__init__(PlatformConfig(enabled=True, token="fake-token"), Platform.TELEGRAM)
|
||||
self.sent = []
|
||||
|
||||
async def connect(self) -> bool:
|
||||
return True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
return None
|
||||
|
||||
async def send(self, chat_id, content, reply_to=None, metadata=None) -> SendResult:
|
||||
self.sent.append(
|
||||
{
|
||||
"chat_id": chat_id,
|
||||
"content": content,
|
||||
"reply_to": reply_to,
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
return SendResult(success=True, message_id="hygiene-1")
|
||||
|
||||
async def get_chat_info(self, chat_id: str):
|
||||
return {"id": chat_id}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Detection threshold tests (model-aware, unified with compression config)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSessionHygieneThresholds:
|
||||
"""Test that the threshold logic correctly identifies large sessions.
|
||||
|
||||
Thresholds are derived from model context length × compression threshold,
|
||||
matching what the agent's ContextCompressor uses.
|
||||
"""
|
||||
|
||||
def test_small_session_below_thresholds(self):
|
||||
"""A 10-message session should not trigger compression."""
|
||||
history = _make_history(10)
|
||||
approx_tokens = estimate_messages_tokens_rough(history)
|
||||
|
||||
# For a 200k-context model at 85% threshold = 170k
|
||||
context_length = 200_000
|
||||
threshold_pct = 0.85
|
||||
compress_token_threshold = int(context_length * threshold_pct)
|
||||
|
||||
needs_compress = approx_tokens >= compress_token_threshold
|
||||
assert not needs_compress
|
||||
|
||||
def test_large_token_count_triggers(self):
|
||||
"""High token count should trigger compression when exceeding model threshold."""
|
||||
# Build a history that exceeds 85% of a 200k model (170k tokens)
|
||||
history = _make_large_history_tokens(180_000)
|
||||
approx_tokens = estimate_messages_tokens_rough(history)
|
||||
|
||||
context_length = 200_000
|
||||
threshold_pct = 0.85
|
||||
compress_token_threshold = int(context_length * threshold_pct)
|
||||
|
||||
needs_compress = approx_tokens >= compress_token_threshold
|
||||
assert needs_compress
|
||||
|
||||
def test_under_threshold_no_trigger(self):
|
||||
"""Session under threshold should not trigger, even with many messages."""
|
||||
# 250 short messages — lots of messages but well under token threshold
|
||||
history = _make_history(250, content_size=10)
|
||||
approx_tokens = estimate_messages_tokens_rough(history)
|
||||
|
||||
# 200k model at 85% = 170k token threshold
|
||||
context_length = 200_000
|
||||
threshold_pct = 0.85
|
||||
compress_token_threshold = int(context_length * threshold_pct)
|
||||
|
||||
needs_compress = approx_tokens >= compress_token_threshold
|
||||
assert not needs_compress, (
|
||||
f"250 short messages (~{approx_tokens} tokens) should NOT trigger "
|
||||
f"compression at {compress_token_threshold} token threshold"
|
||||
)
|
||||
|
||||
def test_message_count_alone_does_not_trigger(self):
|
||||
"""Message count alone should NOT trigger — only token count matters.
|
||||
|
||||
The old system used an OR of token-count and message-count thresholds,
|
||||
which caused premature compression in tool-heavy sessions with 200+
|
||||
messages but low total tokens.
|
||||
"""
|
||||
# 300 very short messages — old system would compress, new should not
|
||||
history = _make_history(300, content_size=10)
|
||||
approx_tokens = estimate_messages_tokens_rough(history)
|
||||
|
||||
context_length = 200_000
|
||||
threshold_pct = 0.85
|
||||
compress_token_threshold = int(context_length * threshold_pct)
|
||||
|
||||
# Token-based check only
|
||||
needs_compress = approx_tokens >= compress_token_threshold
|
||||
assert not needs_compress
|
||||
|
||||
def test_threshold_scales_with_model(self):
|
||||
"""Different models should have different compression thresholds."""
|
||||
# 128k model at 85% = 108,800 tokens
|
||||
small_model_threshold = int(128_000 * 0.85)
|
||||
# 200k model at 85% = 170,000 tokens
|
||||
large_model_threshold = int(200_000 * 0.85)
|
||||
# 1M model at 85% = 850,000 tokens
|
||||
huge_model_threshold = int(1_000_000 * 0.85)
|
||||
|
||||
# A session at ~120k tokens:
|
||||
history = _make_large_history_tokens(120_000)
|
||||
approx_tokens = estimate_messages_tokens_rough(history)
|
||||
|
||||
# Should trigger for 128k model
|
||||
assert approx_tokens >= small_model_threshold
|
||||
# Should NOT trigger for 200k model
|
||||
assert approx_tokens < large_model_threshold
|
||||
# Should NOT trigger for 1M model
|
||||
assert approx_tokens < huge_model_threshold
|
||||
|
||||
def test_custom_threshold_percentage(self):
|
||||
"""Custom threshold percentage from config should be respected."""
|
||||
context_length = 200_000
|
||||
|
||||
# At 50% threshold = 100k
|
||||
low_threshold = int(context_length * 0.50)
|
||||
# At 90% threshold = 180k
|
||||
high_threshold = int(context_length * 0.90)
|
||||
|
||||
history = _make_large_history_tokens(150_000)
|
||||
approx_tokens = estimate_messages_tokens_rough(history)
|
||||
|
||||
# Should trigger at 50% but not at 90%
|
||||
assert approx_tokens >= low_threshold
|
||||
assert approx_tokens < high_threshold
|
||||
|
||||
def test_minimum_message_guard(self):
|
||||
"""Sessions with fewer than 4 messages should never trigger."""
|
||||
history = _make_history(3, content_size=100_000)
|
||||
# Even with enormous content, < 4 messages should be skipped
|
||||
# (the gateway code checks `len(history) >= 4` before evaluating)
|
||||
assert len(history) < 4
|
||||
|
||||
|
||||
class TestSessionHygieneWarnThreshold:
|
||||
"""Test the post-compression warning threshold (95% of context)."""
|
||||
|
||||
def test_warn_when_still_large(self):
|
||||
"""If compressed result is still above 95% of context, should warn."""
|
||||
context_length = 200_000
|
||||
warn_threshold = int(context_length * 0.95) # 190k
|
||||
post_compress_tokens = 195_000
|
||||
assert post_compress_tokens >= warn_threshold
|
||||
|
||||
def test_no_warn_when_under(self):
|
||||
"""If compressed result is under 95% of context, no warning."""
|
||||
context_length = 200_000
|
||||
warn_threshold = int(context_length * 0.95) # 190k
|
||||
post_compress_tokens = 150_000
|
||||
assert post_compress_tokens < warn_threshold
|
||||
|
||||
|
||||
class TestEstimatedTokenThreshold:
|
||||
"""Verify that hygiene thresholds are always below the model's context
|
||||
limit — for both actual and estimated token counts.
|
||||
|
||||
Regression: a previous 1.4x multiplier on rough estimates pushed the
|
||||
threshold to 85% * 1.4 = 119% of context, which exceeded the model's
|
||||
limit and prevented hygiene from ever firing for ~200K models (GLM-5).
|
||||
The fix removed the multiplier entirely — the 85% threshold already
|
||||
provides ample headroom over the agent's 50% compressor.
|
||||
"""
|
||||
|
||||
def test_threshold_below_context_for_200k_model(self):
|
||||
"""Hygiene threshold must always be below model context."""
|
||||
context_length = 200_000
|
||||
threshold = int(context_length * 0.85)
|
||||
assert threshold < context_length
|
||||
|
||||
def test_threshold_below_context_for_128k_model(self):
|
||||
context_length = 128_000
|
||||
threshold = int(context_length * 0.85)
|
||||
assert threshold < context_length
|
||||
|
||||
def test_no_multiplier_means_same_threshold_for_estimated_and_actual(self):
|
||||
"""Without the 1.4x, estimated and actual token paths use the same threshold."""
|
||||
context_length = 200_000
|
||||
threshold_pct = 0.85
|
||||
threshold = int(context_length * threshold_pct)
|
||||
# Both paths should use 170K — no inflation
|
||||
assert threshold == 170_000
|
||||
|
||||
def test_warn_threshold_below_context(self):
|
||||
"""Warn threshold (95%) must be below context length."""
|
||||
for ctx in (128_000, 200_000, 1_000_000):
|
||||
warn = int(ctx * 0.95)
|
||||
assert warn < ctx
|
||||
|
||||
def test_overestimate_fires_early_but_safely(self):
|
||||
"""If rough estimate is 50% inflated, hygiene fires at ~57% actual usage.
|
||||
|
||||
That's between the agent's 50% threshold and the model's limit —
|
||||
safe and harmless.
|
||||
"""
|
||||
context_length = 200_000
|
||||
threshold = int(context_length * 0.85) # 170K
|
||||
# If actual tokens = 113K, rough estimate = 113K * 1.5 = 170K
|
||||
# Hygiene fires when estimate hits 170K, actual is ~113K = 57% of ctx
|
||||
actual_when_fires = threshold / 1.5
|
||||
assert actual_when_fires > context_length * 0.50, (
|
||||
"Early fire should still be above agent's 50% threshold"
|
||||
)
|
||||
assert actual_when_fires < context_length, (
|
||||
"Early fire must be well below model limit"
|
||||
)
|
||||
|
||||
|
||||
class TestTokenEstimation:
|
||||
"""Verify rough token estimation works as expected for hygiene checks."""
|
||||
|
||||
def test_empty_history(self):
|
||||
assert estimate_messages_tokens_rough([]) == 0
|
||||
|
||||
def test_proportional_to_content(self):
|
||||
small = _make_history(10, content_size=100)
|
||||
large = _make_history(10, content_size=10_000)
|
||||
assert estimate_messages_tokens_rough(large) > estimate_messages_tokens_rough(small)
|
||||
|
||||
def test_proportional_to_count(self):
|
||||
few = _make_history(10, content_size=1000)
|
||||
many = _make_history(100, content_size=1000)
|
||||
assert estimate_messages_tokens_rough(many) > estimate_messages_tokens_rough(few)
|
||||
|
||||
def test_pathological_session_detected(self):
|
||||
"""The reported pathological case: 648 messages, ~299K tokens.
|
||||
|
||||
With a 200k model at 85% threshold (170k), this should trigger.
|
||||
"""
|
||||
history = _make_history(648, content_size=1800)
|
||||
tokens = estimate_messages_tokens_rough(history)
|
||||
# Should be well above the 170K threshold for a 200k model
|
||||
threshold = int(200_000 * 0.85)
|
||||
assert tokens > threshold
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_hygiene_messages_stay_in_originating_topic(monkeypatch, tmp_path):
|
||||
fake_dotenv = types.ModuleType("dotenv")
|
||||
fake_dotenv.load_dotenv = lambda *args, **kwargs: None
|
||||
monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv)
|
||||
|
||||
class FakeCompressAgent:
|
||||
def __init__(self, **kwargs):
|
||||
self.model = kwargs.get("model")
|
||||
|
||||
def _compress_context(self, messages, *_args, **_kwargs):
|
||||
return ([{"role": "assistant", "content": "compressed"}], None)
|
||||
|
||||
fake_run_agent = types.ModuleType("run_agent")
|
||||
fake_run_agent.AIAgent = FakeCompressAgent
|
||||
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||
|
||||
gateway_run = importlib.import_module("gateway.run")
|
||||
GatewayRunner = gateway_run.GatewayRunner
|
||||
|
||||
adapter = HygieneCaptureAdapter()
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(
|
||||
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="fake-token")}
|
||||
)
|
||||
runner.adapters = {Platform.TELEGRAM: adapter}
|
||||
runner._voice_mode = {}
|
||||
runner.hooks = SimpleNamespace(emit=AsyncMock(), loaded_hooks=False)
|
||||
runner.session_store = MagicMock()
|
||||
runner.session_store.get_or_create_session.return_value = SessionEntry(
|
||||
session_key="agent:main:telegram:group:-1001:17585",
|
||||
session_id="sess-1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="group",
|
||||
)
|
||||
runner.session_store.load_transcript.return_value = _make_history(6, content_size=400)
|
||||
runner.session_store.has_any_sessions.return_value = True
|
||||
runner.session_store.rewrite_transcript = MagicMock()
|
||||
runner.session_store.append_to_transcript = MagicMock()
|
||||
runner._running_agents = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._session_db = None
|
||||
runner._is_user_authorized = lambda _source: True
|
||||
runner._set_session_env = lambda _context: None
|
||||
runner._run_agent = AsyncMock(
|
||||
return_value={
|
||||
"final_response": "ok",
|
||||
"messages": [],
|
||||
"tools": [],
|
||||
"history_offset": 0,
|
||||
"last_prompt_tokens": 0,
|
||||
}
|
||||
)
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "fake"})
|
||||
monkeypatch.setattr(
|
||||
"agent.model_metadata.get_model_context_length",
|
||||
lambda *_args, **_kwargs: 100,
|
||||
)
|
||||
monkeypatch.setenv("TELEGRAM_HOME_CHANNEL", "795544298")
|
||||
|
||||
event = MessageEvent(
|
||||
text="hello",
|
||||
source=SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="-1001",
|
||||
chat_type="group",
|
||||
thread_id="17585",
|
||||
),
|
||||
message_id="1",
|
||||
)
|
||||
|
||||
result = await runner._handle_message(event)
|
||||
|
||||
assert result == "ok"
|
||||
assert len(adapter.sent) == 2
|
||||
assert adapter.sent[0]["chat_id"] == "-1001"
|
||||
assert "Session is large" in adapter.sent[0]["content"]
|
||||
assert adapter.sent[0]["metadata"] == {"thread_id": "17585"}
|
||||
assert adapter.sent[1]["chat_id"] == "-1001"
|
||||
assert "Compressed:" in adapter.sent[1]["content"]
|
||||
assert adapter.sent[1]["metadata"] == {"thread_id": "17585"}
|
||||
267
hermes_code/tests/gateway/test_session_race_guard.py
Normal file
267
hermes_code/tests/gateway/test_session_race_guard.py
Normal file
|
|
@ -0,0 +1,267 @@
|
|||
"""Tests for the session race guard that prevents concurrent agent runs.
|
||||
|
||||
The sentinel-based guard ensures that when _handle_message passes the
|
||||
"is an agent already running?" check and proceeds to the slow async
|
||||
setup path (vision enrichment, STT, hooks, session hygiene), a second
|
||||
message for the same session is correctly recognized as "already running"
|
||||
and routed through the interrupt/queue path instead of spawning a
|
||||
duplicate agent.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import MessageEvent, MessageType
|
||||
from gateway.run import GatewayRunner, _AGENT_PENDING_SENTINEL
|
||||
from gateway.session import SessionSource, build_session_key
|
||||
|
||||
|
||||
class _FakeAdapter:
|
||||
"""Minimal adapter stub for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self._pending_messages = {}
|
||||
|
||||
async def send(self, chat_id, text, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
def _make_runner():
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(
|
||||
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}
|
||||
)
|
||||
runner.adapters = {Platform.TELEGRAM: _FakeAdapter()}
|
||||
runner._running_agents = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._voice_mode = {}
|
||||
runner._is_user_authorized = lambda _source: True
|
||||
return runner
|
||||
|
||||
|
||||
def _make_event(text="hello", chat_id="12345"):
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM, chat_id=chat_id, chat_type="dm"
|
||||
)
|
||||
return MessageEvent(text=text, message_type=MessageType.TEXT, source=source)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Test 1: Sentinel is placed before _handle_message_with_agent runs
|
||||
# ------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_sentinel_placed_before_agent_setup():
|
||||
"""After passing the 'not running' guard, the sentinel must be
|
||||
written into _running_agents *before* any await, so that a
|
||||
concurrent message sees the session as occupied."""
|
||||
runner = _make_runner()
|
||||
event = _make_event()
|
||||
session_key = build_session_key(event.source)
|
||||
|
||||
# Patch _handle_message_with_agent to capture state at entry
|
||||
sentinel_was_set = False
|
||||
|
||||
async def mock_inner(self_inner, ev, src, qk):
|
||||
nonlocal sentinel_was_set
|
||||
sentinel_was_set = runner._running_agents.get(qk) is _AGENT_PENDING_SENTINEL
|
||||
return "ok"
|
||||
|
||||
with patch.object(GatewayRunner, "_handle_message_with_agent", mock_inner):
|
||||
await runner._handle_message(event)
|
||||
|
||||
assert sentinel_was_set, (
|
||||
"Sentinel must be in _running_agents when _handle_message_with_agent starts"
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Test 2: Sentinel is cleaned up after _handle_message_with_agent
|
||||
# ------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_sentinel_cleaned_up_after_handler_returns():
|
||||
"""If _handle_message_with_agent returns normally, the sentinel
|
||||
must be removed so the session is not permanently locked."""
|
||||
runner = _make_runner()
|
||||
event = _make_event()
|
||||
session_key = build_session_key(event.source)
|
||||
|
||||
async def mock_inner(self_inner, ev, src, qk):
|
||||
return "ok"
|
||||
|
||||
with patch.object(GatewayRunner, "_handle_message_with_agent", mock_inner):
|
||||
await runner._handle_message(event)
|
||||
|
||||
assert session_key not in runner._running_agents, (
|
||||
"Sentinel must be removed after handler completes"
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Test 3: Sentinel cleaned up on exception
|
||||
# ------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_sentinel_cleaned_up_on_exception():
|
||||
"""If _handle_message_with_agent raises, the sentinel must still
|
||||
be cleaned up so the session is not permanently locked."""
|
||||
runner = _make_runner()
|
||||
event = _make_event()
|
||||
session_key = build_session_key(event.source)
|
||||
|
||||
async def mock_inner(self_inner, ev, src, qk):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
with patch.object(GatewayRunner, "_handle_message_with_agent", mock_inner):
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
await runner._handle_message(event)
|
||||
|
||||
assert session_key not in runner._running_agents, (
|
||||
"Sentinel must be removed even if handler raises"
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Test 4: Second message during sentinel sees "already running"
|
||||
# ------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_second_message_during_sentinel_queued_not_duplicate():
|
||||
"""While the sentinel is set (agent setup in progress), a second
|
||||
message for the same session must hit the 'already running' branch
|
||||
and be queued — not start a second agent."""
|
||||
runner = _make_runner()
|
||||
event1 = _make_event(text="first message")
|
||||
event2 = _make_event(text="second message")
|
||||
session_key = build_session_key(event1.source)
|
||||
|
||||
barrier = asyncio.Event()
|
||||
|
||||
async def slow_inner(self_inner, ev, src, qk):
|
||||
# Simulate slow setup — wait until test tells us to proceed
|
||||
await barrier.wait()
|
||||
return "ok"
|
||||
|
||||
with patch.object(GatewayRunner, "_handle_message_with_agent", slow_inner):
|
||||
# Start first message (will block at barrier)
|
||||
task1 = asyncio.create_task(runner._handle_message(event1))
|
||||
# Yield so task1 enters slow_inner and sentinel is set
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# Verify sentinel is set
|
||||
assert runner._running_agents.get(session_key) is _AGENT_PENDING_SENTINEL
|
||||
|
||||
# Second message should see "already running" and be queued
|
||||
result2 = await runner._handle_message(event2)
|
||||
assert result2 is None, "Second message should return None (queued)"
|
||||
|
||||
# The second message should have been queued in adapter pending
|
||||
adapter = runner.adapters[Platform.TELEGRAM]
|
||||
assert session_key in adapter._pending_messages, (
|
||||
"Second message should be queued as pending"
|
||||
)
|
||||
assert adapter._pending_messages[session_key] is event2
|
||||
|
||||
# Let first message complete
|
||||
barrier.set()
|
||||
await task1
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Test 5: Sentinel not placed for command messages
|
||||
# ------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_command_messages_do_not_leave_sentinel():
|
||||
"""Slash commands (/help, /status, etc.) return early from
|
||||
_handle_message. They must NOT leave a sentinel behind."""
|
||||
runner = _make_runner()
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM, chat_id="12345", chat_type="dm"
|
||||
)
|
||||
event = MessageEvent(
|
||||
text="/help", message_type=MessageType.TEXT, source=source
|
||||
)
|
||||
session_key = build_session_key(source)
|
||||
|
||||
# Mock the help handler to avoid needing full runner setup
|
||||
runner._handle_help_command = AsyncMock(return_value="Help text")
|
||||
# Need hooks for command emission
|
||||
runner.hooks = MagicMock()
|
||||
runner.hooks.emit = AsyncMock()
|
||||
|
||||
await runner._handle_message(event)
|
||||
|
||||
assert session_key not in runner._running_agents, (
|
||||
"Command handlers must not leave sentinel in _running_agents"
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Test 6: /stop during sentinel returns helpful message
|
||||
# ------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_during_sentinel_returns_message():
|
||||
"""If /stop arrives while the sentinel is set (agent still starting),
|
||||
it should return a helpful message instead of crashing or queuing."""
|
||||
runner = _make_runner()
|
||||
event1 = _make_event(text="hello")
|
||||
session_key = build_session_key(event1.source)
|
||||
|
||||
barrier = asyncio.Event()
|
||||
|
||||
async def slow_inner(self_inner, ev, src, qk):
|
||||
await barrier.wait()
|
||||
return "ok"
|
||||
|
||||
with patch.object(GatewayRunner, "_handle_message_with_agent", slow_inner):
|
||||
task1 = asyncio.create_task(runner._handle_message(event1))
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# Sentinel should be set
|
||||
assert runner._running_agents.get(session_key) is _AGENT_PENDING_SENTINEL
|
||||
|
||||
# Send /stop — should get a message, not crash
|
||||
stop_event = _make_event(text="/stop")
|
||||
result = await runner._handle_message(stop_event)
|
||||
assert result is not None, "/stop during sentinel should return a message"
|
||||
assert "starting up" in result.lower()
|
||||
|
||||
# Should NOT be queued as pending
|
||||
adapter = runner.adapters[Platform.TELEGRAM]
|
||||
assert session_key not in adapter._pending_messages
|
||||
|
||||
barrier.set()
|
||||
await task1
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Test 7: Shutdown skips sentinel entries
|
||||
# ------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_shutdown_skips_sentinel():
|
||||
"""During gateway shutdown, sentinel entries in _running_agents
|
||||
should be skipped without raising AttributeError."""
|
||||
runner = _make_runner()
|
||||
session_key = "telegram:dm:99999"
|
||||
|
||||
# Simulate a sentinel in _running_agents
|
||||
runner._running_agents[session_key] = _AGENT_PENDING_SENTINEL
|
||||
|
||||
# Also add a real agent mock to verify it still gets interrupted
|
||||
real_agent = MagicMock()
|
||||
runner._running_agents["telegram:dm:88888"] = real_agent
|
||||
|
||||
runner.adapters = {} # No adapters to disconnect
|
||||
runner._running = True
|
||||
runner._shutdown_event = asyncio.Event()
|
||||
runner._exit_reason = None
|
||||
runner._shutdown_all_gateway_honcho = lambda: None
|
||||
|
||||
with patch("gateway.status.remove_pid_file"), \
|
||||
patch("gateway.status.write_runtime_status"):
|
||||
await runner.stop()
|
||||
|
||||
# Real agent should have been interrupted
|
||||
real_agent.interrupt.assert_called_once()
|
||||
# Should not have raised on the sentinel
|
||||
207
hermes_code/tests/gateway/test_session_reset_notify.py
Normal file
207
hermes_code/tests/gateway/test_session_reset_notify.py
Normal file
|
|
@ -0,0 +1,207 @@
|
|||
"""Tests for session auto-reset notifications.
|
||||
|
||||
Verifies that:
|
||||
- _should_reset() returns a reason string ("idle" or "daily") instead of bool
|
||||
- SessionEntry captures auto_reset_reason
|
||||
- SessionResetPolicy.notify controls whether notifications are sent
|
||||
- notify_exclude_platforms skips notifications for excluded platforms
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import (
|
||||
GatewayConfig,
|
||||
Platform,
|
||||
PlatformConfig,
|
||||
SessionResetPolicy,
|
||||
)
|
||||
from gateway.session import SessionEntry, SessionSource, SessionStore
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_source(platform=Platform.TELEGRAM, chat_id="123", user_id="u1"):
|
||||
return SessionSource(
|
||||
platform=platform,
|
||||
chat_id=chat_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
def _make_store(policy=None, tmp_path=None):
|
||||
config = GatewayConfig()
|
||||
if policy:
|
||||
config.default_reset_policy = policy
|
||||
store = SessionStore(sessions_dir=tmp_path or "/tmp/test-sessions", config=config)
|
||||
return store
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _should_reset returns reason string
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestShouldResetReason:
|
||||
def test_returns_none_when_not_expired(self, tmp_path):
|
||||
store = _make_store(
|
||||
SessionResetPolicy(mode="both", idle_minutes=60, at_hour=4),
|
||||
tmp_path,
|
||||
)
|
||||
entry = SessionEntry(
|
||||
session_key="test",
|
||||
session_id="s1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(), # just updated
|
||||
)
|
||||
source = _make_source()
|
||||
assert store._should_reset(entry, source) is None
|
||||
|
||||
def test_returns_idle_when_idle_expired(self, tmp_path):
|
||||
store = _make_store(
|
||||
SessionResetPolicy(mode="idle", idle_minutes=30),
|
||||
tmp_path,
|
||||
)
|
||||
entry = SessionEntry(
|
||||
session_key="test",
|
||||
session_id="s1",
|
||||
created_at=datetime.now() - timedelta(hours=2),
|
||||
updated_at=datetime.now() - timedelta(hours=1), # 60min ago > 30min threshold
|
||||
)
|
||||
source = _make_source()
|
||||
assert store._should_reset(entry, source) == "idle"
|
||||
|
||||
def test_returns_daily_when_daily_boundary_crossed(self, tmp_path):
|
||||
now = datetime.now()
|
||||
store = _make_store(
|
||||
SessionResetPolicy(mode="daily", at_hour=now.hour),
|
||||
tmp_path,
|
||||
)
|
||||
entry = SessionEntry(
|
||||
session_key="test",
|
||||
session_id="s1",
|
||||
created_at=now - timedelta(days=2),
|
||||
updated_at=now - timedelta(days=1), # last active yesterday
|
||||
)
|
||||
source = _make_source()
|
||||
assert store._should_reset(entry, source) == "daily"
|
||||
|
||||
def test_returns_none_when_mode_is_none(self, tmp_path):
|
||||
store = _make_store(
|
||||
SessionResetPolicy(mode="none"),
|
||||
tmp_path,
|
||||
)
|
||||
entry = SessionEntry(
|
||||
session_key="test",
|
||||
session_id="s1",
|
||||
created_at=datetime.now() - timedelta(days=30),
|
||||
updated_at=datetime.now() - timedelta(days=30),
|
||||
)
|
||||
source = _make_source()
|
||||
assert store._should_reset(entry, source) is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SessionEntry captures reason
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSessionEntryReason:
|
||||
def test_auto_reset_reason_stored(self, tmp_path):
|
||||
store = _make_store(
|
||||
SessionResetPolicy(mode="idle", idle_minutes=1),
|
||||
tmp_path,
|
||||
)
|
||||
source = _make_source()
|
||||
|
||||
# Create initial session
|
||||
entry1 = store.get_or_create_session(source)
|
||||
assert not entry1.was_auto_reset
|
||||
|
||||
# Age it past the idle threshold
|
||||
entry1.updated_at = datetime.now() - timedelta(minutes=5)
|
||||
store._save()
|
||||
|
||||
# Next call should create a new session with reason
|
||||
entry2 = store.get_or_create_session(source)
|
||||
assert entry2.was_auto_reset is True
|
||||
assert entry2.auto_reset_reason == "idle"
|
||||
assert entry2.session_id != entry1.session_id
|
||||
|
||||
def test_reset_had_activity_false_when_no_tokens(self, tmp_path):
|
||||
"""Expired session with no tokens → reset_had_activity=False."""
|
||||
store = _make_store(
|
||||
SessionResetPolicy(mode="idle", idle_minutes=1),
|
||||
tmp_path,
|
||||
)
|
||||
source = _make_source()
|
||||
|
||||
entry1 = store.get_or_create_session(source)
|
||||
# No tokens used — session was idle with no conversation
|
||||
entry1.updated_at = datetime.now() - timedelta(minutes=5)
|
||||
store._save()
|
||||
|
||||
entry2 = store.get_or_create_session(source)
|
||||
assert entry2.was_auto_reset is True
|
||||
assert entry2.reset_had_activity is False
|
||||
|
||||
def test_reset_had_activity_true_when_tokens_used(self, tmp_path):
|
||||
"""Expired session with tokens → reset_had_activity=True."""
|
||||
store = _make_store(
|
||||
SessionResetPolicy(mode="idle", idle_minutes=1),
|
||||
tmp_path,
|
||||
)
|
||||
source = _make_source()
|
||||
|
||||
entry1 = store.get_or_create_session(source)
|
||||
# Simulate some conversation happened
|
||||
entry1.total_tokens = 5000
|
||||
entry1.updated_at = datetime.now() - timedelta(minutes=5)
|
||||
store._save()
|
||||
|
||||
entry2 = store.get_or_create_session(source)
|
||||
assert entry2.was_auto_reset is True
|
||||
assert entry2.reset_had_activity is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SessionResetPolicy notify config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestResetPolicyNotify:
|
||||
def test_notify_defaults_true(self):
|
||||
policy = SessionResetPolicy()
|
||||
assert policy.notify is True
|
||||
|
||||
def test_notify_exclude_defaults(self):
|
||||
policy = SessionResetPolicy()
|
||||
assert "api_server" in policy.notify_exclude_platforms
|
||||
assert "webhook" in policy.notify_exclude_platforms
|
||||
|
||||
def test_from_dict_with_notify_false(self):
|
||||
policy = SessionResetPolicy.from_dict({"notify": False})
|
||||
assert policy.notify is False
|
||||
|
||||
def test_from_dict_with_custom_excludes(self):
|
||||
policy = SessionResetPolicy.from_dict({
|
||||
"notify_exclude_platforms": ["api_server", "webhook", "homeassistant"],
|
||||
})
|
||||
assert "homeassistant" in policy.notify_exclude_platforms
|
||||
|
||||
def test_from_dict_preserves_defaults_on_missing_keys(self):
|
||||
policy = SessionResetPolicy.from_dict({})
|
||||
assert policy.notify is True
|
||||
assert "api_server" in policy.notify_exclude_platforms
|
||||
|
||||
def test_to_dict_roundtrip(self):
|
||||
original = SessionResetPolicy(
|
||||
mode="idle",
|
||||
notify=False,
|
||||
notify_exclude_platforms=("api_server",),
|
||||
)
|
||||
restored = SessionResetPolicy.from_dict(original.to_dict())
|
||||
assert restored.notify == original.notify
|
||||
assert restored.notify_exclude_platforms == original.notify_exclude_platforms
|
||||
assert restored.mode == original.mode
|
||||
298
hermes_code/tests/gateway/test_signal.py
Normal file
298
hermes_code/tests/gateway/test_signal.py
Normal file
|
|
@ -0,0 +1,298 @@
|
|||
"""Tests for Signal messenger platform adapter."""
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Platform & Config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSignalPlatformEnum:
|
||||
def test_signal_enum_exists(self):
|
||||
assert Platform.SIGNAL.value == "signal"
|
||||
|
||||
def test_signal_in_platform_list(self):
|
||||
platforms = [p.value for p in Platform]
|
||||
assert "signal" in platforms
|
||||
|
||||
|
||||
class TestSignalConfigLoading:
|
||||
def test_apply_env_overrides_signal(self, monkeypatch):
|
||||
monkeypatch.setenv("SIGNAL_HTTP_URL", "http://localhost:9090")
|
||||
monkeypatch.setenv("SIGNAL_ACCOUNT", "+15551234567")
|
||||
|
||||
from gateway.config import GatewayConfig, _apply_env_overrides
|
||||
config = GatewayConfig()
|
||||
_apply_env_overrides(config)
|
||||
|
||||
assert Platform.SIGNAL in config.platforms
|
||||
sc = config.platforms[Platform.SIGNAL]
|
||||
assert sc.enabled is True
|
||||
assert sc.extra["http_url"] == "http://localhost:9090"
|
||||
assert sc.extra["account"] == "+15551234567"
|
||||
|
||||
def test_signal_not_loaded_without_both_vars(self, monkeypatch):
|
||||
monkeypatch.setenv("SIGNAL_HTTP_URL", "http://localhost:9090")
|
||||
# No SIGNAL_ACCOUNT
|
||||
|
||||
from gateway.config import GatewayConfig, _apply_env_overrides
|
||||
config = GatewayConfig()
|
||||
_apply_env_overrides(config)
|
||||
|
||||
assert Platform.SIGNAL not in config.platforms
|
||||
|
||||
def test_connected_platforms_includes_signal(self, monkeypatch):
|
||||
monkeypatch.setenv("SIGNAL_HTTP_URL", "http://localhost:8080")
|
||||
monkeypatch.setenv("SIGNAL_ACCOUNT", "+15551234567")
|
||||
|
||||
from gateway.config import GatewayConfig, _apply_env_overrides
|
||||
config = GatewayConfig()
|
||||
_apply_env_overrides(config)
|
||||
|
||||
connected = config.get_connected_platforms()
|
||||
assert Platform.SIGNAL in connected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Adapter Init & Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSignalAdapterInit:
|
||||
def _make_config(self, **extra):
|
||||
config = PlatformConfig()
|
||||
config.enabled = True
|
||||
config.extra = {
|
||||
"http_url": "http://localhost:8080",
|
||||
"account": "+15551234567",
|
||||
**extra,
|
||||
}
|
||||
return config
|
||||
|
||||
def test_init_parses_config(self, monkeypatch):
|
||||
monkeypatch.setenv("SIGNAL_GROUP_ALLOWED_USERS", "group123,group456")
|
||||
|
||||
from gateway.platforms.signal import SignalAdapter
|
||||
adapter = SignalAdapter(self._make_config())
|
||||
|
||||
assert adapter.http_url == "http://localhost:8080"
|
||||
assert adapter.account == "+15551234567"
|
||||
assert "group123" in adapter.group_allow_from
|
||||
|
||||
def test_init_empty_allowlist(self, monkeypatch):
|
||||
monkeypatch.setenv("SIGNAL_GROUP_ALLOWED_USERS", "")
|
||||
|
||||
from gateway.platforms.signal import SignalAdapter
|
||||
adapter = SignalAdapter(self._make_config())
|
||||
|
||||
assert len(adapter.group_allow_from) == 0
|
||||
|
||||
def test_init_strips_trailing_slash(self, monkeypatch):
|
||||
monkeypatch.setenv("SIGNAL_GROUP_ALLOWED_USERS", "")
|
||||
|
||||
from gateway.platforms.signal import SignalAdapter
|
||||
adapter = SignalAdapter(self._make_config(http_url="http://localhost:8080/"))
|
||||
|
||||
assert adapter.http_url == "http://localhost:8080"
|
||||
|
||||
def test_self_message_filtering(self, monkeypatch):
|
||||
monkeypatch.setenv("SIGNAL_GROUP_ALLOWED_USERS", "")
|
||||
|
||||
from gateway.platforms.signal import SignalAdapter
|
||||
adapter = SignalAdapter(self._make_config())
|
||||
|
||||
assert adapter._account_normalized == "+15551234567"
|
||||
|
||||
|
||||
class TestSignalHelpers:
|
||||
def test_redact_phone_long(self):
|
||||
from gateway.platforms.signal import _redact_phone
|
||||
assert _redact_phone("+15551234567") == "+155****4567"
|
||||
|
||||
def test_redact_phone_short(self):
|
||||
from gateway.platforms.signal import _redact_phone
|
||||
assert _redact_phone("+12345") == "+1****45"
|
||||
|
||||
def test_redact_phone_empty(self):
|
||||
from gateway.platforms.signal import _redact_phone
|
||||
assert _redact_phone("") == "<none>"
|
||||
|
||||
def test_parse_comma_list(self):
|
||||
from gateway.platforms.signal import _parse_comma_list
|
||||
assert _parse_comma_list("+1234, +5678 , +9012") == ["+1234", "+5678", "+9012"]
|
||||
assert _parse_comma_list("") == []
|
||||
assert _parse_comma_list(" , , ") == []
|
||||
|
||||
def test_guess_extension_png(self):
|
||||
from gateway.platforms.signal import _guess_extension
|
||||
assert _guess_extension(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100) == ".png"
|
||||
|
||||
def test_guess_extension_jpeg(self):
|
||||
from gateway.platforms.signal import _guess_extension
|
||||
assert _guess_extension(b"\xff\xd8\xff\xe0" + b"\x00" * 100) == ".jpg"
|
||||
|
||||
def test_guess_extension_pdf(self):
|
||||
from gateway.platforms.signal import _guess_extension
|
||||
assert _guess_extension(b"%PDF-1.4" + b"\x00" * 100) == ".pdf"
|
||||
|
||||
def test_guess_extension_zip(self):
|
||||
from gateway.platforms.signal import _guess_extension
|
||||
assert _guess_extension(b"PK\x03\x04" + b"\x00" * 100) == ".zip"
|
||||
|
||||
def test_guess_extension_mp4(self):
|
||||
from gateway.platforms.signal import _guess_extension
|
||||
assert _guess_extension(b"\x00\x00\x00\x18ftypisom" + b"\x00" * 100) == ".mp4"
|
||||
|
||||
def test_guess_extension_unknown(self):
|
||||
from gateway.platforms.signal import _guess_extension
|
||||
assert _guess_extension(b"\x00\x01\x02\x03" * 10) == ".bin"
|
||||
|
||||
def test_is_image_ext(self):
|
||||
from gateway.platforms.signal import _is_image_ext
|
||||
assert _is_image_ext(".png") is True
|
||||
assert _is_image_ext(".jpg") is True
|
||||
assert _is_image_ext(".gif") is True
|
||||
assert _is_image_ext(".pdf") is False
|
||||
|
||||
def test_is_audio_ext(self):
|
||||
from gateway.platforms.signal import _is_audio_ext
|
||||
assert _is_audio_ext(".mp3") is True
|
||||
assert _is_audio_ext(".ogg") is True
|
||||
assert _is_audio_ext(".png") is False
|
||||
|
||||
def test_check_requirements(self, monkeypatch):
|
||||
from gateway.platforms.signal import check_signal_requirements
|
||||
monkeypatch.setenv("SIGNAL_HTTP_URL", "http://localhost:8080")
|
||||
monkeypatch.setenv("SIGNAL_ACCOUNT", "+15551234567")
|
||||
assert check_signal_requirements() is True
|
||||
|
||||
def test_render_mentions(self):
|
||||
from gateway.platforms.signal import _render_mentions
|
||||
text = "Hello \uFFFC, how are you?"
|
||||
mentions = [{"start": 6, "length": 1, "number": "+15559999999"}]
|
||||
result = _render_mentions(text, mentions)
|
||||
assert "@+15559999999" in result
|
||||
assert "\uFFFC" not in result
|
||||
|
||||
def test_render_mentions_no_mentions(self):
|
||||
from gateway.platforms.signal import _render_mentions
|
||||
text = "Hello world"
|
||||
result = _render_mentions(text, [])
|
||||
assert result == "Hello world"
|
||||
|
||||
def test_check_requirements_missing(self, monkeypatch):
|
||||
from gateway.platforms.signal import check_signal_requirements
|
||||
monkeypatch.delenv("SIGNAL_HTTP_URL", raising=False)
|
||||
monkeypatch.delenv("SIGNAL_ACCOUNT", raising=False)
|
||||
assert check_signal_requirements() is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session Source
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSignalSessionSource:
|
||||
def test_session_source_alt_fields(self):
|
||||
from gateway.session import SessionSource
|
||||
source = SessionSource(
|
||||
platform=Platform.SIGNAL,
|
||||
chat_id="+15551234567",
|
||||
user_id="+15551234567",
|
||||
user_id_alt="uuid:abc-123",
|
||||
chat_id_alt=None,
|
||||
)
|
||||
d = source.to_dict()
|
||||
assert d["user_id_alt"] == "uuid:abc-123"
|
||||
assert "chat_id_alt" not in d # None fields excluded
|
||||
|
||||
def test_session_source_roundtrip(self):
|
||||
from gateway.session import SessionSource
|
||||
source = SessionSource(
|
||||
platform=Platform.SIGNAL,
|
||||
chat_id="group:xyz",
|
||||
chat_type="group",
|
||||
user_id="+15551234567",
|
||||
user_id_alt="uuid:abc",
|
||||
chat_id_alt="xyz",
|
||||
)
|
||||
d = source.to_dict()
|
||||
restored = SessionSource.from_dict(d)
|
||||
assert restored.user_id_alt == "uuid:abc"
|
||||
assert restored.chat_id_alt == "xyz"
|
||||
assert restored.platform == Platform.SIGNAL
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Phone Redaction in agent/redact.py
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSignalPhoneRedaction:
|
||||
@pytest.fixture(autouse=True)
|
||||
def _ensure_redaction_enabled(self, monkeypatch):
|
||||
monkeypatch.delenv("HERMES_REDACT_SECRETS", raising=False)
|
||||
|
||||
def test_us_number(self):
|
||||
from agent.redact import redact_sensitive_text
|
||||
result = redact_sensitive_text("Call +15551234567 now")
|
||||
assert "+15551234567" not in result
|
||||
assert "+155" in result # Prefix preserved
|
||||
assert "4567" in result # Suffix preserved
|
||||
|
||||
def test_uk_number(self):
|
||||
from agent.redact import redact_sensitive_text
|
||||
result = redact_sensitive_text("UK: +442071838750")
|
||||
assert "+442071838750" not in result
|
||||
assert "****" in result
|
||||
|
||||
def test_multiple_numbers(self):
|
||||
from agent.redact import redact_sensitive_text
|
||||
text = "From +15551234567 to +442071838750"
|
||||
result = redact_sensitive_text(text)
|
||||
assert "+15551234567" not in result
|
||||
assert "+442071838750" not in result
|
||||
|
||||
def test_short_number_not_matched(self):
|
||||
from agent.redact import redact_sensitive_text
|
||||
result = redact_sensitive_text("Code: +12345")
|
||||
# 5 digits after + is below the 7-digit minimum
|
||||
assert "+12345" in result # Too short to redact
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Authorization in run.py
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSignalAuthorization:
|
||||
def test_signal_in_allowlist_maps(self):
|
||||
"""Signal should be in the platform auth maps."""
|
||||
from gateway.run import GatewayRunner
|
||||
from gateway.config import GatewayConfig
|
||||
|
||||
gw = GatewayRunner.__new__(GatewayRunner)
|
||||
gw.config = GatewayConfig()
|
||||
gw.pairing_store = MagicMock()
|
||||
gw.pairing_store.is_approved.return_value = False
|
||||
|
||||
source = MagicMock()
|
||||
source.platform = Platform.SIGNAL
|
||||
source.user_id = "+15559999999"
|
||||
|
||||
# No allowlists set — should check GATEWAY_ALLOW_ALL_USERS
|
||||
with patch.dict("os.environ", {}, clear=True):
|
||||
result = gw._is_user_authorized(source)
|
||||
assert result is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Send Message Tool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSignalSendMessage:
|
||||
def test_signal_in_platform_map(self):
|
||||
"""Signal should be in the send_message tool's platform map."""
|
||||
from tools.send_message_tool import send_message_tool
|
||||
# Just verify the import works and Signal is a valid platform
|
||||
from gateway.config import Platform
|
||||
assert Platform.SIGNAL.value == "signal"
|
||||
948
hermes_code/tests/gateway/test_slack.py
Normal file
948
hermes_code/tests/gateway/test_slack.py
Normal file
|
|
@ -0,0 +1,948 @@
|
|||
"""
|
||||
Tests for Slack platform adapter.
|
||||
|
||||
Covers: app_mention handler, send_document, send_video,
|
||||
incoming document handling, message routing.
|
||||
|
||||
Note: slack-bolt may not be installed in the test environment.
|
||||
We mock the slack modules at import time to avoid collection errors.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import (
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock the slack-bolt package if it's not installed
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _ensure_slack_mock():
|
||||
"""Install mock slack modules so SlackAdapter can be imported."""
|
||||
if "slack_bolt" in sys.modules and hasattr(sys.modules["slack_bolt"], "__file__"):
|
||||
return # Real library installed
|
||||
|
||||
slack_bolt = MagicMock()
|
||||
slack_bolt.async_app.AsyncApp = MagicMock
|
||||
slack_bolt.adapter.socket_mode.async_handler.AsyncSocketModeHandler = MagicMock
|
||||
|
||||
slack_sdk = MagicMock()
|
||||
slack_sdk.web.async_client.AsyncWebClient = MagicMock
|
||||
|
||||
for name, mod in [
|
||||
("slack_bolt", slack_bolt),
|
||||
("slack_bolt.async_app", slack_bolt.async_app),
|
||||
("slack_bolt.adapter", slack_bolt.adapter),
|
||||
("slack_bolt.adapter.socket_mode", slack_bolt.adapter.socket_mode),
|
||||
("slack_bolt.adapter.socket_mode.async_handler", slack_bolt.adapter.socket_mode.async_handler),
|
||||
("slack_sdk", slack_sdk),
|
||||
("slack_sdk.web", slack_sdk.web),
|
||||
("slack_sdk.web.async_client", slack_sdk.web.async_client),
|
||||
]:
|
||||
sys.modules.setdefault(name, mod)
|
||||
|
||||
|
||||
_ensure_slack_mock()
|
||||
|
||||
# Patch SLACK_AVAILABLE before importing the adapter
|
||||
import gateway.platforms.slack as _slack_mod
|
||||
_slack_mod.SLACK_AVAILABLE = True
|
||||
|
||||
from gateway.platforms.slack import SlackAdapter # noqa: E402
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture()
|
||||
def adapter():
|
||||
config = PlatformConfig(enabled=True, token="xoxb-fake-token")
|
||||
a = SlackAdapter(config)
|
||||
# Mock the Slack app client
|
||||
a._app = MagicMock()
|
||||
a._app.client = AsyncMock()
|
||||
a._bot_user_id = "U_BOT"
|
||||
a._running = True
|
||||
# Capture events instead of processing them
|
||||
a.handle_message = AsyncMock()
|
||||
return a
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _redirect_cache(tmp_path, monkeypatch):
|
||||
"""Point document cache to tmp_path so tests don't touch ~/.hermes."""
|
||||
monkeypatch.setattr(
|
||||
"gateway.platforms.base.DOCUMENT_CACHE_DIR", tmp_path / "doc_cache"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestAppMentionHandler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestAppMentionHandler:
|
||||
"""Verify that the app_mention event handler is registered."""
|
||||
|
||||
def test_app_mention_registered_on_connect(self):
|
||||
"""connect() should register both 'message' and 'app_mention' handlers."""
|
||||
config = PlatformConfig(enabled=True, token="xoxb-fake")
|
||||
adapter = SlackAdapter(config)
|
||||
|
||||
# Track which events get registered
|
||||
registered_events = []
|
||||
registered_commands = []
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
def mock_event(event_type):
|
||||
def decorator(fn):
|
||||
registered_events.append(event_type)
|
||||
return fn
|
||||
return decorator
|
||||
|
||||
def mock_command(cmd):
|
||||
def decorator(fn):
|
||||
registered_commands.append(cmd)
|
||||
return fn
|
||||
return decorator
|
||||
|
||||
mock_app.event = mock_event
|
||||
mock_app.command = mock_command
|
||||
mock_app.client = AsyncMock()
|
||||
mock_app.client.auth_test = AsyncMock(return_value={
|
||||
"user_id": "U_BOT",
|
||||
"user": "testbot",
|
||||
})
|
||||
|
||||
with patch.object(_slack_mod, "AsyncApp", return_value=mock_app), \
|
||||
patch.object(_slack_mod, "AsyncSocketModeHandler", return_value=MagicMock()), \
|
||||
patch.dict(os.environ, {"SLACK_APP_TOKEN": "xapp-fake"}), \
|
||||
patch("asyncio.create_task"):
|
||||
asyncio.run(adapter.connect())
|
||||
|
||||
assert "message" in registered_events
|
||||
assert "app_mention" in registered_events
|
||||
assert "/hermes" in registered_commands
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestSendDocument
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSendDocument:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_document_success(self, adapter, tmp_path):
|
||||
test_file = tmp_path / "report.pdf"
|
||||
test_file.write_bytes(b"%PDF-1.4 fake content")
|
||||
|
||||
adapter._app.client.files_upload_v2 = AsyncMock(return_value={"ok": True})
|
||||
|
||||
result = await adapter.send_document(
|
||||
chat_id="C123",
|
||||
file_path=str(test_file),
|
||||
caption="Here's the report",
|
||||
)
|
||||
|
||||
assert result.success
|
||||
adapter._app.client.files_upload_v2.assert_called_once()
|
||||
call_kwargs = adapter._app.client.files_upload_v2.call_args[1]
|
||||
assert call_kwargs["channel"] == "C123"
|
||||
assert call_kwargs["file"] == str(test_file)
|
||||
assert call_kwargs["filename"] == "report.pdf"
|
||||
assert call_kwargs["initial_comment"] == "Here's the report"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_document_custom_name(self, adapter, tmp_path):
|
||||
test_file = tmp_path / "data.csv"
|
||||
test_file.write_bytes(b"a,b,c\n1,2,3")
|
||||
|
||||
adapter._app.client.files_upload_v2 = AsyncMock(return_value={"ok": True})
|
||||
|
||||
result = await adapter.send_document(
|
||||
chat_id="C123",
|
||||
file_path=str(test_file),
|
||||
file_name="quarterly-report.csv",
|
||||
)
|
||||
|
||||
assert result.success
|
||||
call_kwargs = adapter._app.client.files_upload_v2.call_args[1]
|
||||
assert call_kwargs["filename"] == "quarterly-report.csv"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_document_missing_file(self, adapter):
|
||||
result = await adapter.send_document(
|
||||
chat_id="C123",
|
||||
file_path="/nonexistent/file.pdf",
|
||||
)
|
||||
|
||||
assert not result.success
|
||||
assert "not found" in result.error.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_document_not_connected(self, adapter):
|
||||
adapter._app = None
|
||||
result = await adapter.send_document(
|
||||
chat_id="C123",
|
||||
file_path="/some/file.pdf",
|
||||
)
|
||||
|
||||
assert not result.success
|
||||
assert "Not connected" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_document_api_error_falls_back(self, adapter, tmp_path):
|
||||
test_file = tmp_path / "doc.pdf"
|
||||
test_file.write_bytes(b"content")
|
||||
|
||||
adapter._app.client.files_upload_v2 = AsyncMock(
|
||||
side_effect=RuntimeError("Slack API error")
|
||||
)
|
||||
|
||||
# Should fall back to base class (text message)
|
||||
result = await adapter.send_document(
|
||||
chat_id="C123",
|
||||
file_path=str(test_file),
|
||||
)
|
||||
|
||||
# Base class send() is also mocked, so check it was attempted
|
||||
adapter._app.client.chat_postMessage.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_document_with_thread(self, adapter, tmp_path):
|
||||
test_file = tmp_path / "notes.txt"
|
||||
test_file.write_bytes(b"some notes")
|
||||
|
||||
adapter._app.client.files_upload_v2 = AsyncMock(return_value={"ok": True})
|
||||
|
||||
result = await adapter.send_document(
|
||||
chat_id="C123",
|
||||
file_path=str(test_file),
|
||||
reply_to="1234567890.123456",
|
||||
)
|
||||
|
||||
assert result.success
|
||||
call_kwargs = adapter._app.client.files_upload_v2.call_args[1]
|
||||
assert call_kwargs["thread_ts"] == "1234567890.123456"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestSendVideo
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSendVideo:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_video_success(self, adapter, tmp_path):
|
||||
video = tmp_path / "clip.mp4"
|
||||
video.write_bytes(b"fake video data")
|
||||
|
||||
adapter._app.client.files_upload_v2 = AsyncMock(return_value={"ok": True})
|
||||
|
||||
result = await adapter.send_video(
|
||||
chat_id="C123",
|
||||
video_path=str(video),
|
||||
caption="Check this out",
|
||||
)
|
||||
|
||||
assert result.success
|
||||
call_kwargs = adapter._app.client.files_upload_v2.call_args[1]
|
||||
assert call_kwargs["filename"] == "clip.mp4"
|
||||
assert call_kwargs["initial_comment"] == "Check this out"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_video_missing_file(self, adapter):
|
||||
result = await adapter.send_video(
|
||||
chat_id="C123",
|
||||
video_path="/nonexistent/video.mp4",
|
||||
)
|
||||
|
||||
assert not result.success
|
||||
assert "not found" in result.error.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_video_not_connected(self, adapter):
|
||||
adapter._app = None
|
||||
result = await adapter.send_video(
|
||||
chat_id="C123",
|
||||
video_path="/some/video.mp4",
|
||||
)
|
||||
|
||||
assert not result.success
|
||||
assert "Not connected" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_video_api_error_falls_back(self, adapter, tmp_path):
|
||||
video = tmp_path / "clip.mp4"
|
||||
video.write_bytes(b"fake video")
|
||||
|
||||
adapter._app.client.files_upload_v2 = AsyncMock(
|
||||
side_effect=RuntimeError("Slack API error")
|
||||
)
|
||||
|
||||
# Should fall back to base class (text message)
|
||||
result = await adapter.send_video(
|
||||
chat_id="C123",
|
||||
video_path=str(video),
|
||||
)
|
||||
|
||||
adapter._app.client.chat_postMessage.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestIncomingDocumentHandling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestIncomingDocumentHandling:
|
||||
def _make_event(self, files=None, text="hello", channel_type="im"):
|
||||
"""Build a mock Slack message event with file attachments."""
|
||||
return {
|
||||
"text": text,
|
||||
"user": "U_USER",
|
||||
"channel": "C123",
|
||||
"channel_type": channel_type,
|
||||
"ts": "1234567890.000001",
|
||||
"files": files or [],
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pdf_document_cached(self, adapter):
|
||||
"""A PDF attachment should be downloaded, cached, and set as DOCUMENT type."""
|
||||
pdf_bytes = b"%PDF-1.4 fake content"
|
||||
|
||||
with patch.object(adapter, "_download_slack_file_bytes", new_callable=AsyncMock) as dl:
|
||||
dl.return_value = pdf_bytes
|
||||
event = self._make_event(files=[{
|
||||
"mimetype": "application/pdf",
|
||||
"name": "report.pdf",
|
||||
"url_private_download": "https://files.slack.com/report.pdf",
|
||||
"size": len(pdf_bytes),
|
||||
}])
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert msg_event.message_type == MessageType.DOCUMENT
|
||||
assert len(msg_event.media_urls) == 1
|
||||
assert os.path.exists(msg_event.media_urls[0])
|
||||
assert msg_event.media_types == ["application/pdf"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_txt_document_injects_content(self, adapter):
|
||||
"""A .txt file under 100KB should have its content injected into event text."""
|
||||
content = b"Hello from a text file"
|
||||
|
||||
with patch.object(adapter, "_download_slack_file_bytes", new_callable=AsyncMock) as dl:
|
||||
dl.return_value = content
|
||||
event = self._make_event(
|
||||
text="summarize this",
|
||||
files=[{
|
||||
"mimetype": "text/plain",
|
||||
"name": "notes.txt",
|
||||
"url_private_download": "https://files.slack.com/notes.txt",
|
||||
"size": len(content),
|
||||
}],
|
||||
)
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert "Hello from a text file" in msg_event.text
|
||||
assert "[Content of notes.txt]" in msg_event.text
|
||||
assert "summarize this" in msg_event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_md_document_injects_content(self, adapter):
|
||||
"""A .md file under 100KB should have its content injected."""
|
||||
content = b"# Title\nSome markdown content"
|
||||
|
||||
with patch.object(adapter, "_download_slack_file_bytes", new_callable=AsyncMock) as dl:
|
||||
dl.return_value = content
|
||||
event = self._make_event(files=[{
|
||||
"mimetype": "text/markdown",
|
||||
"name": "readme.md",
|
||||
"url_private_download": "https://files.slack.com/readme.md",
|
||||
"size": len(content),
|
||||
}], text="")
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert "# Title" in msg_event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_txt_not_injected(self, adapter):
|
||||
"""A .txt file over 100KB should be cached but NOT injected."""
|
||||
content = b"x" * (200 * 1024)
|
||||
|
||||
with patch.object(adapter, "_download_slack_file_bytes", new_callable=AsyncMock) as dl:
|
||||
dl.return_value = content
|
||||
event = self._make_event(files=[{
|
||||
"mimetype": "text/plain",
|
||||
"name": "big.txt",
|
||||
"url_private_download": "https://files.slack.com/big.txt",
|
||||
"size": len(content),
|
||||
}], text="")
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert len(msg_event.media_urls) == 1
|
||||
assert "[Content of" not in (msg_event.text or "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsupported_file_type_skipped(self, adapter):
|
||||
"""A .zip file should be silently skipped."""
|
||||
event = self._make_event(files=[{
|
||||
"mimetype": "application/zip",
|
||||
"name": "archive.zip",
|
||||
"url_private_download": "https://files.slack.com/archive.zip",
|
||||
"size": 1024,
|
||||
}])
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert msg_event.message_type == MessageType.TEXT
|
||||
assert len(msg_event.media_urls) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oversized_document_skipped(self, adapter):
|
||||
"""A document over 20MB should be skipped."""
|
||||
event = self._make_event(files=[{
|
||||
"mimetype": "application/pdf",
|
||||
"name": "huge.pdf",
|
||||
"url_private_download": "https://files.slack.com/huge.pdf",
|
||||
"size": 25 * 1024 * 1024,
|
||||
}])
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert len(msg_event.media_urls) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_download_error_handled(self, adapter):
|
||||
"""If document download fails, handler should not crash."""
|
||||
with patch.object(adapter, "_download_slack_file_bytes", new_callable=AsyncMock) as dl:
|
||||
dl.side_effect = RuntimeError("download failed")
|
||||
event = self._make_event(files=[{
|
||||
"mimetype": "application/pdf",
|
||||
"name": "report.pdf",
|
||||
"url_private_download": "https://files.slack.com/report.pdf",
|
||||
"size": 1024,
|
||||
}])
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
# Handler should still be called (the exception is caught)
|
||||
adapter.handle_message.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_still_handled(self, adapter):
|
||||
"""Image attachments should still go through the image path, not document."""
|
||||
with patch.object(adapter, "_download_slack_file", new_callable=AsyncMock) as dl:
|
||||
dl.return_value = "/tmp/cached_image.jpg"
|
||||
event = self._make_event(files=[{
|
||||
"mimetype": "image/jpeg",
|
||||
"name": "photo.jpg",
|
||||
"url_private_download": "https://files.slack.com/photo.jpg",
|
||||
"size": 1024,
|
||||
}])
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert msg_event.message_type == MessageType.PHOTO
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestMessageRouting
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMessageRouting:
|
||||
@pytest.mark.asyncio
|
||||
async def test_dm_processed_without_mention(self, adapter):
|
||||
"""DM messages should be processed without requiring a bot mention."""
|
||||
event = {
|
||||
"text": "hello",
|
||||
"user": "U_USER",
|
||||
"channel": "D123",
|
||||
"channel_type": "im",
|
||||
"ts": "1234567890.000001",
|
||||
}
|
||||
await adapter._handle_slack_message(event)
|
||||
adapter.handle_message.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_channel_message_requires_mention(self, adapter):
|
||||
"""Channel messages without a bot mention should be ignored."""
|
||||
event = {
|
||||
"text": "just talking",
|
||||
"user": "U_USER",
|
||||
"channel": "C123",
|
||||
"channel_type": "channel",
|
||||
"ts": "1234567890.000001",
|
||||
}
|
||||
await adapter._handle_slack_message(event)
|
||||
adapter.handle_message.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_channel_mention_strips_bot_id(self, adapter):
|
||||
"""When mentioned in a channel, the bot mention should be stripped."""
|
||||
event = {
|
||||
"text": "<@U_BOT> what's the weather?",
|
||||
"user": "U_USER",
|
||||
"channel": "C123",
|
||||
"channel_type": "channel",
|
||||
"ts": "1234567890.000001",
|
||||
}
|
||||
await adapter._handle_slack_message(event)
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert msg_event.text == "what's the weather?"
|
||||
assert "<@U_BOT>" not in msg_event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bot_messages_ignored(self, adapter):
|
||||
"""Messages from bots should be ignored."""
|
||||
event = {
|
||||
"text": "bot response",
|
||||
"bot_id": "B_OTHER",
|
||||
"channel": "C123",
|
||||
"channel_type": "im",
|
||||
"ts": "1234567890.000001",
|
||||
}
|
||||
await adapter._handle_slack_message(event)
|
||||
adapter.handle_message.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_edits_ignored(self, adapter):
|
||||
"""Message edits should be ignored."""
|
||||
event = {
|
||||
"text": "edited message",
|
||||
"user": "U_USER",
|
||||
"channel": "C123",
|
||||
"channel_type": "im",
|
||||
"ts": "1234567890.000001",
|
||||
"subtype": "message_changed",
|
||||
}
|
||||
await adapter._handle_slack_message(event)
|
||||
adapter.handle_message.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestSendTyping — assistant.threads.setStatus
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSendTyping:
|
||||
"""Test typing indicator via assistant.threads.setStatus."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sets_status_in_thread(self, adapter):
|
||||
adapter._app.client.assistant_threads_setStatus = AsyncMock()
|
||||
await adapter.send_typing("C123", metadata={"thread_id": "parent_ts"})
|
||||
adapter._app.client.assistant_threads_setStatus.assert_called_once_with(
|
||||
channel_id="C123",
|
||||
thread_ts="parent_ts",
|
||||
status="is thinking...",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_noop_without_thread(self, adapter):
|
||||
adapter._app.client.assistant_threads_setStatus = AsyncMock()
|
||||
await adapter.send_typing("C123")
|
||||
adapter._app.client.assistant_threads_setStatus.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_missing_scope_gracefully(self, adapter):
|
||||
adapter._app.client.assistant_threads_setStatus = AsyncMock(
|
||||
side_effect=Exception("missing_scope")
|
||||
)
|
||||
# Should not raise
|
||||
await adapter.send_typing("C123", metadata={"thread_id": "ts1"})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uses_thread_ts_fallback(self, adapter):
|
||||
adapter._app.client.assistant_threads_setStatus = AsyncMock()
|
||||
await adapter.send_typing("C123", metadata={"thread_ts": "fallback_ts"})
|
||||
adapter._app.client.assistant_threads_setStatus.assert_called_once_with(
|
||||
channel_id="C123",
|
||||
thread_ts="fallback_ts",
|
||||
status="is thinking...",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestFormatMessage — Markdown → mrkdwn conversion
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFormatMessage:
|
||||
"""Test markdown to Slack mrkdwn conversion."""
|
||||
|
||||
def test_bold_conversion(self, adapter):
|
||||
assert adapter.format_message("**hello**") == "*hello*"
|
||||
|
||||
def test_italic_asterisk_conversion(self, adapter):
|
||||
assert adapter.format_message("*hello*") == "_hello_"
|
||||
|
||||
def test_italic_underscore_preserved(self, adapter):
|
||||
assert adapter.format_message("_hello_") == "_hello_"
|
||||
|
||||
def test_header_to_bold(self, adapter):
|
||||
assert adapter.format_message("## Section Title") == "*Section Title*"
|
||||
|
||||
def test_header_with_bold_content(self, adapter):
|
||||
# **bold** inside a header should not double-wrap
|
||||
assert adapter.format_message("## **Title**") == "*Title*"
|
||||
|
||||
def test_link_conversion(self, adapter):
|
||||
result = adapter.format_message("[click here](https://example.com)")
|
||||
assert result == "<https://example.com|click here>"
|
||||
|
||||
def test_strikethrough(self, adapter):
|
||||
assert adapter.format_message("~~deleted~~") == "~deleted~"
|
||||
|
||||
def test_code_block_preserved(self, adapter):
|
||||
code = "```python\nx = **not bold**\n```"
|
||||
assert adapter.format_message(code) == code
|
||||
|
||||
def test_inline_code_preserved(self, adapter):
|
||||
text = "Use `**raw**` syntax"
|
||||
assert adapter.format_message(text) == "Use `**raw**` syntax"
|
||||
|
||||
def test_mixed_content(self, adapter):
|
||||
text = "**Bold** and *italic* with `code`"
|
||||
result = adapter.format_message(text)
|
||||
assert "*Bold*" in result
|
||||
assert "_italic_" in result
|
||||
assert "`code`" in result
|
||||
|
||||
def test_empty_string(self, adapter):
|
||||
assert adapter.format_message("") == ""
|
||||
|
||||
def test_none_passthrough(self, adapter):
|
||||
assert adapter.format_message(None) is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestReactions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReactions:
|
||||
"""Test emoji reaction methods."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_reaction_calls_api(self, adapter):
|
||||
adapter._app.client.reactions_add = AsyncMock()
|
||||
result = await adapter._add_reaction("C123", "ts1", "eyes")
|
||||
assert result is True
|
||||
adapter._app.client.reactions_add.assert_called_once_with(
|
||||
channel="C123", timestamp="ts1", name="eyes"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_reaction_handles_error(self, adapter):
|
||||
adapter._app.client.reactions_add = AsyncMock(side_effect=Exception("already_reacted"))
|
||||
result = await adapter._add_reaction("C123", "ts1", "eyes")
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_reaction_calls_api(self, adapter):
|
||||
adapter._app.client.reactions_remove = AsyncMock()
|
||||
result = await adapter._remove_reaction("C123", "ts1", "eyes")
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reactions_in_message_flow(self, adapter):
|
||||
"""Reactions should be added on receipt and swapped on completion."""
|
||||
adapter._app.client.reactions_add = AsyncMock()
|
||||
adapter._app.client.reactions_remove = AsyncMock()
|
||||
adapter._app.client.users_info = AsyncMock(return_value={
|
||||
"user": {"profile": {"display_name": "Tyler"}}
|
||||
})
|
||||
|
||||
event = {
|
||||
"text": "hello",
|
||||
"user": "U_USER",
|
||||
"channel": "C123",
|
||||
"channel_type": "im",
|
||||
"ts": "1234567890.000001",
|
||||
}
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
# Should have added 👀, then removed 👀, then added ✅
|
||||
add_calls = adapter._app.client.reactions_add.call_args_list
|
||||
remove_calls = adapter._app.client.reactions_remove.call_args_list
|
||||
assert len(add_calls) == 2
|
||||
assert add_calls[0].kwargs["name"] == "eyes"
|
||||
assert add_calls[1].kwargs["name"] == "white_check_mark"
|
||||
assert len(remove_calls) == 1
|
||||
assert remove_calls[0].kwargs["name"] == "eyes"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestUserNameResolution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestUserNameResolution:
|
||||
"""Test user identity resolution."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolves_display_name(self, adapter):
|
||||
adapter._app.client.users_info = AsyncMock(return_value={
|
||||
"user": {"profile": {"display_name": "Tyler", "real_name": "Tyler B"}}
|
||||
})
|
||||
name = await adapter._resolve_user_name("U123")
|
||||
assert name == "Tyler"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_back_to_real_name(self, adapter):
|
||||
adapter._app.client.users_info = AsyncMock(return_value={
|
||||
"user": {"profile": {"display_name": "", "real_name": "Tyler B"}}
|
||||
})
|
||||
name = await adapter._resolve_user_name("U123")
|
||||
assert name == "Tyler B"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_caches_result(self, adapter):
|
||||
adapter._app.client.users_info = AsyncMock(return_value={
|
||||
"user": {"profile": {"display_name": "Tyler"}}
|
||||
})
|
||||
await adapter._resolve_user_name("U123")
|
||||
await adapter._resolve_user_name("U123")
|
||||
# Only one API call despite two lookups
|
||||
assert adapter._app.client.users_info.call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_api_error(self, adapter):
|
||||
adapter._app.client.users_info = AsyncMock(side_effect=Exception("rate limited"))
|
||||
name = await adapter._resolve_user_name("U123")
|
||||
assert name == "U123" # Falls back to user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_name_in_message_source(self, adapter):
|
||||
"""Message source should include resolved user name."""
|
||||
adapter._app.client.users_info = AsyncMock(return_value={
|
||||
"user": {"profile": {"display_name": "Tyler"}}
|
||||
})
|
||||
adapter._app.client.reactions_add = AsyncMock()
|
||||
adapter._app.client.reactions_remove = AsyncMock()
|
||||
|
||||
event = {
|
||||
"text": "hello",
|
||||
"user": "U_USER",
|
||||
"channel": "C123",
|
||||
"channel_type": "im",
|
||||
"ts": "1234567890.000001",
|
||||
}
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
# Check the source in the MessageEvent passed to handle_message
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert msg_event.source.user_name == "Tyler"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestSlashCommands — expanded command set
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSlashCommands:
|
||||
"""Test slash command routing."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compact_maps_to_compress(self, adapter):
|
||||
command = {"text": "compact", "user_id": "U1", "channel_id": "C1"}
|
||||
await adapter._handle_slash_command(command)
|
||||
msg = adapter.handle_message.call_args[0][0]
|
||||
assert msg.text == "/compress"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_command(self, adapter):
|
||||
command = {"text": "resume my session", "user_id": "U1", "channel_id": "C1"}
|
||||
await adapter._handle_slash_command(command)
|
||||
msg = adapter.handle_message.call_args[0][0]
|
||||
assert msg.text == "/resume my session"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_background_command(self, adapter):
|
||||
command = {"text": "background run tests", "user_id": "U1", "channel_id": "C1"}
|
||||
await adapter._handle_slash_command(command)
|
||||
msg = adapter.handle_message.call_args[0][0]
|
||||
assert msg.text == "/background run tests"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_usage_command(self, adapter):
|
||||
command = {"text": "usage", "user_id": "U1", "channel_id": "C1"}
|
||||
await adapter._handle_slash_command(command)
|
||||
msg = adapter.handle_message.call_args[0][0]
|
||||
assert msg.text == "/usage"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reasoning_command(self, adapter):
|
||||
command = {"text": "reasoning", "user_id": "U1", "channel_id": "C1"}
|
||||
await adapter._handle_slash_command(command)
|
||||
msg = adapter.handle_message.call_args[0][0]
|
||||
assert msg.text == "/reasoning"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestMessageSplitting
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMessageSplitting:
|
||||
"""Test that long messages are split before sending."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_long_message_split_into_chunks(self, adapter):
|
||||
"""Messages over MAX_MESSAGE_LENGTH should be split."""
|
||||
long_text = "x" * 45000 # Over Slack's 40k API limit
|
||||
adapter._app.client.chat_postMessage = AsyncMock(
|
||||
return_value={"ts": "ts1"}
|
||||
)
|
||||
await adapter.send("C123", long_text)
|
||||
# Should have been called multiple times
|
||||
assert adapter._app.client.chat_postMessage.call_count >= 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_short_message_single_send(self, adapter):
|
||||
"""Short messages should be sent in one call."""
|
||||
adapter._app.client.chat_postMessage = AsyncMock(
|
||||
return_value={"ts": "ts1"}
|
||||
)
|
||||
await adapter.send("C123", "hello world")
|
||||
assert adapter._app.client.chat_postMessage.call_count == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestReplyBroadcast
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReplyBroadcast:
|
||||
"""Test reply_broadcast config option."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_disabled_by_default(self, adapter):
|
||||
adapter._app.client.chat_postMessage = AsyncMock(
|
||||
return_value={"ts": "ts1"}
|
||||
)
|
||||
await adapter.send("C123", "hi", metadata={"thread_id": "parent_ts"})
|
||||
kwargs = adapter._app.client.chat_postMessage.call_args.kwargs
|
||||
assert "reply_broadcast" not in kwargs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_enabled_via_config(self, adapter):
|
||||
adapter.config.extra["reply_broadcast"] = True
|
||||
adapter._app.client.chat_postMessage = AsyncMock(
|
||||
return_value={"ts": "ts1"}
|
||||
)
|
||||
await adapter.send("C123", "hi", metadata={"thread_id": "parent_ts"})
|
||||
kwargs = adapter._app.client.chat_postMessage.call_args.kwargs
|
||||
assert kwargs.get("reply_broadcast") is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestFallbackPreservesThreadContext
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFallbackPreservesThreadContext:
|
||||
"""Bug fix: file upload fallbacks lost thread context (metadata) when
|
||||
calling super() without metadata, causing replies to appear outside
|
||||
the thread."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_image_file_fallback_preserves_thread(self, adapter, tmp_path):
|
||||
test_file = tmp_path / "photo.jpg"
|
||||
test_file.write_bytes(b"\xff\xd8\xff\xe0")
|
||||
|
||||
adapter._app.client.files_upload_v2 = AsyncMock(
|
||||
side_effect=Exception("upload failed")
|
||||
)
|
||||
adapter._app.client.chat_postMessage = AsyncMock(
|
||||
return_value={"ts": "msg_ts"}
|
||||
)
|
||||
|
||||
metadata = {"thread_id": "parent_ts_123"}
|
||||
await adapter.send_image_file(
|
||||
chat_id="C123",
|
||||
image_path=str(test_file),
|
||||
caption="test image",
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
call_kwargs = adapter._app.client.chat_postMessage.call_args.kwargs
|
||||
assert call_kwargs.get("thread_ts") == "parent_ts_123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_video_fallback_preserves_thread(self, adapter, tmp_path):
|
||||
test_file = tmp_path / "clip.mp4"
|
||||
test_file.write_bytes(b"\x00\x00\x00\x1c")
|
||||
|
||||
adapter._app.client.files_upload_v2 = AsyncMock(
|
||||
side_effect=Exception("upload failed")
|
||||
)
|
||||
adapter._app.client.chat_postMessage = AsyncMock(
|
||||
return_value={"ts": "msg_ts"}
|
||||
)
|
||||
|
||||
metadata = {"thread_id": "parent_ts_456"}
|
||||
await adapter.send_video(
|
||||
chat_id="C123",
|
||||
video_path=str(test_file),
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
call_kwargs = adapter._app.client.chat_postMessage.call_args.kwargs
|
||||
assert call_kwargs.get("thread_ts") == "parent_ts_456"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_document_fallback_preserves_thread(self, adapter, tmp_path):
|
||||
test_file = tmp_path / "report.pdf"
|
||||
test_file.write_bytes(b"%PDF-1.4")
|
||||
|
||||
adapter._app.client.files_upload_v2 = AsyncMock(
|
||||
side_effect=Exception("upload failed")
|
||||
)
|
||||
adapter._app.client.chat_postMessage = AsyncMock(
|
||||
return_value={"ts": "msg_ts"}
|
||||
)
|
||||
|
||||
metadata = {"thread_id": "parent_ts_789"}
|
||||
await adapter.send_document(
|
||||
chat_id="C123",
|
||||
file_path=str(test_file),
|
||||
caption="report",
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
call_kwargs = adapter._app.client.chat_postMessage.call_args.kwargs
|
||||
assert call_kwargs.get("thread_ts") == "parent_ts_789"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_image_file_fallback_includes_caption(self, adapter, tmp_path):
|
||||
test_file = tmp_path / "photo.jpg"
|
||||
test_file.write_bytes(b"\xff\xd8\xff\xe0")
|
||||
|
||||
adapter._app.client.files_upload_v2 = AsyncMock(
|
||||
side_effect=Exception("upload failed")
|
||||
)
|
||||
adapter._app.client.chat_postMessage = AsyncMock(
|
||||
return_value={"ts": "msg_ts"}
|
||||
)
|
||||
|
||||
await adapter.send_image_file(
|
||||
chat_id="C123",
|
||||
image_path=str(test_file),
|
||||
caption="important screenshot",
|
||||
)
|
||||
|
||||
call_kwargs = adapter._app.client.chat_postMessage.call_args.kwargs
|
||||
assert "important screenshot" in call_kwargs["text"]
|
||||
215
hermes_code/tests/gateway/test_sms.py
Normal file
215
hermes_code/tests/gateway/test_sms.py
Normal file
|
|
@ -0,0 +1,215 @@
|
|||
"""Tests for SMS (Twilio) platform integration.
|
||||
|
||||
Covers config loading, format/truncate, echo prevention,
|
||||
requirements check, and toolset verification.
|
||||
"""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform, PlatformConfig, HomeChannel
|
||||
|
||||
|
||||
# ── Config loading ──────────────────────────────────────────────────
|
||||
|
||||
class TestSmsConfigLoading:
|
||||
"""Verify _apply_env_overrides wires SMS correctly."""
|
||||
|
||||
def test_sms_platform_enum_exists(self):
|
||||
assert Platform.SMS.value == "sms"
|
||||
|
||||
def test_env_overrides_create_sms_config(self):
|
||||
from gateway.config import load_gateway_config
|
||||
|
||||
env = {
|
||||
"TWILIO_ACCOUNT_SID": "ACtest123",
|
||||
"TWILIO_AUTH_TOKEN": "token_abc",
|
||||
"TWILIO_PHONE_NUMBER": "+15551234567",
|
||||
}
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
config = load_gateway_config()
|
||||
assert Platform.SMS in config.platforms
|
||||
pc = config.platforms[Platform.SMS]
|
||||
assert pc.enabled is True
|
||||
assert pc.api_key == "token_abc"
|
||||
|
||||
def test_env_overrides_set_home_channel(self):
|
||||
from gateway.config import load_gateway_config
|
||||
|
||||
env = {
|
||||
"TWILIO_ACCOUNT_SID": "ACtest123",
|
||||
"TWILIO_AUTH_TOKEN": "token_abc",
|
||||
"TWILIO_PHONE_NUMBER": "+15551234567",
|
||||
"SMS_HOME_CHANNEL": "+15559876543",
|
||||
"SMS_HOME_CHANNEL_NAME": "My Phone",
|
||||
}
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
config = load_gateway_config()
|
||||
hc = config.platforms[Platform.SMS].home_channel
|
||||
assert hc is not None
|
||||
assert hc.chat_id == "+15559876543"
|
||||
assert hc.name == "My Phone"
|
||||
assert hc.platform == Platform.SMS
|
||||
|
||||
def test_sms_in_connected_platforms(self):
|
||||
from gateway.config import load_gateway_config
|
||||
|
||||
env = {
|
||||
"TWILIO_ACCOUNT_SID": "ACtest123",
|
||||
"TWILIO_AUTH_TOKEN": "token_abc",
|
||||
}
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
config = load_gateway_config()
|
||||
connected = config.get_connected_platforms()
|
||||
assert Platform.SMS in connected
|
||||
|
||||
|
||||
# ── Format / truncate ───────────────────────────────────────────────
|
||||
|
||||
class TestSmsFormatAndTruncate:
|
||||
"""Test SmsAdapter.format_message strips markdown."""
|
||||
|
||||
def _make_adapter(self):
|
||||
from gateway.platforms.sms import SmsAdapter
|
||||
|
||||
env = {
|
||||
"TWILIO_ACCOUNT_SID": "ACtest",
|
||||
"TWILIO_AUTH_TOKEN": "tok",
|
||||
"TWILIO_PHONE_NUMBER": "+15550001111",
|
||||
}
|
||||
with patch.dict(os.environ, env):
|
||||
pc = PlatformConfig(enabled=True, api_key="tok")
|
||||
adapter = object.__new__(SmsAdapter)
|
||||
adapter.config = pc
|
||||
adapter._platform = Platform.SMS
|
||||
adapter._account_sid = "ACtest"
|
||||
adapter._auth_token = "tok"
|
||||
adapter._from_number = "+15550001111"
|
||||
return adapter
|
||||
|
||||
def test_strips_bold(self):
|
||||
adapter = self._make_adapter()
|
||||
assert adapter.format_message("**hello**") == "hello"
|
||||
|
||||
def test_strips_italic(self):
|
||||
adapter = self._make_adapter()
|
||||
assert adapter.format_message("*world*") == "world"
|
||||
|
||||
def test_strips_code_blocks(self):
|
||||
adapter = self._make_adapter()
|
||||
result = adapter.format_message("```python\nprint('hi')\n```")
|
||||
assert "```" not in result
|
||||
assert "print('hi')" in result
|
||||
|
||||
def test_strips_inline_code(self):
|
||||
adapter = self._make_adapter()
|
||||
assert adapter.format_message("`code`") == "code"
|
||||
|
||||
def test_strips_headers(self):
|
||||
adapter = self._make_adapter()
|
||||
assert adapter.format_message("## Title") == "Title"
|
||||
|
||||
def test_strips_links(self):
|
||||
adapter = self._make_adapter()
|
||||
assert adapter.format_message("[click](https://example.com)") == "click"
|
||||
|
||||
def test_collapses_newlines(self):
|
||||
adapter = self._make_adapter()
|
||||
result = adapter.format_message("a\n\n\n\nb")
|
||||
assert result == "a\n\nb"
|
||||
|
||||
|
||||
# ── Echo prevention ────────────────────────────────────────────────
|
||||
|
||||
class TestSmsEchoPrevention:
|
||||
"""Adapter should ignore messages from its own number."""
|
||||
|
||||
def test_own_number_detection(self):
|
||||
"""The adapter stores _from_number for echo prevention."""
|
||||
from gateway.platforms.sms import SmsAdapter
|
||||
|
||||
env = {
|
||||
"TWILIO_ACCOUNT_SID": "ACtest",
|
||||
"TWILIO_AUTH_TOKEN": "tok",
|
||||
"TWILIO_PHONE_NUMBER": "+15550001111",
|
||||
}
|
||||
with patch.dict(os.environ, env):
|
||||
pc = PlatformConfig(enabled=True, api_key="tok")
|
||||
adapter = SmsAdapter(pc)
|
||||
assert adapter._from_number == "+15550001111"
|
||||
|
||||
|
||||
# ── Requirements check ─────────────────────────────────────────────
|
||||
|
||||
class TestSmsRequirements:
|
||||
def test_check_sms_requirements_missing_sid(self):
|
||||
from gateway.platforms.sms import check_sms_requirements
|
||||
|
||||
env = {"TWILIO_AUTH_TOKEN": "tok"}
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
assert check_sms_requirements() is False
|
||||
|
||||
def test_check_sms_requirements_missing_token(self):
|
||||
from gateway.platforms.sms import check_sms_requirements
|
||||
|
||||
env = {"TWILIO_ACCOUNT_SID": "ACtest"}
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
assert check_sms_requirements() is False
|
||||
|
||||
def test_check_sms_requirements_both_set(self):
|
||||
from gateway.platforms.sms import check_sms_requirements
|
||||
|
||||
env = {
|
||||
"TWILIO_ACCOUNT_SID": "ACtest",
|
||||
"TWILIO_AUTH_TOKEN": "tok",
|
||||
}
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
# Only returns True if aiohttp is also importable
|
||||
result = check_sms_requirements()
|
||||
try:
|
||||
import aiohttp # noqa: F401
|
||||
assert result is True
|
||||
except ImportError:
|
||||
assert result is False
|
||||
|
||||
|
||||
# ── Toolset verification ───────────────────────────────────────────
|
||||
|
||||
class TestSmsToolset:
|
||||
def test_hermes_sms_toolset_exists(self):
|
||||
from toolsets import get_toolset
|
||||
|
||||
ts = get_toolset("hermes-sms")
|
||||
assert ts is not None
|
||||
assert "tools" in ts
|
||||
|
||||
def test_hermes_sms_in_gateway_includes(self):
|
||||
from toolsets import get_toolset
|
||||
|
||||
gw = get_toolset("hermes-gateway")
|
||||
assert gw is not None
|
||||
assert "hermes-sms" in gw["includes"]
|
||||
|
||||
def test_sms_platform_hint_exists(self):
|
||||
from agent.prompt_builder import PLATFORM_HINTS
|
||||
|
||||
assert "sms" in PLATFORM_HINTS
|
||||
assert "concise" in PLATFORM_HINTS["sms"].lower()
|
||||
|
||||
def test_sms_in_scheduler_platform_map(self):
|
||||
"""Verify cron scheduler recognizes 'sms' as a valid platform."""
|
||||
# Just check the Platform enum has SMS — the scheduler imports it dynamically
|
||||
assert Platform.SMS.value == "sms"
|
||||
|
||||
def test_sms_in_send_message_platform_map(self):
|
||||
"""Verify send_message_tool recognizes 'sms'."""
|
||||
# The platform_map is built inside _handle_send; verify SMS enum exists
|
||||
assert hasattr(Platform, "SMS")
|
||||
|
||||
def test_sms_in_cronjob_deliver_description(self):
|
||||
"""Verify cronjob_tools mentions sms in deliver description."""
|
||||
from tools.cronjob_tools import CRONJOB_SCHEMA
|
||||
deliver_desc = CRONJOB_SCHEMA["parameters"]["properties"]["deliver"]["description"]
|
||||
assert "sms" in deliver_desc.lower()
|
||||
81
hermes_code/tests/gateway/test_ssl_certs.py
Normal file
81
hermes_code/tests/gateway/test_ssl_certs.py
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
"""Tests for SSL certificate auto-detection in gateway/run.py."""
|
||||
|
||||
import importlib
|
||||
import os
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
|
||||
def _load_ensure_ssl():
|
||||
"""Import _ensure_ssl_certs fresh (gateway/run.py has heavy deps, so we
|
||||
extract just the function source to avoid importing the whole gateway)."""
|
||||
# We can test via the actual module since conftest isolates HERMES_HOME,
|
||||
# but we need to be careful about side effects. Instead, replicate the
|
||||
# logic in a controlled way.
|
||||
from types import ModuleType
|
||||
import textwrap, ssl as _ssl # noqa: F401
|
||||
|
||||
code = textwrap.dedent("""\
|
||||
import os, ssl
|
||||
|
||||
def _ensure_ssl_certs():
|
||||
if "SSL_CERT_FILE" in os.environ:
|
||||
return
|
||||
paths = ssl.get_default_verify_paths()
|
||||
for candidate in (paths.cafile, paths.openssl_cafile):
|
||||
if candidate and os.path.exists(candidate):
|
||||
os.environ["SSL_CERT_FILE"] = candidate
|
||||
return
|
||||
try:
|
||||
import certifi
|
||||
os.environ["SSL_CERT_FILE"] = certifi.where()
|
||||
return
|
||||
except ImportError:
|
||||
pass
|
||||
for candidate in (
|
||||
"/etc/ssl/certs/ca-certificates.crt",
|
||||
"/etc/ssl/cert.pem",
|
||||
):
|
||||
if os.path.exists(candidate):
|
||||
os.environ["SSL_CERT_FILE"] = candidate
|
||||
return
|
||||
""")
|
||||
mod = ModuleType("_ssl_helper")
|
||||
exec(code, mod.__dict__)
|
||||
return mod._ensure_ssl_certs
|
||||
|
||||
|
||||
class TestEnsureSslCerts:
|
||||
def test_respects_existing_env_var(self):
|
||||
fn = _load_ensure_ssl()
|
||||
with patch.dict(os.environ, {"SSL_CERT_FILE": "/custom/ca.pem"}):
|
||||
fn()
|
||||
assert os.environ["SSL_CERT_FILE"] == "/custom/ca.pem"
|
||||
|
||||
def test_sets_from_ssl_default_paths(self, tmp_path):
|
||||
fn = _load_ensure_ssl()
|
||||
cert = tmp_path / "ca.crt"
|
||||
cert.write_text("FAKE CERT")
|
||||
|
||||
mock_paths = MagicMock()
|
||||
mock_paths.cafile = str(cert)
|
||||
mock_paths.openssl_cafile = None
|
||||
|
||||
env = {k: v for k, v in os.environ.items() if k != "SSL_CERT_FILE"}
|
||||
with patch.dict(os.environ, env, clear=True), \
|
||||
patch("ssl.get_default_verify_paths", return_value=mock_paths):
|
||||
fn()
|
||||
assert os.environ.get("SSL_CERT_FILE") == str(cert)
|
||||
|
||||
def test_no_op_when_nothing_found(self):
|
||||
fn = _load_ensure_ssl()
|
||||
mock_paths = MagicMock()
|
||||
mock_paths.cafile = None
|
||||
mock_paths.openssl_cafile = None
|
||||
|
||||
env = {k: v for k, v in os.environ.items() if k != "SSL_CERT_FILE"}
|
||||
with patch.dict(os.environ, env, clear=True), \
|
||||
patch("ssl.get_default_verify_paths", return_value=mock_paths), \
|
||||
patch("os.path.exists", return_value=False), \
|
||||
patch.dict("sys.modules", {"certifi": None}):
|
||||
fn()
|
||||
assert "SSL_CERT_FILE" not in os.environ
|
||||
157
hermes_code/tests/gateway/test_status.py
Normal file
157
hermes_code/tests/gateway/test_status.py
Normal file
|
|
@ -0,0 +1,157 @@
|
|||
"""Tests for gateway runtime status tracking."""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from gateway import status
|
||||
|
||||
|
||||
class TestGatewayPidState:
|
||||
def test_write_pid_file_records_gateway_metadata(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
status.write_pid_file()
|
||||
|
||||
payload = json.loads((tmp_path / "gateway.pid").read_text())
|
||||
assert payload["pid"] == os.getpid()
|
||||
assert payload["kind"] == "hermes-gateway"
|
||||
assert isinstance(payload["argv"], list)
|
||||
assert payload["argv"]
|
||||
|
||||
def test_get_running_pid_rejects_live_non_gateway_pid(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
pid_path = tmp_path / "gateway.pid"
|
||||
pid_path.write_text(str(os.getpid()))
|
||||
|
||||
assert status.get_running_pid() is None
|
||||
assert not pid_path.exists()
|
||||
|
||||
def test_get_running_pid_accepts_gateway_metadata_when_cmdline_unavailable(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
pid_path = tmp_path / "gateway.pid"
|
||||
pid_path.write_text(json.dumps({
|
||||
"pid": os.getpid(),
|
||||
"kind": "hermes-gateway",
|
||||
"argv": ["python", "-m", "hermes_cli.main", "gateway"],
|
||||
"start_time": 123,
|
||||
}))
|
||||
|
||||
monkeypatch.setattr(status.os, "kill", lambda pid, sig: None)
|
||||
monkeypatch.setattr(status, "_get_process_start_time", lambda pid: 123)
|
||||
monkeypatch.setattr(status, "_read_process_cmdline", lambda pid: None)
|
||||
|
||||
assert status.get_running_pid() == os.getpid()
|
||||
|
||||
def test_get_running_pid_accepts_script_style_gateway_cmdline(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
pid_path = tmp_path / "gateway.pid"
|
||||
pid_path.write_text(json.dumps({
|
||||
"pid": os.getpid(),
|
||||
"kind": "hermes-gateway",
|
||||
"argv": ["/venv/bin/python", "/repo/hermes_cli/main.py", "gateway", "run", "--replace"],
|
||||
"start_time": 123,
|
||||
}))
|
||||
|
||||
monkeypatch.setattr(status.os, "kill", lambda pid, sig: None)
|
||||
monkeypatch.setattr(status, "_get_process_start_time", lambda pid: 123)
|
||||
monkeypatch.setattr(
|
||||
status,
|
||||
"_read_process_cmdline",
|
||||
lambda pid: "/venv/bin/python /repo/hermes_cli/main.py gateway run --replace",
|
||||
)
|
||||
|
||||
assert status.get_running_pid() == os.getpid()
|
||||
|
||||
|
||||
class TestGatewayRuntimeStatus:
|
||||
def test_write_runtime_status_overwrites_stale_pid_on_restart(self, tmp_path, monkeypatch):
|
||||
"""Regression: setdefault() preserved stale PID from previous process (#1631)."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
# Simulate a previous gateway run that left a state file with a stale PID
|
||||
state_path = tmp_path / "gateway_state.json"
|
||||
state_path.write_text(json.dumps({
|
||||
"pid": 99999,
|
||||
"start_time": 1000.0,
|
||||
"kind": "hermes-gateway",
|
||||
"platforms": {},
|
||||
"updated_at": "2025-01-01T00:00:00Z",
|
||||
}))
|
||||
|
||||
status.write_runtime_status(gateway_state="running")
|
||||
|
||||
payload = status.read_runtime_status()
|
||||
assert payload["pid"] == os.getpid(), "PID should be overwritten, not preserved via setdefault"
|
||||
assert payload["start_time"] != 1000.0, "start_time should be overwritten on restart"
|
||||
|
||||
def test_write_runtime_status_records_platform_failure(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
status.write_runtime_status(
|
||||
gateway_state="startup_failed",
|
||||
exit_reason="telegram conflict",
|
||||
platform="telegram",
|
||||
platform_state="fatal",
|
||||
error_code="telegram_polling_conflict",
|
||||
error_message="another poller is active",
|
||||
)
|
||||
|
||||
payload = status.read_runtime_status()
|
||||
assert payload["gateway_state"] == "startup_failed"
|
||||
assert payload["exit_reason"] == "telegram conflict"
|
||||
assert payload["platforms"]["telegram"]["state"] == "fatal"
|
||||
assert payload["platforms"]["telegram"]["error_code"] == "telegram_polling_conflict"
|
||||
assert payload["platforms"]["telegram"]["error_message"] == "another poller is active"
|
||||
|
||||
|
||||
class TestScopedLocks:
|
||||
def test_acquire_scoped_lock_rejects_live_other_process(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_GATEWAY_LOCK_DIR", str(tmp_path / "locks"))
|
||||
lock_path = tmp_path / "locks" / "telegram-bot-token-2bb80d537b1da3e3.lock"
|
||||
lock_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
lock_path.write_text(json.dumps({
|
||||
"pid": 99999,
|
||||
"start_time": 123,
|
||||
"kind": "hermes-gateway",
|
||||
}))
|
||||
|
||||
monkeypatch.setattr(status.os, "kill", lambda pid, sig: None)
|
||||
monkeypatch.setattr(status, "_get_process_start_time", lambda pid: 123)
|
||||
|
||||
acquired, existing = status.acquire_scoped_lock("telegram-bot-token", "secret", metadata={"platform": "telegram"})
|
||||
|
||||
assert acquired is False
|
||||
assert existing["pid"] == 99999
|
||||
|
||||
def test_acquire_scoped_lock_replaces_stale_record(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_GATEWAY_LOCK_DIR", str(tmp_path / "locks"))
|
||||
lock_path = tmp_path / "locks" / "telegram-bot-token-2bb80d537b1da3e3.lock"
|
||||
lock_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
lock_path.write_text(json.dumps({
|
||||
"pid": 99999,
|
||||
"start_time": 123,
|
||||
"kind": "hermes-gateway",
|
||||
}))
|
||||
|
||||
def fake_kill(pid, sig):
|
||||
raise ProcessLookupError
|
||||
|
||||
monkeypatch.setattr(status.os, "kill", fake_kill)
|
||||
|
||||
acquired, existing = status.acquire_scoped_lock("telegram-bot-token", "secret", metadata={"platform": "telegram"})
|
||||
|
||||
assert acquired is True
|
||||
payload = json.loads(lock_path.read_text())
|
||||
assert payload["pid"] == os.getpid()
|
||||
assert payload["metadata"]["platform"] == "telegram"
|
||||
|
||||
def test_release_scoped_lock_only_removes_current_owner(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_GATEWAY_LOCK_DIR", str(tmp_path / "locks"))
|
||||
|
||||
acquired, _ = status.acquire_scoped_lock("telegram-bot-token", "secret", metadata={"platform": "telegram"})
|
||||
assert acquired is True
|
||||
lock_path = tmp_path / "locks" / "telegram-bot-token-2bb80d537b1da3e3.lock"
|
||||
assert lock_path.exists()
|
||||
|
||||
status.release_scoped_lock("telegram-bot-token", "secret")
|
||||
assert not lock_path.exists()
|
||||
140
hermes_code/tests/gateway/test_status_command.py
Normal file
140
hermes_code/tests/gateway/test_status_command.py
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
"""Tests for gateway /status behavior and token persistence."""
|
||||
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import MessageEvent
|
||||
from gateway.session import SessionEntry, SessionSource, build_session_key
|
||||
|
||||
|
||||
def _make_source() -> SessionSource:
|
||||
return SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
user_id="u1",
|
||||
chat_id="c1",
|
||||
user_name="tester",
|
||||
chat_type="dm",
|
||||
)
|
||||
|
||||
|
||||
def _make_event(text: str) -> MessageEvent:
|
||||
return MessageEvent(
|
||||
text=text,
|
||||
source=_make_source(),
|
||||
message_id="m1",
|
||||
)
|
||||
|
||||
|
||||
def _make_runner(session_entry: SessionEntry):
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(
|
||||
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}
|
||||
)
|
||||
adapter = MagicMock()
|
||||
adapter.send = AsyncMock()
|
||||
runner.adapters = {Platform.TELEGRAM: adapter}
|
||||
runner._voice_mode = {}
|
||||
runner.hooks = SimpleNamespace(emit=AsyncMock(), loaded_hooks=False)
|
||||
runner.session_store = MagicMock()
|
||||
runner.session_store.get_or_create_session.return_value = session_entry
|
||||
runner.session_store.load_transcript.return_value = []
|
||||
runner.session_store.has_any_sessions.return_value = True
|
||||
runner.session_store.append_to_transcript = MagicMock()
|
||||
runner.session_store.rewrite_transcript = MagicMock()
|
||||
runner.session_store.update_session = MagicMock()
|
||||
runner._running_agents = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._session_db = None
|
||||
runner._reasoning_config = None
|
||||
runner._provider_routing = {}
|
||||
runner._fallback_model = None
|
||||
runner._show_reasoning = False
|
||||
runner._is_user_authorized = lambda _source: True
|
||||
runner._set_session_env = lambda _context: None
|
||||
runner._should_send_voice_reply = lambda *_args, **_kwargs: False
|
||||
runner._send_voice_reply = AsyncMock()
|
||||
runner._capture_gateway_honcho_if_configured = lambda *args, **kwargs: None
|
||||
runner._emit_gateway_run_progress = AsyncMock()
|
||||
return runner
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_command_reports_running_agent_without_interrupt(monkeypatch):
|
||||
session_entry = SessionEntry(
|
||||
session_key=build_session_key(_make_source()),
|
||||
session_id="sess-1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
total_tokens=321,
|
||||
)
|
||||
runner = _make_runner(session_entry)
|
||||
running_agent = MagicMock()
|
||||
runner._running_agents[build_session_key(_make_source())] = running_agent
|
||||
|
||||
result = await runner._handle_message(_make_event("/status"))
|
||||
|
||||
assert "**Tokens:** 321" in result
|
||||
assert "**Agent Running:** Yes ⚡" in result
|
||||
running_agent.interrupt.assert_not_called()
|
||||
assert runner._pending_messages == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_persists_agent_token_counts(monkeypatch):
|
||||
import gateway.run as gateway_run
|
||||
|
||||
session_entry = SessionEntry(
|
||||
session_key=build_session_key(_make_source()),
|
||||
session_id="sess-1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
)
|
||||
runner = _make_runner(session_entry)
|
||||
runner.session_store.load_transcript.return_value = [{"role": "user", "content": "earlier"}]
|
||||
runner._run_agent = AsyncMock(
|
||||
return_value={
|
||||
"final_response": "ok",
|
||||
"messages": [],
|
||||
"tools": [],
|
||||
"history_offset": 0,
|
||||
"last_prompt_tokens": 80,
|
||||
"input_tokens": 120,
|
||||
"output_tokens": 45,
|
||||
"model": "openai/test-model",
|
||||
}
|
||||
)
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
|
||||
monkeypatch.setattr(
|
||||
"agent.model_metadata.get_model_context_length",
|
||||
lambda *_args, **_kwargs: 100000,
|
||||
)
|
||||
|
||||
result = await runner._handle_message(_make_event("hello"))
|
||||
|
||||
assert result == "ok"
|
||||
runner.session_store.update_session.assert_called_once_with(
|
||||
session_entry.session_key,
|
||||
input_tokens=120,
|
||||
output_tokens=45,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
last_prompt_tokens=80,
|
||||
model="openai/test-model",
|
||||
estimated_cost_usd=None,
|
||||
cost_status=None,
|
||||
cost_source=None,
|
||||
provider=None,
|
||||
base_url=None,
|
||||
)
|
||||
127
hermes_code/tests/gateway/test_sticker_cache.py
Normal file
127
hermes_code/tests/gateway/test_sticker_cache.py
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
"""Tests for gateway/sticker_cache.py — sticker description cache."""
|
||||
|
||||
import json
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
|
||||
from gateway.sticker_cache import (
|
||||
_load_cache,
|
||||
_save_cache,
|
||||
get_cached_description,
|
||||
cache_sticker_description,
|
||||
build_sticker_injection,
|
||||
build_animated_sticker_injection,
|
||||
STICKER_VISION_PROMPT,
|
||||
)
|
||||
|
||||
|
||||
class TestLoadSaveCache:
|
||||
def test_load_missing_file(self, tmp_path):
|
||||
with patch("gateway.sticker_cache.CACHE_PATH", tmp_path / "nope.json"):
|
||||
assert _load_cache() == {}
|
||||
|
||||
def test_load_corrupt_file(self, tmp_path):
|
||||
bad_file = tmp_path / "bad.json"
|
||||
bad_file.write_text("not json{{{")
|
||||
with patch("gateway.sticker_cache.CACHE_PATH", bad_file):
|
||||
assert _load_cache() == {}
|
||||
|
||||
def test_save_and_load_roundtrip(self, tmp_path):
|
||||
cache_file = tmp_path / "cache.json"
|
||||
data = {"abc123": {"description": "A cat", "emoji": "", "set_name": "", "cached_at": 1.0}}
|
||||
with patch("gateway.sticker_cache.CACHE_PATH", cache_file):
|
||||
_save_cache(data)
|
||||
loaded = _load_cache()
|
||||
assert loaded == data
|
||||
|
||||
def test_save_creates_parent_dirs(self, tmp_path):
|
||||
cache_file = tmp_path / "sub" / "dir" / "cache.json"
|
||||
with patch("gateway.sticker_cache.CACHE_PATH", cache_file):
|
||||
_save_cache({"key": "value"})
|
||||
assert cache_file.exists()
|
||||
|
||||
|
||||
class TestCacheSticker:
|
||||
def test_cache_and_retrieve(self, tmp_path):
|
||||
cache_file = tmp_path / "cache.json"
|
||||
with patch("gateway.sticker_cache.CACHE_PATH", cache_file):
|
||||
cache_sticker_description("uid_1", "A happy dog", emoji="🐕", set_name="Dogs")
|
||||
result = get_cached_description("uid_1")
|
||||
|
||||
assert result is not None
|
||||
assert result["description"] == "A happy dog"
|
||||
assert result["emoji"] == "🐕"
|
||||
assert result["set_name"] == "Dogs"
|
||||
assert "cached_at" in result
|
||||
|
||||
def test_missing_sticker_returns_none(self, tmp_path):
|
||||
cache_file = tmp_path / "cache.json"
|
||||
with patch("gateway.sticker_cache.CACHE_PATH", cache_file):
|
||||
result = get_cached_description("nonexistent")
|
||||
assert result is None
|
||||
|
||||
def test_overwrite_existing(self, tmp_path):
|
||||
cache_file = tmp_path / "cache.json"
|
||||
with patch("gateway.sticker_cache.CACHE_PATH", cache_file):
|
||||
cache_sticker_description("uid_1", "Old description")
|
||||
cache_sticker_description("uid_1", "New description")
|
||||
result = get_cached_description("uid_1")
|
||||
|
||||
assert result["description"] == "New description"
|
||||
|
||||
def test_multiple_stickers(self, tmp_path):
|
||||
cache_file = tmp_path / "cache.json"
|
||||
with patch("gateway.sticker_cache.CACHE_PATH", cache_file):
|
||||
cache_sticker_description("uid_1", "Cat")
|
||||
cache_sticker_description("uid_2", "Dog")
|
||||
r1 = get_cached_description("uid_1")
|
||||
r2 = get_cached_description("uid_2")
|
||||
|
||||
assert r1["description"] == "Cat"
|
||||
assert r2["description"] == "Dog"
|
||||
|
||||
|
||||
class TestBuildStickerInjection:
|
||||
def test_exact_format_no_context(self):
|
||||
result = build_sticker_injection("A cat waving")
|
||||
assert result == '[The user sent a sticker~ It shows: "A cat waving" (=^.w.^=)]'
|
||||
|
||||
def test_exact_format_emoji_only(self):
|
||||
result = build_sticker_injection("A cat", emoji="😀")
|
||||
assert result == '[The user sent a sticker 😀~ It shows: "A cat" (=^.w.^=)]'
|
||||
|
||||
def test_exact_format_emoji_and_set_name(self):
|
||||
result = build_sticker_injection("A cat", emoji="😀", set_name="MyPack")
|
||||
assert result == '[The user sent a sticker 😀 from "MyPack"~ It shows: "A cat" (=^.w.^=)]'
|
||||
|
||||
def test_set_name_without_emoji_ignored(self):
|
||||
"""set_name alone (no emoji) produces no context — only emoji+set_name triggers 'from' clause."""
|
||||
result = build_sticker_injection("A cat", set_name="MyPack")
|
||||
assert result == '[The user sent a sticker~ It shows: "A cat" (=^.w.^=)]'
|
||||
assert "MyPack" not in result
|
||||
|
||||
def test_description_with_quotes(self):
|
||||
result = build_sticker_injection('A "happy" dog')
|
||||
assert '"A \\"happy\\" dog"' not in result # no escaping happens
|
||||
assert 'A "happy" dog' in result
|
||||
|
||||
def test_empty_description(self):
|
||||
result = build_sticker_injection("")
|
||||
assert result == '[The user sent a sticker~ It shows: "" (=^.w.^=)]'
|
||||
|
||||
|
||||
class TestBuildAnimatedStickerInjection:
|
||||
def test_exact_format_with_emoji(self):
|
||||
result = build_animated_sticker_injection(emoji="🎉")
|
||||
assert result == (
|
||||
"[The user sent an animated sticker 🎉~ "
|
||||
"I can't see animated ones yet, but the emoji suggests: 🎉]"
|
||||
)
|
||||
|
||||
def test_exact_format_without_emoji(self):
|
||||
result = build_animated_sticker_injection()
|
||||
assert result == "[The user sent an animated sticker~ I can't see animated ones yet]"
|
||||
|
||||
def test_empty_emoji_same_as_no_emoji(self):
|
||||
result = build_animated_sticker_injection(emoji="")
|
||||
assert result == build_animated_sticker_injection()
|
||||
77
hermes_code/tests/gateway/test_stt_config.py
Normal file
77
hermes_code/tests/gateway/test_stt_config.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
"""Gateway STT config tests — honor stt.enabled: false from config.yaml."""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from gateway.config import GatewayConfig, load_gateway_config
|
||||
|
||||
|
||||
def test_gateway_config_stt_disabled_from_dict_nested():
|
||||
config = GatewayConfig.from_dict({"stt": {"enabled": False}})
|
||||
assert config.stt_enabled is False
|
||||
|
||||
|
||||
def test_load_gateway_config_bridges_stt_enabled_from_config_yaml(tmp_path, monkeypatch):
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
(hermes_home / "config.yaml").write_text(
|
||||
yaml.dump({"stt": {"enabled": False}}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
config = load_gateway_config()
|
||||
|
||||
assert config.stt_enabled is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enrich_message_with_transcription_skips_when_stt_disabled():
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = GatewayRunner.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(stt_enabled=False)
|
||||
|
||||
with patch(
|
||||
"tools.transcription_tools.transcribe_audio",
|
||||
side_effect=AssertionError("transcribe_audio should not be called when STT is disabled"),
|
||||
), patch(
|
||||
"tools.transcription_tools.get_stt_model_from_config",
|
||||
return_value=None,
|
||||
):
|
||||
result = await runner._enrich_message_with_transcription(
|
||||
"caption",
|
||||
["/tmp/voice.ogg"],
|
||||
)
|
||||
|
||||
assert "transcription is disabled" in result.lower()
|
||||
assert "caption" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enrich_message_with_transcription_avoids_bogus_no_provider_message_for_backend_key_errors():
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = GatewayRunner.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(stt_enabled=True)
|
||||
|
||||
with patch(
|
||||
"tools.transcription_tools.transcribe_audio",
|
||||
return_value={"success": False, "error": "VOICE_TOOLS_OPENAI_KEY not set"},
|
||||
), patch(
|
||||
"tools.transcription_tools.get_stt_model_from_config",
|
||||
return_value=None,
|
||||
):
|
||||
result = await runner._enrich_message_with_transcription(
|
||||
"caption",
|
||||
["/tmp/voice.ogg"],
|
||||
)
|
||||
|
||||
assert "No STT provider is configured" not in result
|
||||
assert "trouble transcribing" in result
|
||||
assert "caption" in result
|
||||
241
hermes_code/tests/gateway/test_telegram_conflict.py
Normal file
241
hermes_code/tests/gateway/test_telegram_conflict.py
Normal file
|
|
@ -0,0 +1,241 @@
|
|||
import asyncio
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
|
||||
|
||||
def _ensure_telegram_mock():
|
||||
if "telegram" in sys.modules and hasattr(sys.modules["telegram"], "__file__"):
|
||||
return
|
||||
|
||||
telegram_mod = MagicMock()
|
||||
telegram_mod.ext.ContextTypes.DEFAULT_TYPE = type(None)
|
||||
telegram_mod.constants.ParseMode.MARKDOWN_V2 = "MarkdownV2"
|
||||
telegram_mod.constants.ChatType.GROUP = "group"
|
||||
telegram_mod.constants.ChatType.SUPERGROUP = "supergroup"
|
||||
telegram_mod.constants.ChatType.CHANNEL = "channel"
|
||||
telegram_mod.constants.ChatType.PRIVATE = "private"
|
||||
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants"):
|
||||
sys.modules.setdefault(name, telegram_mod)
|
||||
|
||||
|
||||
_ensure_telegram_mock()
|
||||
|
||||
from gateway.platforms.telegram import TelegramAdapter # noqa: E402
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_rejects_same_host_token_lock(monkeypatch):
|
||||
adapter = TelegramAdapter(PlatformConfig(enabled=True, token="secret-token"))
|
||||
|
||||
monkeypatch.setattr(
|
||||
"gateway.status.acquire_scoped_lock",
|
||||
lambda scope, identity, metadata=None: (False, {"pid": 4242}),
|
||||
)
|
||||
|
||||
ok = await adapter.connect()
|
||||
|
||||
assert ok is False
|
||||
assert adapter.fatal_error_code == "telegram_token_lock"
|
||||
assert adapter.has_fatal_error is True
|
||||
assert "already using this Telegram bot token" in adapter.fatal_error_message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_polling_conflict_retries_before_fatal(monkeypatch):
|
||||
"""A single 409 should trigger a retry, not an immediate fatal error."""
|
||||
adapter = TelegramAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
fatal_handler = AsyncMock()
|
||||
adapter.set_fatal_error_handler(fatal_handler)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"gateway.status.acquire_scoped_lock",
|
||||
lambda scope, identity, metadata=None: (True, None),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"gateway.status.release_scoped_lock",
|
||||
lambda scope, identity: None,
|
||||
)
|
||||
|
||||
captured = {}
|
||||
|
||||
async def fake_start_polling(**kwargs):
|
||||
captured["error_callback"] = kwargs["error_callback"]
|
||||
|
||||
updater = SimpleNamespace(
|
||||
start_polling=AsyncMock(side_effect=fake_start_polling),
|
||||
stop=AsyncMock(),
|
||||
running=True,
|
||||
)
|
||||
bot = SimpleNamespace(set_my_commands=AsyncMock())
|
||||
app = SimpleNamespace(
|
||||
bot=bot,
|
||||
updater=updater,
|
||||
add_handler=MagicMock(),
|
||||
initialize=AsyncMock(),
|
||||
start=AsyncMock(),
|
||||
)
|
||||
builder = MagicMock()
|
||||
builder.token.return_value = builder
|
||||
builder.build.return_value = app
|
||||
monkeypatch.setattr("gateway.platforms.telegram.Application", SimpleNamespace(builder=MagicMock(return_value=builder)))
|
||||
|
||||
# Speed up retries for testing
|
||||
monkeypatch.setattr("asyncio.sleep", AsyncMock())
|
||||
|
||||
ok = await adapter.connect()
|
||||
|
||||
assert ok is True
|
||||
assert callable(captured["error_callback"])
|
||||
|
||||
conflict = type("Conflict", (Exception,), {})
|
||||
|
||||
# First conflict: should retry, NOT be fatal
|
||||
captured["error_callback"](conflict("Conflict: terminated by other getUpdates request"))
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
# Give the scheduled task a chance to run
|
||||
for _ in range(10):
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert adapter.has_fatal_error is False, "First conflict should not be fatal"
|
||||
assert adapter._polling_conflict_count == 0, "Count should reset after successful retry"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_polling_conflict_becomes_fatal_after_retries(monkeypatch):
|
||||
"""After exhausting retries, the conflict should become fatal."""
|
||||
adapter = TelegramAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
fatal_handler = AsyncMock()
|
||||
adapter.set_fatal_error_handler(fatal_handler)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"gateway.status.acquire_scoped_lock",
|
||||
lambda scope, identity, metadata=None: (True, None),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"gateway.status.release_scoped_lock",
|
||||
lambda scope, identity: None,
|
||||
)
|
||||
|
||||
captured = {}
|
||||
|
||||
async def fake_start_polling(**kwargs):
|
||||
captured["error_callback"] = kwargs["error_callback"]
|
||||
|
||||
# Make start_polling fail on retries to exhaust retries
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def failing_start_polling(**kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
# First call (initial connect) succeeds
|
||||
captured["error_callback"] = kwargs["error_callback"]
|
||||
else:
|
||||
# Retry calls fail
|
||||
raise Exception("Connection refused")
|
||||
|
||||
updater = SimpleNamespace(
|
||||
start_polling=AsyncMock(side_effect=failing_start_polling),
|
||||
stop=AsyncMock(),
|
||||
running=True,
|
||||
)
|
||||
bot = SimpleNamespace(set_my_commands=AsyncMock())
|
||||
app = SimpleNamespace(
|
||||
bot=bot,
|
||||
updater=updater,
|
||||
add_handler=MagicMock(),
|
||||
initialize=AsyncMock(),
|
||||
start=AsyncMock(),
|
||||
)
|
||||
builder = MagicMock()
|
||||
builder.token.return_value = builder
|
||||
builder.build.return_value = app
|
||||
monkeypatch.setattr("gateway.platforms.telegram.Application", SimpleNamespace(builder=MagicMock(return_value=builder)))
|
||||
|
||||
# Speed up retries for testing
|
||||
monkeypatch.setattr("asyncio.sleep", AsyncMock())
|
||||
|
||||
ok = await adapter.connect()
|
||||
assert ok is True
|
||||
|
||||
conflict = type("Conflict", (Exception,), {})
|
||||
|
||||
# Directly call _handle_polling_conflict to avoid event-loop scheduling
|
||||
# complexity. Each call simulates one 409 from Telegram.
|
||||
for i in range(4):
|
||||
await adapter._handle_polling_conflict(
|
||||
conflict("Conflict: terminated by other getUpdates request")
|
||||
)
|
||||
|
||||
# After 3 failed retries (count 1-3 each enter the retry branch but
|
||||
# start_polling raises), the 4th conflict pushes count to 4 which
|
||||
# exceeds MAX_CONFLICT_RETRIES (3), entering the fatal branch.
|
||||
assert adapter.fatal_error_code == "telegram_polling_conflict", (
|
||||
f"Expected fatal after 4 conflicts, got code={adapter.fatal_error_code}, "
|
||||
f"count={adapter._polling_conflict_count}"
|
||||
)
|
||||
assert adapter.has_fatal_error is True
|
||||
fatal_handler.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_marks_retryable_fatal_error_for_startup_network_failure(monkeypatch):
|
||||
adapter = TelegramAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
|
||||
monkeypatch.setattr(
|
||||
"gateway.status.acquire_scoped_lock",
|
||||
lambda scope, identity, metadata=None: (True, None),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"gateway.status.release_scoped_lock",
|
||||
lambda scope, identity: None,
|
||||
)
|
||||
|
||||
builder = MagicMock()
|
||||
builder.token.return_value = builder
|
||||
app = SimpleNamespace(
|
||||
bot=SimpleNamespace(),
|
||||
updater=SimpleNamespace(),
|
||||
add_handler=MagicMock(),
|
||||
initialize=AsyncMock(side_effect=RuntimeError("Temporary failure in name resolution")),
|
||||
start=AsyncMock(),
|
||||
)
|
||||
builder.build.return_value = app
|
||||
monkeypatch.setattr("gateway.platforms.telegram.Application", SimpleNamespace(builder=MagicMock(return_value=builder)))
|
||||
|
||||
ok = await adapter.connect()
|
||||
|
||||
assert ok is False
|
||||
assert adapter.fatal_error_code == "telegram_connect_error"
|
||||
assert adapter.fatal_error_retryable is True
|
||||
assert "Temporary failure in name resolution" in adapter.fatal_error_message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_skips_inactive_updater_and_app(monkeypatch):
|
||||
adapter = TelegramAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
|
||||
updater = SimpleNamespace(running=False, stop=AsyncMock())
|
||||
app = SimpleNamespace(
|
||||
updater=updater,
|
||||
running=False,
|
||||
stop=AsyncMock(),
|
||||
shutdown=AsyncMock(),
|
||||
)
|
||||
adapter._app = app
|
||||
|
||||
warning = MagicMock()
|
||||
monkeypatch.setattr("gateway.platforms.telegram.logger.warning", warning)
|
||||
|
||||
await adapter.disconnect()
|
||||
|
||||
updater.stop.assert_not_awaited()
|
||||
app.stop.assert_not_awaited()
|
||||
app.shutdown.assert_awaited_once()
|
||||
warning.assert_not_called()
|
||||
694
hermes_code/tests/gateway/test_telegram_documents.py
Normal file
694
hermes_code/tests/gateway/test_telegram_documents.py
Normal file
|
|
@ -0,0 +1,694 @@
|
|||
"""
|
||||
Tests for Telegram document handling in gateway/platforms/telegram.py.
|
||||
|
||||
Covers: document type detection, download/cache flow, size limits,
|
||||
text injection, error handling.
|
||||
|
||||
Note: python-telegram-bot may not be installed in the test environment.
|
||||
We mock the telegram module at import time to avoid collection errors.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import (
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock the telegram package if it's not installed
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _ensure_telegram_mock():
|
||||
"""Install mock telegram modules so TelegramAdapter can be imported."""
|
||||
if "telegram" in sys.modules and hasattr(sys.modules["telegram"], "__file__"):
|
||||
# Real library is installed — no mocking needed
|
||||
return
|
||||
|
||||
telegram_mod = MagicMock()
|
||||
# ContextTypes needs DEFAULT_TYPE as an actual attribute for the annotation
|
||||
telegram_mod.ext.ContextTypes.DEFAULT_TYPE = type(None)
|
||||
telegram_mod.constants.ParseMode.MARKDOWN_V2 = "MarkdownV2"
|
||||
telegram_mod.constants.ChatType.GROUP = "group"
|
||||
telegram_mod.constants.ChatType.SUPERGROUP = "supergroup"
|
||||
telegram_mod.constants.ChatType.CHANNEL = "channel"
|
||||
telegram_mod.constants.ChatType.PRIVATE = "private"
|
||||
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants"):
|
||||
sys.modules.setdefault(name, telegram_mod)
|
||||
|
||||
|
||||
_ensure_telegram_mock()
|
||||
|
||||
# Now we can safely import
|
||||
from gateway.platforms.telegram import TelegramAdapter # noqa: E402
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers to build mock Telegram objects
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_file_obj(data: bytes = b"hello"):
|
||||
"""Create a mock Telegram File with download_as_bytearray."""
|
||||
f = AsyncMock()
|
||||
f.download_as_bytearray = AsyncMock(return_value=bytearray(data))
|
||||
f.file_path = "documents/file.pdf"
|
||||
return f
|
||||
|
||||
|
||||
def _make_document(
|
||||
file_name="report.pdf",
|
||||
mime_type="application/pdf",
|
||||
file_size=1024,
|
||||
file_obj=None,
|
||||
):
|
||||
"""Create a mock Telegram Document object."""
|
||||
doc = MagicMock()
|
||||
doc.file_name = file_name
|
||||
doc.mime_type = mime_type
|
||||
doc.file_size = file_size
|
||||
doc.get_file = AsyncMock(return_value=file_obj or _make_file_obj())
|
||||
return doc
|
||||
|
||||
|
||||
def _make_message(document=None, caption=None, media_group_id=None, photo=None):
|
||||
"""Build a mock Telegram Message with the given document/photo."""
|
||||
msg = MagicMock()
|
||||
msg.message_id = 42
|
||||
msg.text = caption or ""
|
||||
msg.caption = caption
|
||||
msg.date = None
|
||||
# Media flags — all None except explicit payload
|
||||
msg.photo = photo
|
||||
msg.video = None
|
||||
msg.audio = None
|
||||
msg.voice = None
|
||||
msg.sticker = None
|
||||
msg.document = document
|
||||
msg.media_group_id = media_group_id
|
||||
# Chat / user
|
||||
msg.chat = MagicMock()
|
||||
msg.chat.id = 100
|
||||
msg.chat.type = "private"
|
||||
msg.chat.title = None
|
||||
msg.chat.full_name = "Test User"
|
||||
msg.from_user = MagicMock()
|
||||
msg.from_user.id = 1
|
||||
msg.from_user.full_name = "Test User"
|
||||
msg.message_thread_id = None
|
||||
return msg
|
||||
|
||||
|
||||
def _make_update(msg):
|
||||
"""Wrap a message in a mock Update."""
|
||||
update = MagicMock()
|
||||
update.message = msg
|
||||
return update
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture()
|
||||
def adapter():
|
||||
config = PlatformConfig(enabled=True, token="fake-token")
|
||||
a = TelegramAdapter(config)
|
||||
# Capture events instead of processing them
|
||||
a.handle_message = AsyncMock()
|
||||
return a
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _redirect_cache(tmp_path, monkeypatch):
|
||||
"""Point document cache to tmp_path so tests don't touch ~/.hermes."""
|
||||
monkeypatch.setattr(
|
||||
"gateway.platforms.base.DOCUMENT_CACHE_DIR", tmp_path / "doc_cache"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestDocumentTypeDetection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDocumentTypeDetection:
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_detected_explicitly(self, adapter):
|
||||
doc = _make_document()
|
||||
msg = _make_message(document=doc)
|
||||
update = _make_update(msg)
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert event.message_type == MessageType.DOCUMENT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_is_document(self, adapter):
|
||||
"""When no specific media attr is set, message_type defaults to DOCUMENT."""
|
||||
msg = _make_message()
|
||||
msg.document = None # no media at all
|
||||
update = _make_update(msg)
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert event.message_type == MessageType.DOCUMENT
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestDocumentDownloadBlock
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_photo(file_obj=None):
|
||||
photo = MagicMock()
|
||||
photo.get_file = AsyncMock(return_value=file_obj or _make_file_obj(b"photo-bytes"))
|
||||
return photo
|
||||
|
||||
|
||||
class TestDocumentDownloadBlock:
|
||||
@pytest.mark.asyncio
|
||||
async def test_supported_pdf_is_cached(self, adapter):
|
||||
pdf_bytes = b"%PDF-1.4 fake"
|
||||
file_obj = _make_file_obj(pdf_bytes)
|
||||
doc = _make_document(file_name="report.pdf", file_size=1024, file_obj=file_obj)
|
||||
msg = _make_message(document=doc)
|
||||
update = _make_update(msg)
|
||||
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert len(event.media_urls) == 1
|
||||
assert os.path.exists(event.media_urls[0])
|
||||
assert event.media_types == ["application/pdf"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_supported_txt_injects_content(self, adapter):
|
||||
content = b"Hello from a text file"
|
||||
file_obj = _make_file_obj(content)
|
||||
doc = _make_document(
|
||||
file_name="notes.txt", mime_type="text/plain",
|
||||
file_size=len(content), file_obj=file_obj,
|
||||
)
|
||||
msg = _make_message(document=doc)
|
||||
update = _make_update(msg)
|
||||
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert "Hello from a text file" in event.text
|
||||
assert "[Content of notes.txt]" in event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_supported_md_injects_content(self, adapter):
|
||||
content = b"# Title\nSome markdown"
|
||||
file_obj = _make_file_obj(content)
|
||||
doc = _make_document(
|
||||
file_name="readme.md", mime_type="text/markdown",
|
||||
file_size=len(content), file_obj=file_obj,
|
||||
)
|
||||
msg = _make_message(document=doc)
|
||||
update = _make_update(msg)
|
||||
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert "# Title" in event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_caption_preserved_with_injection(self, adapter):
|
||||
content = b"file text"
|
||||
file_obj = _make_file_obj(content)
|
||||
doc = _make_document(
|
||||
file_name="doc.txt", mime_type="text/plain",
|
||||
file_size=len(content), file_obj=file_obj,
|
||||
)
|
||||
msg = _make_message(document=doc, caption="Please summarize")
|
||||
update = _make_update(msg)
|
||||
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert "file text" in event.text
|
||||
assert "Please summarize" in event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsupported_type_rejected(self, adapter):
|
||||
doc = _make_document(file_name="archive.zip", mime_type="application/zip", file_size=100)
|
||||
msg = _make_message(document=doc)
|
||||
update = _make_update(msg)
|
||||
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert "Unsupported document type" in event.text
|
||||
assert ".zip" in event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oversized_file_rejected(self, adapter):
|
||||
doc = _make_document(file_name="huge.pdf", file_size=25 * 1024 * 1024)
|
||||
msg = _make_message(document=doc)
|
||||
update = _make_update(msg)
|
||||
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert "too large" in event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_file_size_rejected(self, adapter):
|
||||
"""Security fix: file_size=None must be rejected (not silently allowed)."""
|
||||
doc = _make_document(file_name="tricky.pdf", file_size=None)
|
||||
msg = _make_message(document=doc)
|
||||
update = _make_update(msg)
|
||||
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert "too large" in event.text or "could not be verified" in event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_filename_uses_mime_lookup(self, adapter):
|
||||
"""No file_name but valid mime_type should resolve to extension."""
|
||||
content = b"some pdf bytes"
|
||||
file_obj = _make_file_obj(content)
|
||||
doc = _make_document(
|
||||
file_name=None, mime_type="application/pdf",
|
||||
file_size=len(content), file_obj=file_obj,
|
||||
)
|
||||
msg = _make_message(document=doc)
|
||||
update = _make_update(msg)
|
||||
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert len(event.media_urls) == 1
|
||||
assert event.media_types == ["application/pdf"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_filename_and_mime_rejected(self, adapter):
|
||||
doc = _make_document(file_name=None, mime_type=None, file_size=100)
|
||||
msg = _make_message(document=doc)
|
||||
update = _make_update(msg)
|
||||
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert "Unsupported" in event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unicode_decode_error_handled(self, adapter):
|
||||
"""Binary bytes that aren't valid UTF-8 in a .txt — content not injected but file still cached."""
|
||||
binary = bytes(range(128, 256)) # not valid UTF-8
|
||||
file_obj = _make_file_obj(binary)
|
||||
doc = _make_document(
|
||||
file_name="binary.txt", mime_type="text/plain",
|
||||
file_size=len(binary), file_obj=file_obj,
|
||||
)
|
||||
msg = _make_message(document=doc)
|
||||
update = _make_update(msg)
|
||||
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
# File should still be cached
|
||||
assert len(event.media_urls) == 1
|
||||
assert os.path.exists(event.media_urls[0])
|
||||
# Content NOT injected — text should be empty (no caption set)
|
||||
assert "[Content of" not in (event.text or "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_injection_capped(self, adapter):
|
||||
"""A .txt file over 100 KB should NOT have its content injected."""
|
||||
large = b"x" * (200 * 1024) # 200 KB
|
||||
file_obj = _make_file_obj(large)
|
||||
doc = _make_document(
|
||||
file_name="big.txt", mime_type="text/plain",
|
||||
file_size=len(large), file_obj=file_obj,
|
||||
)
|
||||
msg = _make_message(document=doc)
|
||||
update = _make_update(msg)
|
||||
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
# File should be cached
|
||||
assert len(event.media_urls) == 1
|
||||
# Content should NOT be injected
|
||||
assert "[Content of" not in (event.text or "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_exception_handled(self, adapter):
|
||||
"""If get_file() raises, the handler logs the error without crashing."""
|
||||
doc = _make_document(file_name="crash.pdf", file_size=100)
|
||||
doc.get_file = AsyncMock(side_effect=RuntimeError("Telegram API down"))
|
||||
msg = _make_message(document=doc)
|
||||
update = _make_update(msg)
|
||||
|
||||
# Should not raise
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
# handle_message should still be called (the handler catches the exception)
|
||||
adapter.handle_message.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestMediaGroups — media group (album) buffering
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMediaGroups:
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_album_photo_burst_is_buffered_and_combined(self, adapter):
|
||||
first_photo = _make_photo(_make_file_obj(b"first"))
|
||||
second_photo = _make_photo(_make_file_obj(b"second"))
|
||||
|
||||
msg1 = _make_message(caption="two images", photo=[first_photo])
|
||||
msg2 = _make_message(photo=[second_photo])
|
||||
|
||||
with patch("gateway.platforms.telegram.cache_image_from_bytes", side_effect=["/tmp/burst-one.jpg", "/tmp/burst-two.jpg"]):
|
||||
await adapter._handle_media_message(_make_update(msg1), MagicMock())
|
||||
await adapter._handle_media_message(_make_update(msg2), MagicMock())
|
||||
assert adapter.handle_message.await_count == 0
|
||||
await asyncio.sleep(adapter.MEDIA_GROUP_WAIT_SECONDS + 0.05)
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.text == "two images"
|
||||
assert event.media_urls == ["/tmp/burst-one.jpg", "/tmp/burst-two.jpg"]
|
||||
assert len(event.media_types) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_photo_album_is_buffered_and_combined(self, adapter):
|
||||
first_photo = _make_photo(_make_file_obj(b"first"))
|
||||
second_photo = _make_photo(_make_file_obj(b"second"))
|
||||
|
||||
msg1 = _make_message(caption="two images", media_group_id="album-1", photo=[first_photo])
|
||||
msg2 = _make_message(media_group_id="album-1", photo=[second_photo])
|
||||
|
||||
with patch("gateway.platforms.telegram.cache_image_from_bytes", side_effect=["/tmp/one.jpg", "/tmp/two.jpg"]):
|
||||
await adapter._handle_media_message(_make_update(msg1), MagicMock())
|
||||
await adapter._handle_media_message(_make_update(msg2), MagicMock())
|
||||
assert adapter.handle_message.await_count == 0
|
||||
await asyncio.sleep(adapter.MEDIA_GROUP_WAIT_SECONDS + 0.05)
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert event.text == "two images"
|
||||
assert event.media_urls == ["/tmp/one.jpg", "/tmp/two.jpg"]
|
||||
assert len(event.media_types) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_cancels_pending_media_group_flush(self, adapter):
|
||||
first_photo = _make_photo(_make_file_obj(b"first"))
|
||||
msg = _make_message(caption="two images", media_group_id="album-2", photo=[first_photo])
|
||||
|
||||
with patch("gateway.platforms.telegram.cache_image_from_bytes", return_value="/tmp/one.jpg"):
|
||||
await adapter._handle_media_message(_make_update(msg), MagicMock())
|
||||
|
||||
assert "album-2" in adapter._media_group_events
|
||||
assert "album-2" in adapter._media_group_tasks
|
||||
|
||||
await adapter.disconnect()
|
||||
await asyncio.sleep(adapter.MEDIA_GROUP_WAIT_SECONDS + 0.05)
|
||||
|
||||
assert adapter._media_group_events == {}
|
||||
assert adapter._media_group_tasks == {}
|
||||
adapter.handle_message.assert_not_awaited()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestSendDocument — outbound file attachment delivery
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSendDocument:
|
||||
"""Tests for TelegramAdapter.send_document() — sending files to users."""
|
||||
|
||||
@pytest.fixture()
|
||||
def connected_adapter(self, adapter):
|
||||
"""Adapter with a mock bot attached."""
|
||||
bot = AsyncMock()
|
||||
adapter._bot = bot
|
||||
return adapter
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_document_success(self, connected_adapter, tmp_path):
|
||||
"""A local file is sent via bot.send_document and returns success."""
|
||||
# Create a real temp file
|
||||
test_file = tmp_path / "report.pdf"
|
||||
test_file.write_bytes(b"%PDF-1.4 fake content")
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.message_id = 99
|
||||
connected_adapter._bot.send_document = AsyncMock(return_value=mock_msg)
|
||||
|
||||
result = await connected_adapter.send_document(
|
||||
chat_id="12345",
|
||||
file_path=str(test_file),
|
||||
caption="Here's the report",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.message_id == "99"
|
||||
connected_adapter._bot.send_document.assert_called_once()
|
||||
call_kwargs = connected_adapter._bot.send_document.call_args[1]
|
||||
assert call_kwargs["chat_id"] == 12345
|
||||
assert call_kwargs["filename"] == "report.pdf"
|
||||
assert call_kwargs["caption"] == "Here's the report"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_document_custom_filename(self, connected_adapter, tmp_path):
|
||||
"""The file_name parameter overrides the basename for display."""
|
||||
test_file = tmp_path / "doc_abc123_ugly.csv"
|
||||
test_file.write_bytes(b"a,b,c\n1,2,3")
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.message_id = 100
|
||||
connected_adapter._bot.send_document = AsyncMock(return_value=mock_msg)
|
||||
|
||||
result = await connected_adapter.send_document(
|
||||
chat_id="12345",
|
||||
file_path=str(test_file),
|
||||
file_name="clean_data.csv",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
call_kwargs = connected_adapter._bot.send_document.call_args[1]
|
||||
assert call_kwargs["filename"] == "clean_data.csv"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_document_file_not_found(self, connected_adapter):
|
||||
"""Missing file returns error without calling Telegram API."""
|
||||
result = await connected_adapter.send_document(
|
||||
chat_id="12345",
|
||||
file_path="/nonexistent/file.pdf",
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "not found" in result.error.lower()
|
||||
connected_adapter._bot.send_document.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_document_not_connected(self, adapter):
|
||||
"""If bot is None, returns not connected error."""
|
||||
result = await adapter.send_document(
|
||||
chat_id="12345",
|
||||
file_path="/some/file.pdf",
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "Not connected" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_document_caption_truncated(self, connected_adapter, tmp_path):
|
||||
"""Captions longer than 1024 chars are truncated."""
|
||||
test_file = tmp_path / "data.json"
|
||||
test_file.write_bytes(b"{}")
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.message_id = 101
|
||||
connected_adapter._bot.send_document = AsyncMock(return_value=mock_msg)
|
||||
|
||||
long_caption = "x" * 2000
|
||||
await connected_adapter.send_document(
|
||||
chat_id="12345",
|
||||
file_path=str(test_file),
|
||||
caption=long_caption,
|
||||
)
|
||||
|
||||
call_kwargs = connected_adapter._bot.send_document.call_args[1]
|
||||
assert len(call_kwargs["caption"]) == 1024
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_document_api_error_falls_back(self, connected_adapter, tmp_path):
|
||||
"""If Telegram API raises, falls back to base class text message."""
|
||||
test_file = tmp_path / "file.pdf"
|
||||
test_file.write_bytes(b"data")
|
||||
|
||||
connected_adapter._bot.send_document = AsyncMock(
|
||||
side_effect=RuntimeError("Telegram API error")
|
||||
)
|
||||
|
||||
# The base fallback calls self.send() which is also on _bot, so mock it
|
||||
# to avoid cascading errors.
|
||||
connected_adapter.send = AsyncMock(
|
||||
return_value=SendResult(success=True, message_id="fallback")
|
||||
)
|
||||
|
||||
result = await connected_adapter.send_document(
|
||||
chat_id="12345",
|
||||
file_path=str(test_file),
|
||||
)
|
||||
|
||||
# Should have fallen back to base class
|
||||
assert result.success is True
|
||||
assert result.message_id == "fallback"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_document_reply_to(self, connected_adapter, tmp_path):
|
||||
"""reply_to parameter is forwarded as reply_to_message_id."""
|
||||
test_file = tmp_path / "spec.md"
|
||||
test_file.write_bytes(b"# Spec")
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.message_id = 102
|
||||
connected_adapter._bot.send_document = AsyncMock(return_value=mock_msg)
|
||||
|
||||
await connected_adapter.send_document(
|
||||
chat_id="12345",
|
||||
file_path=str(test_file),
|
||||
reply_to="50",
|
||||
)
|
||||
|
||||
call_kwargs = connected_adapter._bot.send_document.call_args[1]
|
||||
assert call_kwargs["reply_to_message_id"] == 50
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_document_thread_id(self, connected_adapter, tmp_path):
|
||||
"""metadata thread_id is forwarded as message_thread_id (required for Telegram forum groups)."""
|
||||
test_file = tmp_path / "report.pdf"
|
||||
test_file.write_bytes(b"%PDF-1.4 data")
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.message_id = 103
|
||||
connected_adapter._bot.send_document = AsyncMock(return_value=mock_msg)
|
||||
|
||||
await connected_adapter.send_document(
|
||||
chat_id="12345",
|
||||
file_path=str(test_file),
|
||||
metadata={"thread_id": "789"},
|
||||
)
|
||||
|
||||
call_kwargs = connected_adapter._bot.send_document.call_args[1]
|
||||
assert call_kwargs["message_thread_id"] == 789
|
||||
|
||||
|
||||
class TestTelegramPhotoBatching:
|
||||
@pytest.mark.asyncio
|
||||
async def test_flush_photo_batch_does_not_drop_newer_scheduled_task(self, adapter):
|
||||
old_task = MagicMock()
|
||||
new_task = MagicMock()
|
||||
batch_key = "session:photo-burst"
|
||||
adapter._pending_photo_batch_tasks[batch_key] = new_task
|
||||
adapter._pending_photo_batches[batch_key] = MessageEvent(
|
||||
text="",
|
||||
message_type=MessageType.PHOTO,
|
||||
source=SimpleNamespace(channel_id="chat-1"),
|
||||
media_urls=["/tmp/a.jpg"],
|
||||
media_types=["image/jpeg"],
|
||||
)
|
||||
|
||||
with (
|
||||
patch("gateway.platforms.telegram.asyncio.current_task", return_value=old_task),
|
||||
patch("gateway.platforms.telegram.asyncio.sleep", new=AsyncMock()),
|
||||
):
|
||||
await adapter._flush_photo_batch(batch_key)
|
||||
|
||||
assert adapter._pending_photo_batch_tasks[batch_key] is new_task
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_cancels_pending_photo_batch_tasks(self, adapter):
|
||||
task = MagicMock()
|
||||
task.done.return_value = False
|
||||
adapter._pending_photo_batch_tasks["session:photo-burst"] = task
|
||||
adapter._pending_photo_batches["session:photo-burst"] = MessageEvent(
|
||||
text="",
|
||||
message_type=MessageType.PHOTO,
|
||||
source=SimpleNamespace(channel_id="chat-1"),
|
||||
)
|
||||
adapter._app = MagicMock()
|
||||
adapter._app.updater.stop = AsyncMock()
|
||||
adapter._app.stop = AsyncMock()
|
||||
adapter._app.shutdown = AsyncMock()
|
||||
|
||||
await adapter.disconnect()
|
||||
|
||||
task.cancel.assert_called_once()
|
||||
assert adapter._pending_photo_batch_tasks == {}
|
||||
assert adapter._pending_photo_batches == {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestSendVideo — outbound video delivery
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSendVideo:
|
||||
"""Tests for TelegramAdapter.send_video() — sending videos to users."""
|
||||
|
||||
@pytest.fixture()
|
||||
def connected_adapter(self, adapter):
|
||||
bot = AsyncMock()
|
||||
adapter._bot = bot
|
||||
return adapter
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_video_success(self, connected_adapter, tmp_path):
|
||||
test_file = tmp_path / "clip.mp4"
|
||||
test_file.write_bytes(b"\x00\x00\x00\x1c" + b"ftyp" + b"\x00" * 100)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.message_id = 200
|
||||
connected_adapter._bot.send_video = AsyncMock(return_value=mock_msg)
|
||||
|
||||
result = await connected_adapter.send_video(
|
||||
chat_id="12345",
|
||||
video_path=str(test_file),
|
||||
caption="Check this out",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.message_id == "200"
|
||||
connected_adapter._bot.send_video.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_video_file_not_found(self, connected_adapter):
|
||||
result = await connected_adapter.send_video(
|
||||
chat_id="12345",
|
||||
video_path="/nonexistent/video.mp4",
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "not found" in result.error.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_video_not_connected(self, adapter):
|
||||
result = await adapter.send_video(
|
||||
chat_id="12345",
|
||||
video_path="/some/video.mp4",
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "Not connected" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_video_thread_id(self, connected_adapter, tmp_path):
|
||||
"""metadata thread_id is forwarded as message_thread_id (required for Telegram forum groups)."""
|
||||
test_file = tmp_path / "clip.mp4"
|
||||
test_file.write_bytes(b"\x00\x00\x00\x1c" + b"ftyp" + b"\x00" * 100)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.message_id = 201
|
||||
connected_adapter._bot.send_video = AsyncMock(return_value=mock_msg)
|
||||
|
||||
await connected_adapter.send_video(
|
||||
chat_id="12345",
|
||||
video_path=str(test_file),
|
||||
metadata={"thread_id": "789"},
|
||||
)
|
||||
|
||||
call_kwargs = connected_adapter._bot.send_video.call_args[1]
|
||||
assert call_kwargs["message_thread_id"] == 789
|
||||
538
hermes_code/tests/gateway/test_telegram_format.py
Normal file
538
hermes_code/tests/gateway/test_telegram_format.py
Normal file
|
|
@ -0,0 +1,538 @@
|
|||
"""Tests for Telegram MarkdownV2 formatting in gateway/platforms/telegram.py.
|
||||
|
||||
Covers: _escape_mdv2 (pure function), format_message (markdown-to-MarkdownV2
|
||||
conversion pipeline), and edge cases that could produce invalid MarkdownV2
|
||||
or corrupt user-visible content.
|
||||
"""
|
||||
|
||||
import re
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock the telegram package if it's not installed
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _ensure_telegram_mock():
|
||||
if "telegram" in sys.modules and hasattr(sys.modules["telegram"], "__file__"):
|
||||
return
|
||||
mod = MagicMock()
|
||||
mod.ext.ContextTypes.DEFAULT_TYPE = type(None)
|
||||
mod.constants.ParseMode.MARKDOWN_V2 = "MarkdownV2"
|
||||
mod.constants.ChatType.GROUP = "group"
|
||||
mod.constants.ChatType.SUPERGROUP = "supergroup"
|
||||
mod.constants.ChatType.CHANNEL = "channel"
|
||||
mod.constants.ChatType.PRIVATE = "private"
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants"):
|
||||
sys.modules.setdefault(name, mod)
|
||||
|
||||
|
||||
_ensure_telegram_mock()
|
||||
|
||||
from gateway.platforms.telegram import TelegramAdapter, _escape_mdv2, _strip_mdv2 # noqa: E402
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture()
|
||||
def adapter():
|
||||
config = PlatformConfig(enabled=True, token="fake-token")
|
||||
return TelegramAdapter(config)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# _escape_mdv2
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestEscapeMdv2:
|
||||
def test_escapes_all_special_characters(self):
|
||||
special = r'_*[]()~`>#+-=|{}.!\ '
|
||||
escaped = _escape_mdv2(special)
|
||||
# Every special char should be preceded by backslash
|
||||
for ch in r'_*[]()~`>#+-=|{}.!\ ':
|
||||
if ch == ' ':
|
||||
continue
|
||||
assert f'\\{ch}' in escaped
|
||||
|
||||
def test_empty_string(self):
|
||||
assert _escape_mdv2("") == ""
|
||||
|
||||
def test_no_special_characters(self):
|
||||
assert _escape_mdv2("hello world 123") == "hello world 123"
|
||||
|
||||
def test_backslash_escaped(self):
|
||||
assert _escape_mdv2("a\\b") == "a\\\\b"
|
||||
|
||||
def test_dot_escaped(self):
|
||||
assert _escape_mdv2("v2.0") == "v2\\.0"
|
||||
|
||||
def test_exclamation_escaped(self):
|
||||
assert _escape_mdv2("wow!") == "wow\\!"
|
||||
|
||||
def test_mixed_text_and_specials(self):
|
||||
result = _escape_mdv2("Hello (world)!")
|
||||
assert result == "Hello \\(world\\)\\!"
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# format_message - basic conversions
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestFormatMessageBasic:
|
||||
def test_empty_string(self, adapter):
|
||||
assert adapter.format_message("") == ""
|
||||
|
||||
def test_none_input(self, adapter):
|
||||
# content is falsy, returned as-is
|
||||
assert adapter.format_message(None) is None
|
||||
|
||||
def test_plain_text_specials_escaped(self, adapter):
|
||||
result = adapter.format_message("Price is $5.00!")
|
||||
assert "\\." in result
|
||||
assert "\\!" in result
|
||||
|
||||
def test_plain_text_no_markdown(self, adapter):
|
||||
result = adapter.format_message("Hello world")
|
||||
assert result == "Hello world"
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# format_message - code blocks
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestFormatMessageCodeBlocks:
|
||||
def test_fenced_code_block_preserved(self, adapter):
|
||||
text = "Before\n```python\nprint('hello')\n```\nAfter"
|
||||
result = adapter.format_message(text)
|
||||
# Code block contents must NOT be escaped
|
||||
assert "```python\nprint('hello')\n```" in result
|
||||
# But "After" should have no escaping needed (plain text)
|
||||
assert "After" in result
|
||||
|
||||
def test_inline_code_preserved(self, adapter):
|
||||
text = "Use `my_var` here"
|
||||
result = adapter.format_message(text)
|
||||
# Inline code content must NOT be escaped
|
||||
assert "`my_var`" in result
|
||||
# The surrounding text's underscore-free content should be fine
|
||||
assert "Use" in result
|
||||
|
||||
def test_code_block_special_chars_not_escaped(self, adapter):
|
||||
text = "```\nif (x > 0) { return !x; }\n```"
|
||||
result = adapter.format_message(text)
|
||||
# Inside code block, > and ! and { should NOT be escaped
|
||||
assert "if (x > 0) { return !x; }" in result
|
||||
|
||||
def test_inline_code_special_chars_not_escaped(self, adapter):
|
||||
text = "Run `rm -rf ./*` carefully"
|
||||
result = adapter.format_message(text)
|
||||
assert "`rm -rf ./*`" in result
|
||||
|
||||
def test_multiple_code_blocks(self, adapter):
|
||||
text = "```\nblock1\n```\ntext\n```\nblock2\n```"
|
||||
result = adapter.format_message(text)
|
||||
assert "block1" in result
|
||||
assert "block2" in result
|
||||
# "text" between blocks should be present
|
||||
assert "text" in result
|
||||
|
||||
def test_inline_code_backslashes_escaped(self, adapter):
|
||||
r"""Backslashes in inline code must be escaped for MarkdownV2."""
|
||||
text = r"Check `C:\ProgramData\VMware\` path"
|
||||
result = adapter.format_message(text)
|
||||
assert r"`C:\\ProgramData\\VMware\\`" in result
|
||||
|
||||
def test_fenced_code_block_backslashes_escaped(self, adapter):
|
||||
r"""Backslashes in fenced code blocks must be escaped for MarkdownV2."""
|
||||
text = "```\npath = r'C:\\Users\\test'\n```"
|
||||
result = adapter.format_message(text)
|
||||
assert r"C:\\Users\\test" in result
|
||||
|
||||
def test_fenced_code_block_backticks_escaped(self, adapter):
|
||||
r"""Backticks inside fenced code blocks must be escaped for MarkdownV2."""
|
||||
text = "```\necho `hostname`\n```"
|
||||
result = adapter.format_message(text)
|
||||
assert r"echo \`hostname\`" in result
|
||||
|
||||
def test_inline_code_no_double_escape(self, adapter):
|
||||
r"""Already-escaped backslashes should not be quadruple-escaped."""
|
||||
text = r"Use `\\server\share`"
|
||||
result = adapter.format_message(text)
|
||||
# \\ in input → \\\\ in output (each \ escaped once)
|
||||
assert r"`\\\\server\\share`" in result
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# format_message - bold and italic
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestFormatMessageBoldItalic:
|
||||
def test_bold_converted(self, adapter):
|
||||
result = adapter.format_message("This is **bold** text")
|
||||
# MarkdownV2 bold uses single *
|
||||
assert "*bold*" in result
|
||||
# Original ** should be gone
|
||||
assert "**" not in result
|
||||
|
||||
def test_italic_converted(self, adapter):
|
||||
result = adapter.format_message("This is *italic* text")
|
||||
# MarkdownV2 italic uses _
|
||||
assert "_italic_" in result
|
||||
|
||||
def test_bold_with_special_chars(self, adapter):
|
||||
result = adapter.format_message("**hello.world!**")
|
||||
# Content inside bold should be escaped
|
||||
assert "*hello\\.world\\!*" in result
|
||||
|
||||
def test_italic_with_special_chars(self, adapter):
|
||||
result = adapter.format_message("*hello.world*")
|
||||
assert "_hello\\.world_" in result
|
||||
|
||||
def test_bold_and_italic_in_same_line(self, adapter):
|
||||
result = adapter.format_message("**bold** and *italic*")
|
||||
assert "*bold*" in result
|
||||
assert "_italic_" in result
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# format_message - headers
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestFormatMessageHeaders:
|
||||
def test_h1_converted_to_bold(self, adapter):
|
||||
result = adapter.format_message("# Title")
|
||||
# Header becomes bold in MarkdownV2
|
||||
assert "*Title*" in result
|
||||
# Hash should be removed
|
||||
assert "#" not in result
|
||||
|
||||
def test_h2_converted(self, adapter):
|
||||
result = adapter.format_message("## Subtitle")
|
||||
assert "*Subtitle*" in result
|
||||
|
||||
def test_header_with_inner_bold_stripped(self, adapter):
|
||||
# Headers strip redundant **...** inside
|
||||
result = adapter.format_message("## **Important**")
|
||||
# Should be *Important* not ***Important***
|
||||
assert "*Important*" in result
|
||||
count = result.count("*")
|
||||
# Should have exactly 2 asterisks (open + close)
|
||||
assert count == 2
|
||||
|
||||
def test_header_with_special_chars(self, adapter):
|
||||
result = adapter.format_message("# Hello (World)!")
|
||||
assert "\\(" in result
|
||||
assert "\\)" in result
|
||||
assert "\\!" in result
|
||||
|
||||
def test_multiline_headers(self, adapter):
|
||||
text = "# First\nSome text\n## Second"
|
||||
result = adapter.format_message(text)
|
||||
assert "*First*" in result
|
||||
assert "*Second*" in result
|
||||
assert "Some text" in result
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# format_message - links
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestFormatMessageLinks:
|
||||
def test_markdown_link_converted(self, adapter):
|
||||
result = adapter.format_message("[Click here](https://example.com)")
|
||||
assert "[Click here](https://example.com)" in result
|
||||
|
||||
def test_link_display_text_escaped(self, adapter):
|
||||
result = adapter.format_message("[Hello!](https://example.com)")
|
||||
# The ! in display text should be escaped
|
||||
assert "Hello\\!" in result
|
||||
|
||||
def test_link_url_parentheses_escaped(self, adapter):
|
||||
result = adapter.format_message("[link](https://example.com/path_(1))")
|
||||
# The ) in URL should be escaped
|
||||
assert "\\)" in result
|
||||
|
||||
def test_link_with_surrounding_text(self, adapter):
|
||||
result = adapter.format_message("Visit [Google](https://google.com) today.")
|
||||
assert "[Google](https://google.com)" in result
|
||||
assert "today\\." in result
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# format_message - BUG: italic regex spans newlines
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestItalicNewlineBug:
|
||||
r"""Italic regex ``\*([^*]+)\*`` matched across newlines, corrupting content.
|
||||
|
||||
This affects bullet lists using * markers and any text where * appears
|
||||
at the end of one line and start of another.
|
||||
"""
|
||||
|
||||
def test_bullet_list_not_corrupted(self, adapter):
|
||||
"""Bullet list items using * must NOT be merged into italic."""
|
||||
text = "* Item one\n* Item two\n* Item three"
|
||||
result = adapter.format_message(text)
|
||||
# Each item should appear in the output (not eaten by italic conversion)
|
||||
assert "Item one" in result
|
||||
assert "Item two" in result
|
||||
assert "Item three" in result
|
||||
# Should NOT contain _ (italic markers) wrapping list items
|
||||
assert "_" not in result or "Item" not in result.split("_")[1] if "_" in result else True
|
||||
|
||||
def test_asterisk_list_items_preserved(self, adapter):
|
||||
"""Each * list item should remain as a separate line, not become italic."""
|
||||
text = "* Alpha\n* Beta"
|
||||
result = adapter.format_message(text)
|
||||
# Both items must be present in output
|
||||
assert "Alpha" in result
|
||||
assert "Beta" in result
|
||||
# The text between first * and second * must NOT become italic
|
||||
lines = result.split("\n")
|
||||
assert len(lines) >= 2
|
||||
|
||||
def test_italic_does_not_span_lines(self, adapter):
|
||||
"""*text on\nmultiple lines* should NOT become italic."""
|
||||
text = "Start *across\nlines* end"
|
||||
result = adapter.format_message(text)
|
||||
# Should NOT have underscore italic markers wrapping cross-line text
|
||||
# If this fails, the italic regex is matching across newlines
|
||||
assert "_across\nlines_" not in result
|
||||
|
||||
def test_single_line_italic_still_works(self, adapter):
|
||||
"""Normal single-line italic must still convert correctly."""
|
||||
text = "This is *italic* text"
|
||||
result = adapter.format_message(text)
|
||||
assert "_italic_" in result
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# format_message - strikethrough
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestFormatMessageStrikethrough:
|
||||
def test_strikethrough_converted(self, adapter):
|
||||
result = adapter.format_message("This is ~~deleted~~ text")
|
||||
assert "~deleted~" in result
|
||||
assert "~~" not in result
|
||||
|
||||
def test_strikethrough_with_special_chars(self, adapter):
|
||||
result = adapter.format_message("~~hello.world!~~")
|
||||
assert "~hello\\.world\\!~" in result
|
||||
|
||||
def test_strikethrough_in_code_not_converted(self, adapter):
|
||||
result = adapter.format_message("`~~not struck~~`")
|
||||
assert "`~~not struck~~`" in result
|
||||
|
||||
def test_strikethrough_with_bold(self, adapter):
|
||||
result = adapter.format_message("**bold** and ~~struck~~")
|
||||
assert "*bold*" in result
|
||||
assert "~struck~" in result
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# format_message - spoiler
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestFormatMessageSpoiler:
|
||||
def test_spoiler_converted(self, adapter):
|
||||
result = adapter.format_message("This is ||hidden|| text")
|
||||
assert "||hidden||" in result
|
||||
|
||||
def test_spoiler_with_special_chars(self, adapter):
|
||||
result = adapter.format_message("||hello.world!||")
|
||||
assert "||hello\\.world\\!||" in result
|
||||
|
||||
def test_spoiler_in_code_not_converted(self, adapter):
|
||||
result = adapter.format_message("`||not spoiler||`")
|
||||
assert "`||not spoiler||`" in result
|
||||
|
||||
def test_spoiler_pipes_not_escaped(self, adapter):
|
||||
"""The || delimiters must not be escaped as \\|\\|."""
|
||||
result = adapter.format_message("||secret||")
|
||||
assert "\\|\\|" not in result
|
||||
assert "||secret||" in result
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# format_message - blockquote
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestFormatMessageBlockquote:
|
||||
def test_blockquote_converted(self, adapter):
|
||||
result = adapter.format_message("> This is a quote")
|
||||
assert "> This is a quote" in result
|
||||
# > must NOT be escaped
|
||||
assert "\\>" not in result
|
||||
|
||||
def test_blockquote_with_special_chars(self, adapter):
|
||||
result = adapter.format_message("> Hello (world)!")
|
||||
assert "> Hello \\(world\\)\\!" in result
|
||||
assert "\\>" not in result
|
||||
|
||||
def test_blockquote_multiline(self, adapter):
|
||||
text = "> Line one\n> Line two"
|
||||
result = adapter.format_message(text)
|
||||
assert "> Line one" in result
|
||||
assert "> Line two" in result
|
||||
assert "\\>" not in result
|
||||
|
||||
def test_blockquote_in_code_not_converted(self, adapter):
|
||||
result = adapter.format_message("```\n> not a quote\n```")
|
||||
assert "> not a quote" in result
|
||||
|
||||
def test_nested_blockquote(self, adapter):
|
||||
result = adapter.format_message(">> Nested quote")
|
||||
assert ">> Nested quote" in result
|
||||
assert "\\>" not in result
|
||||
|
||||
def test_gt_in_middle_of_line_still_escaped(self, adapter):
|
||||
"""Only > at line start is a blockquote; mid-line > should be escaped."""
|
||||
result = adapter.format_message("5 > 3")
|
||||
assert "\\>" in result
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# format_message - mixed/complex
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestFormatMessageComplex:
|
||||
def test_code_block_with_bold_outside(self, adapter):
|
||||
text = "**Note:**\n```\ncode here\n```"
|
||||
result = adapter.format_message(text)
|
||||
assert "*Note:*" in result or "*Note\\:*" in result
|
||||
assert "```\ncode here\n```" in result
|
||||
|
||||
def test_bold_inside_code_not_converted(self, adapter):
|
||||
"""Bold markers inside code blocks should not be converted."""
|
||||
text = "```\n**not bold**\n```"
|
||||
result = adapter.format_message(text)
|
||||
assert "**not bold**" in result
|
||||
|
||||
def test_link_inside_code_not_converted(self, adapter):
|
||||
text = "`[not a link](url)`"
|
||||
result = adapter.format_message(text)
|
||||
assert "`[not a link](url)`" in result
|
||||
|
||||
def test_header_after_code_block(self, adapter):
|
||||
text = "```\ncode\n```\n## Title"
|
||||
result = adapter.format_message(text)
|
||||
assert "*Title*" in result
|
||||
assert "```\ncode\n```" in result
|
||||
|
||||
def test_multiple_bold_segments(self, adapter):
|
||||
result = adapter.format_message("**a** and **b** and **c**")
|
||||
assert result.count("*") >= 6 # 3 bold pairs = 6 asterisks
|
||||
|
||||
def test_special_chars_in_plain_text(self, adapter):
|
||||
result = adapter.format_message("Price: $5.00 (50% off!)")
|
||||
assert "\\." in result
|
||||
assert "\\(" in result
|
||||
assert "\\)" in result
|
||||
assert "\\!" in result
|
||||
|
||||
def test_empty_bold(self, adapter):
|
||||
"""**** (empty bold) should not crash."""
|
||||
result = adapter.format_message("****")
|
||||
assert result is not None
|
||||
|
||||
def test_empty_code_block(self, adapter):
|
||||
result = adapter.format_message("```\n```")
|
||||
assert "```" in result
|
||||
|
||||
def test_placeholder_collision(self, adapter):
|
||||
"""Many formatting elements should not cause placeholder collisions."""
|
||||
text = (
|
||||
"# Header\n"
|
||||
"**bold1** *italic1* `code1`\n"
|
||||
"**bold2** *italic2* `code2`\n"
|
||||
"```\nblock\n```\n"
|
||||
"[link](https://url.com)"
|
||||
)
|
||||
result = adapter.format_message(text)
|
||||
# No placeholder tokens should leak into output
|
||||
assert "\x00" not in result
|
||||
# All elements should be present
|
||||
assert "Header" in result
|
||||
assert "block" in result
|
||||
assert "url.com" in result
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# _strip_mdv2 — plaintext fallback
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestStripMdv2:
|
||||
def test_removes_escape_backslashes(self):
|
||||
assert _strip_mdv2(r"hello\.world\!") == "hello.world!"
|
||||
|
||||
def test_removes_bold_markers(self):
|
||||
assert _strip_mdv2("*bold text*") == "bold text"
|
||||
|
||||
def test_removes_italic_markers(self):
|
||||
assert _strip_mdv2("_italic text_") == "italic text"
|
||||
|
||||
def test_removes_both_bold_and_italic(self):
|
||||
result = _strip_mdv2("*bold* and _italic_")
|
||||
assert result == "bold and italic"
|
||||
|
||||
def test_preserves_snake_case(self):
|
||||
assert _strip_mdv2("my_variable_name") == "my_variable_name"
|
||||
|
||||
def test_preserves_multi_underscore_identifier(self):
|
||||
assert _strip_mdv2("some_func_call here") == "some_func_call here"
|
||||
|
||||
def test_plain_text_unchanged(self):
|
||||
assert _strip_mdv2("plain text") == "plain text"
|
||||
|
||||
def test_empty_string(self):
|
||||
assert _strip_mdv2("") == ""
|
||||
|
||||
def test_removes_strikethrough_markers(self):
|
||||
assert _strip_mdv2("~struck text~") == "struck text"
|
||||
|
||||
def test_removes_spoiler_markers(self):
|
||||
assert _strip_mdv2("||hidden text||") == "hidden text"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_escapes_chunk_indicator_for_markdownv2(adapter):
|
||||
adapter.MAX_MESSAGE_LENGTH = 80
|
||||
adapter._bot = MagicMock()
|
||||
|
||||
sent_texts = []
|
||||
|
||||
async def _fake_send_message(**kwargs):
|
||||
sent_texts.append(kwargs["text"])
|
||||
msg = MagicMock()
|
||||
msg.message_id = len(sent_texts)
|
||||
return msg
|
||||
|
||||
adapter._bot.send_message = AsyncMock(side_effect=_fake_send_message)
|
||||
|
||||
content = ("**bold** chunk content " * 12).strip()
|
||||
result = await adapter.send("123", content)
|
||||
|
||||
assert result.success is True
|
||||
assert len(sent_texts) > 1
|
||||
assert re.search(r" \\\([0-9]+/[0-9]+\\\)$", sent_texts[0])
|
||||
assert re.search(r" \\\([0-9]+/[0-9]+\\\)$", sent_texts[-1])
|
||||
49
hermes_code/tests/gateway/test_telegram_photo_interrupts.py
Normal file
49
hermes_code/tests/gateway/test_telegram_photo_interrupts.py
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
import asyncio
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import MessageEvent, MessageType
|
||||
from gateway.session import SessionSource, build_session_key
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
|
||||
class _PendingAdapter:
|
||||
def __init__(self):
|
||||
self._pending_messages = {}
|
||||
|
||||
|
||||
def _make_runner():
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")})
|
||||
runner.adapters = {Platform.TELEGRAM: _PendingAdapter()}
|
||||
runner._running_agents = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._voice_mode = {}
|
||||
runner._is_user_authorized = lambda _source: True
|
||||
return runner
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_does_not_priority_interrupt_photo_followup():
|
||||
runner = _make_runner()
|
||||
source = SessionSource(platform=Platform.TELEGRAM, chat_id="12345", chat_type="dm")
|
||||
session_key = build_session_key(source)
|
||||
running_agent = MagicMock()
|
||||
runner._running_agents[session_key] = running_agent
|
||||
|
||||
event = MessageEvent(
|
||||
text="caption",
|
||||
message_type=MessageType.PHOTO,
|
||||
source=source,
|
||||
media_urls=["/tmp/photo-a.jpg"],
|
||||
media_types=["image/jpeg"],
|
||||
)
|
||||
|
||||
result = await runner._handle_message(event)
|
||||
|
||||
assert result is None
|
||||
running_agent.interrupt.assert_not_called()
|
||||
assert runner.adapters[Platform.TELEGRAM]._pending_messages[session_key] is event
|
||||
121
hermes_code/tests/gateway/test_telegram_text_batching.py
Normal file
121
hermes_code/tests/gateway/test_telegram_text_batching.py
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
"""Tests for Telegram text message aggregation.
|
||||
|
||||
When a user sends a long message, Telegram clients split it into multiple
|
||||
updates. The TelegramAdapter should buffer rapid successive text messages
|
||||
from the same session and aggregate them before dispatching.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import MessageEvent, MessageType, SessionSource
|
||||
|
||||
|
||||
def _make_adapter():
|
||||
"""Create a minimal TelegramAdapter for testing text batching."""
|
||||
from gateway.platforms.telegram import TelegramAdapter
|
||||
|
||||
config = PlatformConfig(enabled=True, token="test-token")
|
||||
adapter = object.__new__(TelegramAdapter)
|
||||
adapter._platform = Platform.TELEGRAM
|
||||
adapter.config = config
|
||||
adapter._pending_text_batches = {}
|
||||
adapter._pending_text_batch_tasks = {}
|
||||
adapter._text_batch_delay_seconds = 0.1 # fast for tests
|
||||
adapter._active_sessions = {}
|
||||
adapter._pending_messages = {}
|
||||
adapter._message_handler = AsyncMock()
|
||||
adapter.handle_message = AsyncMock()
|
||||
return adapter
|
||||
|
||||
|
||||
def _make_event(text: str, chat_id: str = "12345") -> MessageEvent:
|
||||
return MessageEvent(
|
||||
text=text,
|
||||
message_type=MessageType.TEXT,
|
||||
source=SessionSource(platform=Platform.TELEGRAM, chat_id=chat_id, chat_type="dm"),
|
||||
)
|
||||
|
||||
|
||||
class TestTextBatching:
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_message_dispatched_after_delay(self):
|
||||
adapter = _make_adapter()
|
||||
event = _make_event("hello world")
|
||||
|
||||
adapter._enqueue_text_event(event)
|
||||
|
||||
# Not dispatched yet
|
||||
adapter.handle_message.assert_not_called()
|
||||
|
||||
# Wait for flush
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
adapter.handle_message.assert_called_once()
|
||||
dispatched = adapter.handle_message.call_args[0][0]
|
||||
assert dispatched.text == "hello world"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_split_messages_aggregated(self):
|
||||
"""Two rapid messages from the same chat should be merged."""
|
||||
adapter = _make_adapter()
|
||||
|
||||
adapter._enqueue_text_event(_make_event("This is part one of a long"))
|
||||
await asyncio.sleep(0.02) # small gap, within batch window
|
||||
adapter._enqueue_text_event(_make_event("message that was split by Telegram."))
|
||||
|
||||
# Not dispatched yet (timer restarted)
|
||||
adapter.handle_message.assert_not_called()
|
||||
|
||||
# Wait for flush
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
adapter.handle_message.assert_called_once()
|
||||
dispatched = adapter.handle_message.call_args[0][0]
|
||||
assert "part one" in dispatched.text
|
||||
assert "split by Telegram" in dispatched.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_three_way_split_aggregated(self):
|
||||
"""Three rapid messages should all merge."""
|
||||
adapter = _make_adapter()
|
||||
|
||||
adapter._enqueue_text_event(_make_event("chunk 1"))
|
||||
await asyncio.sleep(0.02)
|
||||
adapter._enqueue_text_event(_make_event("chunk 2"))
|
||||
await asyncio.sleep(0.02)
|
||||
adapter._enqueue_text_event(_make_event("chunk 3"))
|
||||
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
adapter.handle_message.assert_called_once()
|
||||
text = adapter.handle_message.call_args[0][0].text
|
||||
assert "chunk 1" in text
|
||||
assert "chunk 2" in text
|
||||
assert "chunk 3" in text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_chats_not_merged(self):
|
||||
"""Messages from different chats should be separate batches."""
|
||||
adapter = _make_adapter()
|
||||
|
||||
adapter._enqueue_text_event(_make_event("from user A", chat_id="111"))
|
||||
adapter._enqueue_text_event(_make_event("from user B", chat_id="222"))
|
||||
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
assert adapter.handle_message.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_cleans_up_after_flush(self):
|
||||
"""After flushing, internal state should be clean."""
|
||||
adapter = _make_adapter()
|
||||
|
||||
adapter._enqueue_text_event(_make_event("test"))
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
assert len(adapter._pending_text_batches) == 0
|
||||
assert len(adapter._pending_text_batch_tasks) == 0
|
||||
208
hermes_code/tests/gateway/test_title_command.py
Normal file
208
hermes_code/tests/gateway/test_title_command.py
Normal file
|
|
@ -0,0 +1,208 @@
|
|||
"""Tests for /title gateway slash command.
|
||||
|
||||
Tests the _handle_title_command handler (set/show session titles)
|
||||
across all gateway messenger platforms.
|
||||
"""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform
|
||||
from gateway.platforms.base import MessageEvent
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
def _make_event(text="/title", platform=Platform.TELEGRAM,
|
||||
user_id="12345", chat_id="67890"):
|
||||
"""Build a MessageEvent for testing."""
|
||||
source = SessionSource(
|
||||
platform=platform,
|
||||
user_id=user_id,
|
||||
chat_id=chat_id,
|
||||
user_name="testuser",
|
||||
)
|
||||
return MessageEvent(text=text, source=source)
|
||||
|
||||
|
||||
def _make_runner(session_db=None):
|
||||
"""Create a bare GatewayRunner with a mock session_store and optional session_db."""
|
||||
from gateway.run import GatewayRunner
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.adapters = {}
|
||||
runner._voice_mode = {}
|
||||
runner._session_db = session_db
|
||||
|
||||
# Mock session_store that returns a session entry with a known session_id
|
||||
mock_session_entry = MagicMock()
|
||||
mock_session_entry.session_id = "test_session_123"
|
||||
mock_session_entry.session_key = "telegram:12345:67890"
|
||||
mock_store = MagicMock()
|
||||
mock_store.get_or_create_session.return_value = mock_session_entry
|
||||
runner.session_store = mock_store
|
||||
|
||||
return runner
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _handle_title_command
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHandleTitleCommand:
|
||||
"""Tests for GatewayRunner._handle_title_command."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_title(self, tmp_path):
|
||||
"""Setting a title returns confirmation."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("test_session_123", "telegram")
|
||||
|
||||
runner = _make_runner(session_db=db)
|
||||
event = _make_event(text="/title My Research Project")
|
||||
result = await runner._handle_title_command(event)
|
||||
assert "My Research Project" in result
|
||||
assert "✏️" in result
|
||||
|
||||
# Verify in DB
|
||||
assert db.get_session_title("test_session_123") == "My Research Project"
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_show_title_when_set(self, tmp_path):
|
||||
"""Showing title when one is set returns the title."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("test_session_123", "telegram")
|
||||
db.set_session_title("test_session_123", "Existing Title")
|
||||
|
||||
runner = _make_runner(session_db=db)
|
||||
event = _make_event(text="/title")
|
||||
result = await runner._handle_title_command(event)
|
||||
assert "Existing Title" in result
|
||||
assert "📌" in result
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_show_title_when_not_set(self, tmp_path):
|
||||
"""Showing title when none is set returns usage hint."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("test_session_123", "telegram")
|
||||
|
||||
runner = _make_runner(session_db=db)
|
||||
event = _make_event(text="/title")
|
||||
result = await runner._handle_title_command(event)
|
||||
assert "No title set" in result
|
||||
assert "/title" in result
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_title_conflict(self, tmp_path):
|
||||
"""Setting a title already used by another session returns error."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("other_session", "telegram")
|
||||
db.set_session_title("other_session", "Taken Title")
|
||||
db.create_session("test_session_123", "telegram")
|
||||
|
||||
runner = _make_runner(session_db=db)
|
||||
event = _make_event(text="/title Taken Title")
|
||||
result = await runner._handle_title_command(event)
|
||||
assert "already in use" in result
|
||||
assert "⚠️" in result
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_session_db(self):
|
||||
"""Returns error when session database is not available."""
|
||||
runner = _make_runner(session_db=None)
|
||||
event = _make_event(text="/title My Title")
|
||||
result = await runner._handle_title_command(event)
|
||||
assert "not available" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_title_too_long(self, tmp_path):
|
||||
"""Setting a title that exceeds max length returns error."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("test_session_123", "telegram")
|
||||
|
||||
runner = _make_runner(session_db=db)
|
||||
long_title = "A" * 150
|
||||
event = _make_event(text=f"/title {long_title}")
|
||||
result = await runner._handle_title_command(event)
|
||||
assert "too long" in result
|
||||
assert "⚠️" in result
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_title_control_chars_sanitized(self, tmp_path):
|
||||
"""Control characters are stripped and sanitized title is stored."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("test_session_123", "telegram")
|
||||
|
||||
runner = _make_runner(session_db=db)
|
||||
event = _make_event(text="/title hello\x00world")
|
||||
result = await runner._handle_title_command(event)
|
||||
assert "helloworld" in result
|
||||
assert db.get_session_title("test_session_123") == "helloworld"
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_title_only_control_chars(self, tmp_path):
|
||||
"""Title with only control chars returns empty error."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("test_session_123", "telegram")
|
||||
|
||||
runner = _make_runner(session_db=db)
|
||||
event = _make_event(text="/title \x00\x01\x02")
|
||||
result = await runner._handle_title_command(event)
|
||||
assert "empty after cleanup" in result
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_works_across_platforms(self, tmp_path):
|
||||
"""The /title command works for Discord, Slack, and WhatsApp too."""
|
||||
from hermes_state import SessionDB
|
||||
for platform in [Platform.DISCORD, Platform.TELEGRAM]:
|
||||
db = SessionDB(db_path=tmp_path / f"state_{platform.value}.db")
|
||||
db.create_session("test_session_123", platform.value)
|
||||
|
||||
runner = _make_runner(session_db=db)
|
||||
event = _make_event(text="/title Cross-Platform Test", platform=platform)
|
||||
result = await runner._handle_title_command(event)
|
||||
assert "Cross-Platform Test" in result
|
||||
assert db.get_session_title("test_session_123") == "Cross-Platform Test"
|
||||
db.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# /title in help and known_commands
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTitleInHelp:
|
||||
"""Verify /title appears in help text and known commands."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_title_in_help_output(self):
|
||||
"""The /help output includes /title."""
|
||||
runner = _make_runner()
|
||||
event = _make_event(text="/help")
|
||||
# Need hooks for help command
|
||||
from gateway.hooks import HookRegistry
|
||||
runner.hooks = HookRegistry()
|
||||
result = await runner._handle_help_command(event)
|
||||
assert "/title" in result
|
||||
|
||||
def test_title_is_known_command(self):
|
||||
"""The /title command is in the _known_commands set."""
|
||||
from gateway.run import GatewayRunner
|
||||
import inspect
|
||||
source = inspect.getsource(GatewayRunner._handle_message)
|
||||
assert '"title"' in source
|
||||
267
hermes_code/tests/gateway/test_transcript_offset.py
Normal file
267
hermes_code/tests/gateway/test_transcript_offset.py
Normal file
|
|
@ -0,0 +1,267 @@
|
|||
"""Tests for transcript history offset fix.
|
||||
|
||||
Regression tests for a bug where the gateway transcript lost 1 message
|
||||
per turn from turn 2 onwards. The raw transcript history includes
|
||||
``session_meta`` entries that are filtered out before being passed to
|
||||
the agent. The agent returns messages built from this filtered history
|
||||
plus new messages from the current turn.
|
||||
|
||||
The old code used ``len(history)`` (raw count, includes session_meta)
|
||||
to slice ``agent_messages``, which caused the slice to skip valid new
|
||||
messages. The fix adds ``history_offset`` (the filtered history length)
|
||||
to ``_run_agent``'s return dict and uses it for the slice.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers - replicate the filtering logic from _run_agent
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _filter_history(history: list) -> list:
|
||||
"""Replicate the agent_history filtering from GatewayRunner._run_agent.
|
||||
|
||||
Strips session_meta and system messages, exactly as the real code does.
|
||||
"""
|
||||
agent_history = []
|
||||
for msg in history:
|
||||
role = msg.get("role")
|
||||
if not role:
|
||||
continue
|
||||
if role in ("session_meta",):
|
||||
continue
|
||||
if role == "system":
|
||||
continue
|
||||
|
||||
has_tool_calls = "tool_calls" in msg
|
||||
has_tool_call_id = "tool_call_id" in msg
|
||||
is_tool_message = role == "tool"
|
||||
|
||||
if has_tool_calls or has_tool_call_id or is_tool_message:
|
||||
clean_msg = {k: v for k, v in msg.items() if k != "timestamp"}
|
||||
agent_history.append(clean_msg)
|
||||
else:
|
||||
content = msg.get("content")
|
||||
if content:
|
||||
agent_history.append({"role": role, "content": content})
|
||||
return agent_history
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTranscriptHistoryOffset:
|
||||
"""Verify the transcript extraction uses the filtered history length."""
|
||||
|
||||
def test_session_meta_causes_offset_mismatch(self):
|
||||
"""Turn 2: session_meta makes len(history) > len(agent_history).
|
||||
|
||||
- history (raw): 1 session_meta + 2 conversation = 3 entries
|
||||
- agent_history (filtered): 2 entries
|
||||
- Agent returns 2 old + 2 new = 4 messages
|
||||
- OLD: agent_messages[3:] = 1 message (lost the user message)
|
||||
- FIX: agent_messages[2:] = 2 messages (correct)
|
||||
"""
|
||||
history = [
|
||||
{"role": "session_meta", "tools": [], "model": "gpt-4",
|
||||
"platform": "telegram", "timestamp": "t0"},
|
||||
{"role": "user", "content": "Hello", "timestamp": "t1"},
|
||||
{"role": "assistant", "content": "Hi there!", "timestamp": "t1"},
|
||||
]
|
||||
|
||||
agent_history = _filter_history(history)
|
||||
assert len(agent_history) == 2 # session_meta stripped
|
||||
|
||||
# Agent returns: filtered history (2) + new turn (2)
|
||||
agent_messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "user", "content": "What is Python?"},
|
||||
{"role": "assistant", "content": "A programming language."},
|
||||
]
|
||||
|
||||
# OLD behavior: len(history) = 3, skips too many
|
||||
old_offset = len(history)
|
||||
old_new = (agent_messages[old_offset:]
|
||||
if len(agent_messages) > old_offset
|
||||
else agent_messages)
|
||||
assert len(old_new) == 1 # BUG: lost the user message
|
||||
|
||||
# FIXED behavior: history_offset = 2
|
||||
history_offset = len(agent_history)
|
||||
fixed_new = (agent_messages[history_offset:]
|
||||
if len(agent_messages) > history_offset
|
||||
else [])
|
||||
assert len(fixed_new) == 2
|
||||
assert fixed_new[0]["content"] == "What is Python?"
|
||||
assert fixed_new[1]["content"] == "A programming language."
|
||||
|
||||
def test_no_session_meta_same_result(self):
|
||||
"""First turn has no session_meta, so both approaches agree."""
|
||||
history = []
|
||||
agent_history = _filter_history(history)
|
||||
assert len(agent_history) == 0
|
||||
|
||||
agent_messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi!"},
|
||||
]
|
||||
|
||||
old_new = (agent_messages[len(history):]
|
||||
if len(agent_messages) > len(history)
|
||||
else agent_messages)
|
||||
fixed_new = (agent_messages[len(agent_history):]
|
||||
if len(agent_messages) > len(agent_history)
|
||||
else [])
|
||||
|
||||
assert old_new == fixed_new
|
||||
assert len(fixed_new) == 2
|
||||
|
||||
def test_multiple_session_meta_larger_drift(self):
|
||||
"""Two session_meta entries double the offset error.
|
||||
|
||||
This can happen when the session spans tool definition changes
|
||||
or model switches that each write a new session_meta record.
|
||||
"""
|
||||
history = [
|
||||
{"role": "session_meta", "tools": [], "timestamp": "t0"},
|
||||
{"role": "user", "content": "msg1", "timestamp": "t1"},
|
||||
{"role": "assistant", "content": "reply1", "timestamp": "t1"},
|
||||
{"role": "session_meta", "tools": ["new_tool"], "timestamp": "t2"},
|
||||
{"role": "user", "content": "msg2", "timestamp": "t3"},
|
||||
{"role": "assistant", "content": "reply2", "timestamp": "t3"},
|
||||
]
|
||||
|
||||
agent_history = _filter_history(history)
|
||||
assert len(agent_history) == 4
|
||||
assert len(history) == 6 # 2 extra session_meta entries
|
||||
|
||||
# Agent returns 4 old + 2 new = 6 total
|
||||
agent_messages = [
|
||||
{"role": "user", "content": "msg1"},
|
||||
{"role": "assistant", "content": "reply1"},
|
||||
{"role": "user", "content": "msg2"},
|
||||
{"role": "assistant", "content": "reply2"},
|
||||
{"role": "user", "content": "msg3"},
|
||||
{"role": "assistant", "content": "reply3"},
|
||||
]
|
||||
|
||||
# OLD: len(history) == len(agent_messages) == 6 -> else branch
|
||||
old_offset = len(history)
|
||||
old_new = (agent_messages[old_offset:]
|
||||
if len(agent_messages) > old_offset
|
||||
else agent_messages)
|
||||
# BUG: treats ALL messages as new (duplicates entire history)
|
||||
assert old_new == agent_messages
|
||||
|
||||
# FIXED: history_offset = 4
|
||||
fixed_new = (agent_messages[len(agent_history):]
|
||||
if len(agent_messages) > len(agent_history)
|
||||
else [])
|
||||
assert len(fixed_new) == 2
|
||||
assert fixed_new[0]["content"] == "msg3"
|
||||
assert fixed_new[1]["content"] == "reply3"
|
||||
|
||||
def test_system_messages_also_filtered(self):
|
||||
"""system messages in history are also stripped from agent_history."""
|
||||
history = [
|
||||
{"role": "session_meta", "tools": [], "timestamp": "t0"},
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hi", "timestamp": "t1"},
|
||||
{"role": "assistant", "content": "Hello!", "timestamp": "t1"},
|
||||
]
|
||||
|
||||
agent_history = _filter_history(history)
|
||||
assert len(agent_history) == 2 # only user + assistant
|
||||
|
||||
agent_messages = [
|
||||
{"role": "user", "content": "Hi"},
|
||||
{"role": "assistant", "content": "Hello!"},
|
||||
{"role": "user", "content": "New question"},
|
||||
{"role": "assistant", "content": "New answer"},
|
||||
]
|
||||
|
||||
# OLD: len(history) = 4, skips everything
|
||||
old_offset = len(history)
|
||||
old_new = (agent_messages[old_offset:]
|
||||
if len(agent_messages) > old_offset
|
||||
else agent_messages)
|
||||
assert old_new == agent_messages # BUG: all treated as new
|
||||
|
||||
# FIXED
|
||||
fixed_new = (agent_messages[len(agent_history):]
|
||||
if len(agent_messages) > len(agent_history)
|
||||
else [])
|
||||
assert len(fixed_new) == 2
|
||||
assert fixed_new[0]["content"] == "New question"
|
||||
|
||||
def test_else_branch_returns_empty_list(self):
|
||||
"""When agent has fewer messages than offset, return [] not all.
|
||||
|
||||
The old code had ``else agent_messages`` which would treat the
|
||||
entire message list as new when the agent compressed or dropped
|
||||
messages. The fix changes this to ``else []``, falling through
|
||||
to the simple user/assistant fallback path.
|
||||
"""
|
||||
history = [
|
||||
{"role": "session_meta", "tools": [], "timestamp": "t0"},
|
||||
{"role": "user", "content": "Hello", "timestamp": "t1"},
|
||||
{"role": "assistant", "content": "Hi!", "timestamp": "t1"},
|
||||
]
|
||||
|
||||
# Agent compressed and returned fewer messages than history
|
||||
agent_messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi!"},
|
||||
]
|
||||
|
||||
history_offset = len(_filter_history(history)) # 2
|
||||
new_messages = (agent_messages[history_offset:]
|
||||
if len(agent_messages) > history_offset
|
||||
else [])
|
||||
# 2 == 2, so no new messages - falls to fallback
|
||||
assert new_messages == []
|
||||
|
||||
def test_tool_call_messages_preserved_in_filter(self):
|
||||
"""Tool call messages pass through the filter, keeping offset correct."""
|
||||
history = [
|
||||
{"role": "session_meta", "tools": [], "timestamp": "t0"},
|
||||
{"role": "user", "content": "Search for cats", "timestamp": "t1"},
|
||||
{"role": "assistant", "content": None, "timestamp": "t1",
|
||||
"tool_calls": [{"id": "tc1", "function": {"name": "web_search"}}]},
|
||||
{"role": "tool", "tool_call_id": "tc1",
|
||||
"content": "Results about cats", "timestamp": "t1"},
|
||||
{"role": "assistant", "content": "Here are results.",
|
||||
"timestamp": "t1"},
|
||||
]
|
||||
|
||||
agent_history = _filter_history(history)
|
||||
# session_meta filtered, but tool_calls/tool messages kept
|
||||
assert len(agent_history) == 4
|
||||
assert len(history) == 5 # 1 session_meta extra
|
||||
|
||||
agent_messages = [
|
||||
{"role": "user", "content": "Search for cats"},
|
||||
{"role": "assistant", "content": None,
|
||||
"tool_calls": [{"id": "tc1", "function": {"name": "web_search"}}]},
|
||||
{"role": "tool", "tool_call_id": "tc1", "content": "Results about cats"},
|
||||
{"role": "assistant", "content": "Here are results."},
|
||||
{"role": "user", "content": "Now search for dogs"},
|
||||
{"role": "assistant", "content": "Dog results here."},
|
||||
]
|
||||
|
||||
# OLD: len(history) = 5, agent_messages[5:] = 1 message (lost user msg)
|
||||
old_new = (agent_messages[len(history):]
|
||||
if len(agent_messages) > len(history)
|
||||
else agent_messages)
|
||||
assert len(old_new) == 1 # BUG
|
||||
|
||||
# FIXED
|
||||
fixed_new = (agent_messages[len(agent_history):]
|
||||
if len(agent_messages) > len(agent_history)
|
||||
else [])
|
||||
assert len(fixed_new) == 2
|
||||
assert fixed_new[0]["content"] == "Now search for dogs"
|
||||
assert fixed_new[1]["content"] == "Dog results here."
|
||||
137
hermes_code/tests/gateway/test_unauthorized_dm_behavior.py
Normal file
137
hermes_code/tests/gateway/test_unauthorized_dm_behavior.py
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import MessageEvent
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
def _clear_auth_env(monkeypatch) -> None:
|
||||
for key in (
|
||||
"TELEGRAM_ALLOWED_USERS",
|
||||
"DISCORD_ALLOWED_USERS",
|
||||
"WHATSAPP_ALLOWED_USERS",
|
||||
"SLACK_ALLOWED_USERS",
|
||||
"SIGNAL_ALLOWED_USERS",
|
||||
"EMAIL_ALLOWED_USERS",
|
||||
"SMS_ALLOWED_USERS",
|
||||
"MATTERMOST_ALLOWED_USERS",
|
||||
"MATRIX_ALLOWED_USERS",
|
||||
"DINGTALK_ALLOWED_USERS",
|
||||
"GATEWAY_ALLOWED_USERS",
|
||||
"TELEGRAM_ALLOW_ALL_USERS",
|
||||
"DISCORD_ALLOW_ALL_USERS",
|
||||
"WHATSAPP_ALLOW_ALL_USERS",
|
||||
"SLACK_ALLOW_ALL_USERS",
|
||||
"SIGNAL_ALLOW_ALL_USERS",
|
||||
"EMAIL_ALLOW_ALL_USERS",
|
||||
"SMS_ALLOW_ALL_USERS",
|
||||
"MATTERMOST_ALLOW_ALL_USERS",
|
||||
"MATRIX_ALLOW_ALL_USERS",
|
||||
"DINGTALK_ALLOW_ALL_USERS",
|
||||
"GATEWAY_ALLOW_ALL_USERS",
|
||||
):
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
|
||||
|
||||
def _make_event(platform: Platform, user_id: str, chat_id: str) -> MessageEvent:
|
||||
return MessageEvent(
|
||||
text="hello",
|
||||
message_id="m1",
|
||||
source=SessionSource(
|
||||
platform=platform,
|
||||
user_id=user_id,
|
||||
chat_id=chat_id,
|
||||
user_name="tester",
|
||||
chat_type="dm",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _make_runner(platform: Platform, config: GatewayConfig):
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = config
|
||||
adapter = SimpleNamespace(send=AsyncMock())
|
||||
runner.adapters = {platform: adapter}
|
||||
runner.pairing_store = MagicMock()
|
||||
runner.pairing_store.is_approved.return_value = False
|
||||
return runner, adapter
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unauthorized_dm_pairs_by_default(monkeypatch):
|
||||
_clear_auth_env(monkeypatch)
|
||||
config = GatewayConfig(
|
||||
platforms={Platform.WHATSAPP: PlatformConfig(enabled=True)},
|
||||
)
|
||||
runner, adapter = _make_runner(Platform.WHATSAPP, config)
|
||||
runner.pairing_store.generate_code.return_value = "ABC12DEF"
|
||||
|
||||
result = await runner._handle_message(
|
||||
_make_event(
|
||||
Platform.WHATSAPP,
|
||||
"15551234567@s.whatsapp.net",
|
||||
"15551234567@s.whatsapp.net",
|
||||
)
|
||||
)
|
||||
|
||||
assert result is None
|
||||
runner.pairing_store.generate_code.assert_called_once_with(
|
||||
"whatsapp",
|
||||
"15551234567@s.whatsapp.net",
|
||||
"tester",
|
||||
)
|
||||
adapter.send.assert_awaited_once()
|
||||
assert "ABC12DEF" in adapter.send.await_args.args[1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unauthorized_whatsapp_dm_can_be_ignored(monkeypatch):
|
||||
_clear_auth_env(monkeypatch)
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.WHATSAPP: PlatformConfig(
|
||||
enabled=True,
|
||||
extra={"unauthorized_dm_behavior": "ignore"},
|
||||
),
|
||||
},
|
||||
)
|
||||
runner, adapter = _make_runner(Platform.WHATSAPP, config)
|
||||
|
||||
result = await runner._handle_message(
|
||||
_make_event(
|
||||
Platform.WHATSAPP,
|
||||
"15551234567@s.whatsapp.net",
|
||||
"15551234567@s.whatsapp.net",
|
||||
)
|
||||
)
|
||||
|
||||
assert result is None
|
||||
runner.pairing_store.generate_code.assert_not_called()
|
||||
adapter.send.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_ignore_suppresses_pairing_reply(monkeypatch):
|
||||
_clear_auth_env(monkeypatch)
|
||||
config = GatewayConfig(
|
||||
unauthorized_dm_behavior="ignore",
|
||||
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")},
|
||||
)
|
||||
runner, adapter = _make_runner(Platform.TELEGRAM, config)
|
||||
|
||||
result = await runner._handle_message(
|
||||
_make_event(
|
||||
Platform.TELEGRAM,
|
||||
"12345",
|
||||
"12345",
|
||||
)
|
||||
)
|
||||
|
||||
assert result is None
|
||||
runner.pairing_store.generate_code.assert_not_called()
|
||||
adapter.send.assert_not_awaited()
|
||||
637
hermes_code/tests/gateway/test_update_command.py
Normal file
637
hermes_code/tests/gateway/test_update_command.py
Normal file
|
|
@ -0,0 +1,637 @@
|
|||
"""Tests for /update gateway slash command.
|
||||
|
||||
Tests both the _handle_update_command handler (spawns update process) and
|
||||
the _send_update_notification startup hook (sends results after restart).
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform
|
||||
from gateway.platforms.base import MessageEvent
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
def _make_event(text="/update", platform=Platform.TELEGRAM,
|
||||
user_id="12345", chat_id="67890"):
|
||||
"""Build a MessageEvent for testing."""
|
||||
source = SessionSource(
|
||||
platform=platform,
|
||||
user_id=user_id,
|
||||
chat_id=chat_id,
|
||||
user_name="testuser",
|
||||
)
|
||||
return MessageEvent(text=text, source=source)
|
||||
|
||||
|
||||
def _make_runner():
|
||||
"""Create a bare GatewayRunner without calling __init__."""
|
||||
from gateway.run import GatewayRunner
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.adapters = {}
|
||||
runner._voice_mode = {}
|
||||
return runner
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _handle_update_command
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHandleUpdateCommand:
|
||||
"""Tests for GatewayRunner._handle_update_command."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_git_directory(self, tmp_path):
|
||||
"""Returns an error when .git does not exist."""
|
||||
runner = _make_runner()
|
||||
event = _make_event()
|
||||
# Point _hermes_home to tmp_path and project_root to a dir without .git
|
||||
fake_root = tmp_path / "project"
|
||||
fake_root.mkdir()
|
||||
with patch("gateway.run._hermes_home", tmp_path), \
|
||||
patch("gateway.run.Path") as MockPath:
|
||||
# Path(__file__).parent.parent.resolve() -> fake_root
|
||||
MockPath.return_value = MagicMock()
|
||||
MockPath.__truediv__ = Path.__truediv__
|
||||
# Easier: just patch the __file__ resolution in the method
|
||||
pass
|
||||
|
||||
# Simpler approach — mock at method level using a wrapper
|
||||
from gateway.run import GatewayRunner
|
||||
runner = _make_runner()
|
||||
|
||||
with patch("gateway.run._hermes_home", tmp_path):
|
||||
# The handler does Path(__file__).parent.parent.resolve()
|
||||
# We need to make project_root / '.git' not exist.
|
||||
# Since Path(__file__) resolves to the real gateway/run.py,
|
||||
# project_root will be the real hermes-agent dir (which HAS .git).
|
||||
# Patch Path to control this.
|
||||
original_path = Path
|
||||
|
||||
class FakePath(type(Path())):
|
||||
pass
|
||||
|
||||
# Actually, simplest: just patch the specific file attr
|
||||
fake_file = str(fake_root / "gateway" / "run.py")
|
||||
(fake_root / "gateway").mkdir(parents=True)
|
||||
(fake_root / "gateway" / "run.py").touch()
|
||||
|
||||
with patch("gateway.run.__file__", fake_file):
|
||||
result = await runner._handle_update_command(event)
|
||||
|
||||
assert "Not a git repository" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_hermes_binary(self, tmp_path):
|
||||
"""Returns error when hermes is not on PATH and hermes_cli is not importable."""
|
||||
runner = _make_runner()
|
||||
event = _make_event()
|
||||
|
||||
# Create project dir WITH .git
|
||||
fake_root = tmp_path / "project"
|
||||
fake_root.mkdir()
|
||||
(fake_root / ".git").mkdir()
|
||||
(fake_root / "gateway").mkdir()
|
||||
(fake_root / "gateway" / "run.py").touch()
|
||||
fake_file = str(fake_root / "gateway" / "run.py")
|
||||
|
||||
with patch("gateway.run._hermes_home", tmp_path), \
|
||||
patch("gateway.run.__file__", fake_file), \
|
||||
patch("shutil.which", return_value=None), \
|
||||
patch("importlib.util.find_spec", return_value=None):
|
||||
result = await runner._handle_update_command(event)
|
||||
|
||||
assert "Could not locate" in result
|
||||
assert "hermes update" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_to_sys_executable(self, tmp_path):
|
||||
"""Falls back to sys.executable -m hermes_cli.main when hermes not on PATH."""
|
||||
runner = _make_runner()
|
||||
event = _make_event()
|
||||
|
||||
fake_root = tmp_path / "project"
|
||||
fake_root.mkdir()
|
||||
(fake_root / ".git").mkdir()
|
||||
(fake_root / "gateway").mkdir()
|
||||
(fake_root / "gateway" / "run.py").touch()
|
||||
fake_file = str(fake_root / "gateway" / "run.py")
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
|
||||
mock_popen = MagicMock()
|
||||
fake_spec = MagicMock()
|
||||
|
||||
with patch("gateway.run._hermes_home", hermes_home), \
|
||||
patch("gateway.run.__file__", fake_file), \
|
||||
patch("shutil.which", return_value=None), \
|
||||
patch("importlib.util.find_spec", return_value=fake_spec), \
|
||||
patch("subprocess.Popen", mock_popen):
|
||||
result = await runner._handle_update_command(event)
|
||||
|
||||
assert "Starting Hermes update" in result
|
||||
call_args = mock_popen.call_args[0][0]
|
||||
# The update_cmd uses sys.executable -m hermes_cli.main
|
||||
joined = " ".join(call_args) if isinstance(call_args, list) else call_args
|
||||
assert "hermes_cli.main" in joined or "bash" in call_args[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_hermes_bin_prefers_which(self, tmp_path):
|
||||
"""_resolve_hermes_bin returns argv parts from shutil.which when available."""
|
||||
from gateway.run import _resolve_hermes_bin
|
||||
|
||||
with patch("shutil.which", return_value="/custom/path/hermes"):
|
||||
result = _resolve_hermes_bin()
|
||||
|
||||
assert result == ["/custom/path/hermes"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_hermes_bin_fallback(self):
|
||||
"""_resolve_hermes_bin falls back to sys.executable argv when which fails."""
|
||||
import sys
|
||||
from gateway.run import _resolve_hermes_bin
|
||||
|
||||
fake_spec = MagicMock()
|
||||
with patch("shutil.which", return_value=None), \
|
||||
patch("importlib.util.find_spec", return_value=fake_spec):
|
||||
result = _resolve_hermes_bin()
|
||||
|
||||
assert result == [sys.executable, "-m", "hermes_cli.main"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_hermes_bin_returns_none_when_both_fail(self):
|
||||
"""_resolve_hermes_bin returns None when both strategies fail."""
|
||||
from gateway.run import _resolve_hermes_bin
|
||||
|
||||
with patch("shutil.which", return_value=None), \
|
||||
patch("importlib.util.find_spec", return_value=None):
|
||||
result = _resolve_hermes_bin()
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_writes_pending_marker(self, tmp_path):
|
||||
"""Writes .update_pending.json with correct platform and chat info."""
|
||||
runner = _make_runner()
|
||||
event = _make_event(platform=Platform.TELEGRAM, chat_id="99999")
|
||||
|
||||
fake_root = tmp_path / "project"
|
||||
fake_root.mkdir()
|
||||
(fake_root / ".git").mkdir()
|
||||
(fake_root / "gateway").mkdir()
|
||||
(fake_root / "gateway" / "run.py").touch()
|
||||
fake_file = str(fake_root / "gateway" / "run.py")
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
|
||||
with patch("gateway.run._hermes_home", hermes_home), \
|
||||
patch("gateway.run.__file__", fake_file), \
|
||||
patch("shutil.which", side_effect=lambda x: "/usr/bin/hermes" if x == "hermes" else "/usr/bin/systemd-run"), \
|
||||
patch("subprocess.Popen"):
|
||||
result = await runner._handle_update_command(event)
|
||||
|
||||
pending_path = hermes_home / ".update_pending.json"
|
||||
assert pending_path.exists()
|
||||
data = json.loads(pending_path.read_text())
|
||||
assert data["platform"] == "telegram"
|
||||
assert data["chat_id"] == "99999"
|
||||
assert "timestamp" in data
|
||||
assert not (hermes_home / ".update_exit_code").exists()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spawns_systemd_run(self, tmp_path):
|
||||
"""Uses systemd-run when available."""
|
||||
runner = _make_runner()
|
||||
event = _make_event()
|
||||
|
||||
fake_root = tmp_path / "project"
|
||||
fake_root.mkdir()
|
||||
(fake_root / ".git").mkdir()
|
||||
(fake_root / "gateway").mkdir()
|
||||
(fake_root / "gateway" / "run.py").touch()
|
||||
fake_file = str(fake_root / "gateway" / "run.py")
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
|
||||
mock_popen = MagicMock()
|
||||
with patch("gateway.run._hermes_home", hermes_home), \
|
||||
patch("gateway.run.__file__", fake_file), \
|
||||
patch("shutil.which", side_effect=lambda x: f"/usr/bin/{x}"), \
|
||||
patch("subprocess.Popen", mock_popen):
|
||||
result = await runner._handle_update_command(event)
|
||||
|
||||
# Verify systemd-run was used
|
||||
call_args = mock_popen.call_args[0][0]
|
||||
assert call_args[0] == "/usr/bin/systemd-run"
|
||||
assert "--scope" in call_args
|
||||
assert ".update_exit_code" in call_args[-1]
|
||||
assert "Starting Hermes update" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_nohup_when_no_systemd_run(self, tmp_path):
|
||||
"""Falls back to nohup when systemd-run is not available."""
|
||||
runner = _make_runner()
|
||||
event = _make_event()
|
||||
|
||||
fake_root = tmp_path / "project"
|
||||
fake_root.mkdir()
|
||||
(fake_root / ".git").mkdir()
|
||||
(fake_root / "gateway").mkdir()
|
||||
(fake_root / "gateway" / "run.py").touch()
|
||||
fake_file = str(fake_root / "gateway" / "run.py")
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
|
||||
mock_popen = MagicMock()
|
||||
|
||||
def which_no_systemd(x):
|
||||
if x == "hermes":
|
||||
return "/usr/bin/hermes"
|
||||
if x == "systemd-run":
|
||||
return None
|
||||
return None
|
||||
|
||||
with patch("gateway.run._hermes_home", hermes_home), \
|
||||
patch("gateway.run.__file__", fake_file), \
|
||||
patch("shutil.which", side_effect=which_no_systemd), \
|
||||
patch("subprocess.Popen", mock_popen):
|
||||
result = await runner._handle_update_command(event)
|
||||
|
||||
# Verify bash -c nohup fallback was used
|
||||
call_args = mock_popen.call_args[0][0]
|
||||
assert call_args[0] == "bash"
|
||||
assert "nohup" in call_args[2]
|
||||
assert ".update_exit_code" in call_args[2]
|
||||
assert "Starting Hermes update" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_popen_failure_cleans_up(self, tmp_path):
|
||||
"""Cleans up pending file and returns error on Popen failure."""
|
||||
runner = _make_runner()
|
||||
event = _make_event()
|
||||
|
||||
fake_root = tmp_path / "project"
|
||||
fake_root.mkdir()
|
||||
(fake_root / ".git").mkdir()
|
||||
(fake_root / "gateway").mkdir()
|
||||
(fake_root / "gateway" / "run.py").touch()
|
||||
fake_file = str(fake_root / "gateway" / "run.py")
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
|
||||
with patch("gateway.run._hermes_home", hermes_home), \
|
||||
patch("gateway.run.__file__", fake_file), \
|
||||
patch("shutil.which", side_effect=lambda x: f"/usr/bin/{x}"), \
|
||||
patch("subprocess.Popen", side_effect=OSError("spawn failed")):
|
||||
result = await runner._handle_update_command(event)
|
||||
|
||||
assert "Failed to start update" in result
|
||||
# Pending file should be cleaned up
|
||||
assert not (hermes_home / ".update_pending.json").exists()
|
||||
assert not (hermes_home / ".update_exit_code").exists()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_user_friendly_message(self, tmp_path):
|
||||
"""The success response is user-friendly."""
|
||||
runner = _make_runner()
|
||||
event = _make_event()
|
||||
|
||||
fake_root = tmp_path / "project"
|
||||
fake_root.mkdir()
|
||||
(fake_root / ".git").mkdir()
|
||||
(fake_root / "gateway").mkdir()
|
||||
(fake_root / "gateway" / "run.py").touch()
|
||||
fake_file = str(fake_root / "gateway" / "run.py")
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
|
||||
with patch("gateway.run._hermes_home", hermes_home), \
|
||||
patch("gateway.run.__file__", fake_file), \
|
||||
patch("shutil.which", side_effect=lambda x: f"/usr/bin/{x}"), \
|
||||
patch("subprocess.Popen"):
|
||||
result = await runner._handle_update_command(event)
|
||||
|
||||
assert "notify you when it's done" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _send_update_notification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSendUpdateNotification:
|
||||
"""Tests for GatewayRunner._send_update_notification."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_pending_file_is_noop(self, tmp_path):
|
||||
"""Does nothing when no pending file exists."""
|
||||
runner = _make_runner()
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
|
||||
with patch("gateway.run._hermes_home", hermes_home):
|
||||
# Should not raise
|
||||
await runner._send_update_notification()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_defers_notification_while_update_still_running(self, tmp_path):
|
||||
"""Returns False and keeps marker files when the update has not exited yet."""
|
||||
runner = _make_runner()
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
|
||||
pending_path = hermes_home / ".update_pending.json"
|
||||
pending_path.write_text(json.dumps({
|
||||
"platform": "telegram", "chat_id": "67890", "user_id": "12345",
|
||||
}))
|
||||
(hermes_home / ".update_output.txt").write_text("still running")
|
||||
|
||||
mock_adapter = AsyncMock()
|
||||
runner.adapters = {Platform.TELEGRAM: mock_adapter}
|
||||
|
||||
with patch("gateway.run._hermes_home", hermes_home):
|
||||
result = await runner._send_update_notification()
|
||||
|
||||
assert result is False
|
||||
mock_adapter.send.assert_not_called()
|
||||
assert pending_path.exists()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recovers_from_claimed_pending_file(self, tmp_path):
|
||||
"""A claimed pending file from a crashed notifier is still deliverable."""
|
||||
runner = _make_runner()
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
|
||||
claimed_path = hermes_home / ".update_pending.claimed.json"
|
||||
claimed_path.write_text(json.dumps({
|
||||
"platform": "telegram", "chat_id": "67890", "user_id": "12345",
|
||||
}))
|
||||
(hermes_home / ".update_output.txt").write_text("done")
|
||||
(hermes_home / ".update_exit_code").write_text("0")
|
||||
|
||||
mock_adapter = AsyncMock()
|
||||
runner.adapters = {Platform.TELEGRAM: mock_adapter}
|
||||
|
||||
with patch("gateway.run._hermes_home", hermes_home):
|
||||
result = await runner._send_update_notification()
|
||||
|
||||
assert result is True
|
||||
mock_adapter.send.assert_called_once()
|
||||
assert not claimed_path.exists()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sends_notification_with_output(self, tmp_path):
|
||||
"""Sends update output to the correct platform and chat."""
|
||||
runner = _make_runner()
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
|
||||
# Write pending marker
|
||||
pending = {
|
||||
"platform": "telegram",
|
||||
"chat_id": "67890",
|
||||
"user_id": "12345",
|
||||
"timestamp": "2026-03-04T21:00:00",
|
||||
}
|
||||
(hermes_home / ".update_pending.json").write_text(json.dumps(pending))
|
||||
(hermes_home / ".update_output.txt").write_text(
|
||||
"→ Found 3 new commit(s)\n✓ Code updated!\n✓ Update complete!"
|
||||
)
|
||||
(hermes_home / ".update_exit_code").write_text("0")
|
||||
|
||||
# Mock the adapter
|
||||
mock_adapter = AsyncMock()
|
||||
mock_adapter.send = AsyncMock()
|
||||
runner.adapters = {Platform.TELEGRAM: mock_adapter}
|
||||
|
||||
with patch("gateway.run._hermes_home", hermes_home):
|
||||
await runner._send_update_notification()
|
||||
|
||||
mock_adapter.send.assert_called_once()
|
||||
call_args = mock_adapter.send.call_args
|
||||
assert call_args[0][0] == "67890" # chat_id
|
||||
assert "Update complete" in call_args[0][1] or "update finished" in call_args[0][1].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_strips_ansi_codes(self, tmp_path):
|
||||
"""ANSI escape codes are removed from output."""
|
||||
runner = _make_runner()
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
|
||||
pending = {"platform": "telegram", "chat_id": "111", "user_id": "222"}
|
||||
(hermes_home / ".update_pending.json").write_text(json.dumps(pending))
|
||||
(hermes_home / ".update_output.txt").write_text(
|
||||
"\x1b[32m✓ Code updated!\x1b[0m\n\x1b[1mDone\x1b[0m"
|
||||
)
|
||||
(hermes_home / ".update_exit_code").write_text("0")
|
||||
|
||||
mock_adapter = AsyncMock()
|
||||
runner.adapters = {Platform.TELEGRAM: mock_adapter}
|
||||
|
||||
with patch("gateway.run._hermes_home", hermes_home):
|
||||
await runner._send_update_notification()
|
||||
|
||||
sent_text = mock_adapter.send.call_args[0][1]
|
||||
assert "\x1b[" not in sent_text
|
||||
assert "Code updated" in sent_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncates_long_output(self, tmp_path):
|
||||
"""Output longer than 3500 chars is truncated."""
|
||||
runner = _make_runner()
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
|
||||
pending = {"platform": "telegram", "chat_id": "111", "user_id": "222"}
|
||||
(hermes_home / ".update_pending.json").write_text(json.dumps(pending))
|
||||
(hermes_home / ".update_output.txt").write_text("x" * 5000)
|
||||
(hermes_home / ".update_exit_code").write_text("0")
|
||||
|
||||
mock_adapter = AsyncMock()
|
||||
runner.adapters = {Platform.TELEGRAM: mock_adapter}
|
||||
|
||||
with patch("gateway.run._hermes_home", hermes_home):
|
||||
await runner._send_update_notification()
|
||||
|
||||
sent_text = mock_adapter.send.call_args[0][1]
|
||||
# Should start with truncation marker
|
||||
assert "…" in sent_text
|
||||
# Total message should not be absurdly long
|
||||
assert len(sent_text) < 4500
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sends_failure_message_when_update_fails(self, tmp_path):
|
||||
"""Non-zero exit codes produce a failure notification with captured output."""
|
||||
runner = _make_runner()
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
|
||||
pending = {"platform": "telegram", "chat_id": "111", "user_id": "222"}
|
||||
(hermes_home / ".update_pending.json").write_text(json.dumps(pending))
|
||||
(hermes_home / ".update_output.txt").write_text("Traceback: boom")
|
||||
(hermes_home / ".update_exit_code").write_text("1")
|
||||
|
||||
mock_adapter = AsyncMock()
|
||||
runner.adapters = {Platform.TELEGRAM: mock_adapter}
|
||||
|
||||
with patch("gateway.run._hermes_home", hermes_home):
|
||||
result = await runner._send_update_notification()
|
||||
|
||||
assert result is True
|
||||
sent_text = mock_adapter.send.call_args[0][1]
|
||||
assert "update failed" in sent_text.lower()
|
||||
assert "Traceback: boom" in sent_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sends_generic_message_when_no_output(self, tmp_path):
|
||||
"""Sends a success message even if the output file is missing."""
|
||||
runner = _make_runner()
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
|
||||
pending = {"platform": "telegram", "chat_id": "111", "user_id": "222"}
|
||||
(hermes_home / ".update_pending.json").write_text(json.dumps(pending))
|
||||
# No .update_output.txt created
|
||||
(hermes_home / ".update_exit_code").write_text("0")
|
||||
|
||||
mock_adapter = AsyncMock()
|
||||
runner.adapters = {Platform.TELEGRAM: mock_adapter}
|
||||
|
||||
with patch("gateway.run._hermes_home", hermes_home):
|
||||
await runner._send_update_notification()
|
||||
|
||||
sent_text = mock_adapter.send.call_args[0][1]
|
||||
assert "finished successfully" in sent_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleans_up_files_after_notification(self, tmp_path):
|
||||
"""Both marker and output files are deleted after notification."""
|
||||
runner = _make_runner()
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
|
||||
pending_path = hermes_home / ".update_pending.json"
|
||||
output_path = hermes_home / ".update_output.txt"
|
||||
exit_code_path = hermes_home / ".update_exit_code"
|
||||
pending_path.write_text(json.dumps({
|
||||
"platform": "telegram", "chat_id": "111", "user_id": "222",
|
||||
}))
|
||||
output_path.write_text("✓ Done")
|
||||
exit_code_path.write_text("0")
|
||||
|
||||
mock_adapter = AsyncMock()
|
||||
runner.adapters = {Platform.TELEGRAM: mock_adapter}
|
||||
|
||||
with patch("gateway.run._hermes_home", hermes_home):
|
||||
await runner._send_update_notification()
|
||||
|
||||
assert not pending_path.exists()
|
||||
assert not output_path.exists()
|
||||
assert not exit_code_path.exists()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleans_up_on_error(self, tmp_path):
|
||||
"""Files are cleaned up even if notification fails."""
|
||||
runner = _make_runner()
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
|
||||
pending_path = hermes_home / ".update_pending.json"
|
||||
output_path = hermes_home / ".update_output.txt"
|
||||
exit_code_path = hermes_home / ".update_exit_code"
|
||||
pending_path.write_text(json.dumps({
|
||||
"platform": "telegram", "chat_id": "111", "user_id": "222",
|
||||
}))
|
||||
output_path.write_text("✓ Done")
|
||||
exit_code_path.write_text("0")
|
||||
|
||||
# Adapter send raises
|
||||
mock_adapter = AsyncMock()
|
||||
mock_adapter.send.side_effect = RuntimeError("network error")
|
||||
runner.adapters = {Platform.TELEGRAM: mock_adapter}
|
||||
|
||||
with patch("gateway.run._hermes_home", hermes_home):
|
||||
await runner._send_update_notification()
|
||||
|
||||
# Files should still be cleaned up (finally block)
|
||||
assert not pending_path.exists()
|
||||
assert not output_path.exists()
|
||||
assert not exit_code_path.exists()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_corrupt_pending_file(self, tmp_path):
|
||||
"""Gracefully handles a malformed pending JSON file."""
|
||||
runner = _make_runner()
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
|
||||
pending_path = hermes_home / ".update_pending.json"
|
||||
pending_path.write_text("{corrupt json!!")
|
||||
|
||||
with patch("gateway.run._hermes_home", hermes_home):
|
||||
# Should not raise
|
||||
await runner._send_update_notification()
|
||||
|
||||
# File should be cleaned up
|
||||
assert not pending_path.exists()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_adapter_for_platform(self, tmp_path):
|
||||
"""Does not crash if the platform adapter is not connected."""
|
||||
runner = _make_runner()
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
|
||||
pending = {"platform": "discord", "chat_id": "111", "user_id": "222"}
|
||||
pending_path = hermes_home / ".update_pending.json"
|
||||
output_path = hermes_home / ".update_output.txt"
|
||||
exit_code_path = hermes_home / ".update_exit_code"
|
||||
pending_path.write_text(json.dumps(pending))
|
||||
output_path.write_text("Done")
|
||||
exit_code_path.write_text("0")
|
||||
|
||||
# Only telegram adapter available, but pending says discord
|
||||
mock_adapter = AsyncMock()
|
||||
runner.adapters = {Platform.TELEGRAM: mock_adapter}
|
||||
|
||||
with patch("gateway.run._hermes_home", hermes_home):
|
||||
await runner._send_update_notification()
|
||||
|
||||
# send should not have been called (wrong platform)
|
||||
mock_adapter.send.assert_not_called()
|
||||
# Files should still be cleaned up
|
||||
assert not pending_path.exists()
|
||||
assert not exit_code_path.exists()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# /update in help and known_commands
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestUpdateInHelp:
|
||||
"""Verify /update appears in help text and known commands set."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_in_help_output(self):
|
||||
"""The /help output includes /update."""
|
||||
runner = _make_runner()
|
||||
event = _make_event(text="/help")
|
||||
result = await runner._handle_help_command(event)
|
||||
assert "/update" in result
|
||||
|
||||
def test_update_is_known_command(self):
|
||||
"""The /update command is in the help text (proxy for _known_commands)."""
|
||||
# _known_commands is local to _handle_message, so we verify by
|
||||
# checking the help output includes it.
|
||||
from gateway.run import GatewayRunner
|
||||
import inspect
|
||||
source = inspect.getsource(GatewayRunner._handle_message)
|
||||
assert '"update"' in source
|
||||
2632
hermes_code/tests/gateway/test_voice_command.py
Normal file
2632
hermes_code/tests/gateway/test_voice_command.py
Normal file
File diff suppressed because it is too large
Load diff
619
hermes_code/tests/gateway/test_webhook_adapter.py
Normal file
619
hermes_code/tests/gateway/test_webhook_adapter.py
Normal file
|
|
@ -0,0 +1,619 @@
|
|||
"""Unit tests for the generic webhook platform adapter.
|
||||
|
||||
Covers:
|
||||
- HMAC signature validation (GitHub, GitLab, generic)
|
||||
- Prompt rendering with dot-notation template variables
|
||||
- Event type filtering
|
||||
- HTTP handler behaviour (404, 202, health)
|
||||
- Idempotency cache (duplicate delivery IDs)
|
||||
- Rate limiting (fixed-window, per route)
|
||||
- Body size limits
|
||||
- INSECURE_NO_AUTH bypass
|
||||
- Session isolation for concurrent webhooks
|
||||
- Delivery info cleanup after send()
|
||||
- connect / disconnect lifecycle
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
from aiohttp.test_utils import TestClient, TestServer
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import MessageEvent, MessageType, SendResult
|
||||
from gateway.platforms.webhook import (
|
||||
WebhookAdapter,
|
||||
_INSECURE_NO_AUTH,
|
||||
check_webhook_requirements,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_config(
|
||||
routes=None,
|
||||
secret="",
|
||||
rate_limit=30,
|
||||
max_body_bytes=1_048_576,
|
||||
host="0.0.0.0",
|
||||
port=0, # let OS pick a free port in tests
|
||||
):
|
||||
"""Build a PlatformConfig suitable for WebhookAdapter."""
|
||||
extra = {
|
||||
"host": host,
|
||||
"port": port,
|
||||
"routes": routes or {},
|
||||
"rate_limit": rate_limit,
|
||||
"max_body_bytes": max_body_bytes,
|
||||
}
|
||||
if secret:
|
||||
extra["secret"] = secret
|
||||
return PlatformConfig(enabled=True, extra=extra)
|
||||
|
||||
|
||||
def _make_adapter(routes=None, **kwargs):
|
||||
"""Create a WebhookAdapter with sensible defaults for testing."""
|
||||
config = _make_config(routes=routes, **kwargs)
|
||||
return WebhookAdapter(config)
|
||||
|
||||
|
||||
def _create_app(adapter: WebhookAdapter) -> web.Application:
|
||||
"""Build the aiohttp Application from the adapter (without starting a full server)."""
|
||||
app = web.Application()
|
||||
app.router.add_get("/health", adapter._handle_health)
|
||||
app.router.add_post("/webhooks/{route_name}", adapter._handle_webhook)
|
||||
return app
|
||||
|
||||
|
||||
def _mock_request(headers=None, body=b"", content_length=None, match_info=None):
|
||||
"""Build a lightweight mock aiohttp request for non-HTTP tests."""
|
||||
req = MagicMock()
|
||||
req.headers = headers or {}
|
||||
req.content_length = content_length if content_length is not None else len(body)
|
||||
req.match_info = match_info or {}
|
||||
req.method = "POST"
|
||||
|
||||
async def _read():
|
||||
return body
|
||||
|
||||
req.read = _read
|
||||
return req
|
||||
|
||||
|
||||
def _github_signature(body: bytes, secret: str) -> str:
|
||||
"""Compute X-Hub-Signature-256 for *body* using *secret*."""
|
||||
return "sha256=" + hmac.new(
|
||||
secret.encode(), body, hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
|
||||
def _generic_signature(body: bytes, secret: str) -> str:
|
||||
"""Compute X-Webhook-Signature (plain HMAC-SHA256 hex) for *body*."""
|
||||
return hmac.new(secret.encode(), body, hashlib.sha256).hexdigest()
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Signature validation
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestValidateSignature:
|
||||
"""Tests for WebhookAdapter._validate_signature."""
|
||||
|
||||
def test_validate_github_signature_valid(self):
|
||||
"""Valid X-Hub-Signature-256 is accepted."""
|
||||
adapter = _make_adapter()
|
||||
body = b'{"action": "opened"}'
|
||||
secret = "webhook-secret-42"
|
||||
sig = _github_signature(body, secret)
|
||||
req = _mock_request(headers={"X-Hub-Signature-256": sig})
|
||||
assert adapter._validate_signature(req, body, secret) is True
|
||||
|
||||
def test_validate_github_signature_invalid(self):
|
||||
"""Wrong X-Hub-Signature-256 is rejected."""
|
||||
adapter = _make_adapter()
|
||||
body = b'{"action": "opened"}'
|
||||
secret = "webhook-secret-42"
|
||||
req = _mock_request(headers={"X-Hub-Signature-256": "sha256=deadbeef"})
|
||||
assert adapter._validate_signature(req, body, secret) is False
|
||||
|
||||
def test_validate_gitlab_token(self):
|
||||
"""GitLab plain-token match via X-Gitlab-Token."""
|
||||
adapter = _make_adapter()
|
||||
secret = "gl-token-value"
|
||||
req = _mock_request(headers={"X-Gitlab-Token": secret})
|
||||
assert adapter._validate_signature(req, b"{}", secret) is True
|
||||
|
||||
def test_validate_gitlab_token_wrong(self):
|
||||
"""Wrong X-Gitlab-Token is rejected."""
|
||||
adapter = _make_adapter()
|
||||
req = _mock_request(headers={"X-Gitlab-Token": "wrong"})
|
||||
assert adapter._validate_signature(req, b"{}", "correct") is False
|
||||
|
||||
def test_validate_no_signature_with_secret_rejects(self):
|
||||
"""Secret configured but no recognised signature header → reject."""
|
||||
adapter = _make_adapter()
|
||||
req = _mock_request(headers={}) # no sig headers at all
|
||||
assert adapter._validate_signature(req, b"{}", "my-secret") is False
|
||||
|
||||
def test_validate_no_secret_allows_all(self):
|
||||
"""When the secret is empty/falsy, the validator is never even called
|
||||
by the handler (secret check is 'if secret and secret != _INSECURE...').
|
||||
Verify that an empty secret isn't accidentally passed to the validator."""
|
||||
# This tests the semantics: empty secret means skip validation entirely.
|
||||
# The handler code does: if secret and secret != _INSECURE_NO_AUTH: validate
|
||||
# So with an empty secret, _validate_signature is never reached.
|
||||
# We just verify the code path is correct by constructing an adapter
|
||||
# with no secret and confirming the route config resolves to "".
|
||||
adapter = _make_adapter(
|
||||
routes={"test": {"prompt": "hello"}},
|
||||
secret="",
|
||||
)
|
||||
# The route has no secret, global secret is empty
|
||||
route_secret = adapter._routes["test"].get("secret", adapter._global_secret)
|
||||
assert not route_secret # empty → validation is skipped in handler
|
||||
|
||||
def test_validate_generic_signature_valid(self):
|
||||
"""Valid X-Webhook-Signature (generic HMAC-SHA256 hex) is accepted."""
|
||||
adapter = _make_adapter()
|
||||
body = b'{"event": "push"}'
|
||||
secret = "generic-secret"
|
||||
sig = _generic_signature(body, secret)
|
||||
req = _mock_request(headers={"X-Webhook-Signature": sig})
|
||||
assert adapter._validate_signature(req, body, secret) is True
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Prompt rendering
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestRenderPrompt:
|
||||
"""Tests for WebhookAdapter._render_prompt."""
|
||||
|
||||
def test_render_prompt_dot_notation(self):
|
||||
"""Dot-notation {pull_request.title} resolves nested keys."""
|
||||
adapter = _make_adapter()
|
||||
payload = {"pull_request": {"title": "Fix bug", "number": 42}}
|
||||
result = adapter._render_prompt(
|
||||
"PR #{pull_request.number}: {pull_request.title}",
|
||||
payload,
|
||||
"pull_request",
|
||||
"github",
|
||||
)
|
||||
assert result == "PR #42: Fix bug"
|
||||
|
||||
def test_render_prompt_missing_key_preserved(self):
|
||||
"""{nonexistent} is left as-is when key doesn't exist in payload."""
|
||||
adapter = _make_adapter()
|
||||
result = adapter._render_prompt(
|
||||
"Hello {nonexistent}!",
|
||||
{"action": "opened"},
|
||||
"push",
|
||||
"test",
|
||||
)
|
||||
assert "{nonexistent}" in result
|
||||
|
||||
def test_render_prompt_no_template_dumps_json(self):
|
||||
"""Empty template → JSON dump fallback with event/route context."""
|
||||
adapter = _make_adapter()
|
||||
payload = {"key": "value"}
|
||||
result = adapter._render_prompt("", payload, "push", "my-route")
|
||||
assert "push" in result
|
||||
assert "my-route" in result
|
||||
assert "key" in result
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Delivery extra rendering
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestRenderDeliveryExtra:
|
||||
def test_render_delivery_extra_templates(self):
|
||||
"""String values in deliver_extra are rendered with payload data."""
|
||||
adapter = _make_adapter()
|
||||
extra = {"repo": "{repository.full_name}", "pr_number": "{number}", "static": 42}
|
||||
payload = {"repository": {"full_name": "org/repo"}, "number": 7}
|
||||
result = adapter._render_delivery_extra(extra, payload)
|
||||
assert result["repo"] == "org/repo"
|
||||
assert result["pr_number"] == "7"
|
||||
assert result["static"] == 42 # non-string left as-is
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Event filtering
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestEventFilter:
|
||||
"""Tests for event type filtering in _handle_webhook."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_filter_accepts_matching(self):
|
||||
"""Matching event type passes through."""
|
||||
routes = {
|
||||
"gh": {
|
||||
"secret": _INSECURE_NO_AUTH,
|
||||
"events": ["pull_request"],
|
||||
"prompt": "PR: {action}",
|
||||
}
|
||||
}
|
||||
adapter = _make_adapter(routes=routes)
|
||||
# Stub handle_message to avoid running the agent
|
||||
adapter.handle_message = AsyncMock()
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.post(
|
||||
"/webhooks/gh",
|
||||
json={"action": "opened"},
|
||||
headers={"X-GitHub-Event": "pull_request"},
|
||||
)
|
||||
assert resp.status == 202
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_filter_rejects_non_matching(self):
|
||||
"""Non-matching event type returns 200 with status=ignored."""
|
||||
routes = {
|
||||
"gh": {
|
||||
"secret": _INSECURE_NO_AUTH,
|
||||
"events": ["pull_request"],
|
||||
"prompt": "test",
|
||||
}
|
||||
}
|
||||
adapter = _make_adapter(routes=routes)
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.post(
|
||||
"/webhooks/gh",
|
||||
json={"action": "opened"},
|
||||
headers={"X-GitHub-Event": "push"},
|
||||
)
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data["status"] == "ignored"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_filter_empty_allows_all(self):
|
||||
"""No events list → accept any event type."""
|
||||
routes = {
|
||||
"all": {
|
||||
"secret": _INSECURE_NO_AUTH,
|
||||
"prompt": "got it",
|
||||
}
|
||||
}
|
||||
adapter = _make_adapter(routes=routes)
|
||||
adapter.handle_message = AsyncMock()
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.post(
|
||||
"/webhooks/all",
|
||||
json={"action": "any"},
|
||||
headers={"X-GitHub-Event": "whatever"},
|
||||
)
|
||||
assert resp.status == 202
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# HTTP handling
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestHTTPHandling:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_route_returns_404(self):
|
||||
"""POST to an unknown route returns 404."""
|
||||
adapter = _make_adapter(routes={"real": {"secret": _INSECURE_NO_AUTH, "prompt": "x"}})
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.post("/webhooks/nonexistent", json={"a": 1})
|
||||
assert resp.status == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_handler_returns_202(self):
|
||||
"""Valid request returns 202 Accepted."""
|
||||
routes = {"test": {"secret": _INSECURE_NO_AUTH, "prompt": "hi"}}
|
||||
adapter = _make_adapter(routes=routes)
|
||||
adapter.handle_message = AsyncMock()
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.post("/webhooks/test", json={"data": "value"})
|
||||
assert resp.status == 202
|
||||
data = await resp.json()
|
||||
assert data["status"] == "accepted"
|
||||
assert data["route"] == "test"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_endpoint(self):
|
||||
"""GET /health returns 200 with status=ok."""
|
||||
adapter = _make_adapter()
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.get("/health")
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data["status"] == "ok"
|
||||
assert data["platform"] == "webhook"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_starts_server(self):
|
||||
"""connect() starts the HTTP listener and marks adapter as connected."""
|
||||
routes = {"r1": {"secret": _INSECURE_NO_AUTH, "prompt": "x"}}
|
||||
adapter = _make_adapter(routes=routes, port=0)
|
||||
# Use port 0 — the OS picks a free port, but aiohttp requires a real bind.
|
||||
# We just test that the method completes and marks connected.
|
||||
# Need to mock TCPSite to avoid actual binding.
|
||||
with patch("gateway.platforms.webhook.web.AppRunner") as MockRunner, \
|
||||
patch("gateway.platforms.webhook.web.TCPSite") as MockSite:
|
||||
mock_runner_inst = AsyncMock()
|
||||
MockRunner.return_value = mock_runner_inst
|
||||
mock_site_inst = AsyncMock()
|
||||
MockSite.return_value = mock_site_inst
|
||||
|
||||
result = await adapter.connect()
|
||||
assert result is True
|
||||
assert adapter.is_connected
|
||||
mock_runner_inst.setup.assert_awaited_once()
|
||||
mock_site_inst.start.assert_awaited_once()
|
||||
|
||||
await adapter.disconnect()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_cleans_up(self):
|
||||
"""disconnect() stops the server and marks adapter disconnected."""
|
||||
adapter = _make_adapter()
|
||||
# Simulate a runner that was previously set up
|
||||
mock_runner = AsyncMock()
|
||||
adapter._runner = mock_runner
|
||||
adapter._running = True
|
||||
|
||||
await adapter.disconnect()
|
||||
mock_runner.cleanup.assert_awaited_once()
|
||||
assert adapter._runner is None
|
||||
assert not adapter.is_connected
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Idempotency
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestIdempotency:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_duplicate_delivery_id_returns_200(self):
|
||||
"""Second request with same delivery ID returns 200 duplicate."""
|
||||
routes = {"idem": {"secret": _INSECURE_NO_AUTH, "prompt": "test"}}
|
||||
adapter = _make_adapter(routes=routes)
|
||||
adapter.handle_message = AsyncMock()
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
headers = {"X-GitHub-Delivery": "delivery-123"}
|
||||
resp1 = await cli.post("/webhooks/idem", json={"a": 1}, headers=headers)
|
||||
assert resp1.status == 202
|
||||
|
||||
resp2 = await cli.post("/webhooks/idem", json={"a": 1}, headers=headers)
|
||||
assert resp2.status == 200
|
||||
data = await resp2.json()
|
||||
assert data["status"] == "duplicate"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expired_delivery_id_allows_reprocess(self):
|
||||
"""After TTL expires, the same delivery ID is accepted again."""
|
||||
routes = {"idem": {"secret": _INSECURE_NO_AUTH, "prompt": "test"}}
|
||||
adapter = _make_adapter(routes=routes)
|
||||
adapter._idempotency_ttl = 1 # 1 second TTL for test speed
|
||||
adapter.handle_message = AsyncMock()
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
headers = {"X-GitHub-Delivery": "delivery-456"}
|
||||
|
||||
resp1 = await cli.post("/webhooks/idem", json={"x": 1}, headers=headers)
|
||||
assert resp1.status == 202
|
||||
|
||||
# Backdate the cache entry so it appears expired
|
||||
adapter._seen_deliveries["delivery-456"] = time.time() - 3700
|
||||
|
||||
resp2 = await cli.post("/webhooks/idem", json={"x": 1}, headers=headers)
|
||||
assert resp2.status == 202 # re-accepted
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Rate limiting
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestRateLimiting:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_rejects_excess(self):
|
||||
"""Exceeding the rate limit returns 429."""
|
||||
routes = {"limited": {"secret": _INSECURE_NO_AUTH, "prompt": "test"}}
|
||||
adapter = _make_adapter(routes=routes, rate_limit=2)
|
||||
adapter.handle_message = AsyncMock()
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
# Two requests within limit
|
||||
for i in range(2):
|
||||
resp = await cli.post(
|
||||
"/webhooks/limited",
|
||||
json={"n": i},
|
||||
headers={"X-GitHub-Delivery": f"d-{i}"},
|
||||
)
|
||||
assert resp.status == 202, f"Request {i} should be accepted"
|
||||
|
||||
# Third request should be rate-limited
|
||||
resp = await cli.post(
|
||||
"/webhooks/limited",
|
||||
json={"n": 99},
|
||||
headers={"X-GitHub-Delivery": "d-99"},
|
||||
)
|
||||
assert resp.status == 429
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_window_resets(self):
|
||||
"""After the 60-second window passes, requests are allowed again."""
|
||||
routes = {"limited": {"secret": _INSECURE_NO_AUTH, "prompt": "test"}}
|
||||
adapter = _make_adapter(routes=routes, rate_limit=1)
|
||||
adapter.handle_message = AsyncMock()
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.post(
|
||||
"/webhooks/limited",
|
||||
json={"n": 1},
|
||||
headers={"X-GitHub-Delivery": "d-a"},
|
||||
)
|
||||
assert resp.status == 202
|
||||
|
||||
# Backdate all rate-limit timestamps to > 60 seconds ago
|
||||
adapter._rate_counts["limited"] = [time.time() - 120]
|
||||
|
||||
resp = await cli.post(
|
||||
"/webhooks/limited",
|
||||
json={"n": 2},
|
||||
headers={"X-GitHub-Delivery": "d-b"},
|
||||
)
|
||||
assert resp.status == 202 # allowed again
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Body size limit
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestBodySize:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oversized_payload_rejected(self):
|
||||
"""Content-Length > max_body_bytes returns 413."""
|
||||
routes = {"big": {"secret": _INSECURE_NO_AUTH, "prompt": "test"}}
|
||||
adapter = _make_adapter(routes=routes, max_body_bytes=100)
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
large_payload = {"data": "x" * 200}
|
||||
resp = await cli.post(
|
||||
"/webhooks/big",
|
||||
json=large_payload,
|
||||
headers={"Content-Length": "999999"},
|
||||
)
|
||||
assert resp.status == 413
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# INSECURE_NO_AUTH
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestInsecureNoAuth:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insecure_no_auth_skips_validation(self):
|
||||
"""Setting secret to _INSECURE_NO_AUTH bypasses signature check."""
|
||||
routes = {"open": {"secret": _INSECURE_NO_AUTH, "prompt": "hello"}}
|
||||
adapter = _make_adapter(routes=routes)
|
||||
adapter.handle_message = AsyncMock()
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
# No signature header at all — should still be accepted
|
||||
resp = await cli.post("/webhooks/open", json={"test": True})
|
||||
assert resp.status == 202
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Session isolation
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestSessionIsolation:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_webhooks_get_independent_sessions(self):
|
||||
"""Two events on the same route produce different session keys."""
|
||||
routes = {"ci": {"secret": _INSECURE_NO_AUTH, "prompt": "build"}}
|
||||
adapter = _make_adapter(routes=routes)
|
||||
|
||||
captured_events = []
|
||||
|
||||
async def _capture(event):
|
||||
captured_events.append(event)
|
||||
|
||||
adapter.handle_message = _capture
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp1 = await cli.post(
|
||||
"/webhooks/ci",
|
||||
json={"ref": "main"},
|
||||
headers={"X-GitHub-Delivery": "aaa-111"},
|
||||
)
|
||||
assert resp1.status == 202
|
||||
|
||||
resp2 = await cli.post(
|
||||
"/webhooks/ci",
|
||||
json={"ref": "dev"},
|
||||
headers={"X-GitHub-Delivery": "bbb-222"},
|
||||
)
|
||||
assert resp2.status == 202
|
||||
|
||||
# Wait for the async tasks to be created
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
assert len(captured_events) == 2
|
||||
ids = {ev.source.chat_id for ev in captured_events}
|
||||
assert len(ids) == 2, "Each delivery must have a unique session chat_id"
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Delivery info cleanup
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestDeliveryCleanup:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delivery_info_cleaned_after_send(self):
|
||||
"""send() pops delivery_info so the entry doesn't leak memory."""
|
||||
adapter = _make_adapter()
|
||||
chat_id = "webhook:test:d-xyz"
|
||||
adapter._delivery_info[chat_id] = {
|
||||
"deliver": "log",
|
||||
"deliver_extra": {},
|
||||
"payload": {"x": 1},
|
||||
}
|
||||
|
||||
result = await adapter.send(chat_id, "Agent response here")
|
||||
assert result.success is True
|
||||
assert chat_id not in adapter._delivery_info
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# check_webhook_requirements
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestCheckRequirements:
|
||||
def test_returns_true_when_aiohttp_available(self):
|
||||
assert check_webhook_requirements() is True
|
||||
|
||||
@patch("gateway.platforms.webhook.AIOHTTP_AVAILABLE", False)
|
||||
def test_returns_false_without_aiohttp(self):
|
||||
assert check_webhook_requirements() is False
|
||||
337
hermes_code/tests/gateway/test_webhook_integration.py
Normal file
337
hermes_code/tests/gateway/test_webhook_integration.py
Normal file
|
|
@ -0,0 +1,337 @@
|
|||
"""Integration tests for the generic webhook platform adapter.
|
||||
|
||||
These tests exercise end-to-end flows through the webhook adapter:
|
||||
1. GitHub PR webhook → agent MessageEvent created
|
||||
2. Skills config injects skill content into the prompt
|
||||
3. Cross-platform delivery routes to a mock Telegram adapter
|
||||
4. GitHub comment delivery invokes ``gh`` CLI (mocked subprocess)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
from aiohttp.test_utils import TestClient, TestServer
|
||||
|
||||
from gateway.config import (
|
||||
GatewayConfig,
|
||||
HomeChannel,
|
||||
Platform,
|
||||
PlatformConfig,
|
||||
)
|
||||
from gateway.platforms.base import MessageEvent, MessageType, SendResult
|
||||
from gateway.platforms.webhook import WebhookAdapter, _INSECURE_NO_AUTH
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_adapter(routes, **extra_kw) -> WebhookAdapter:
|
||||
"""Create a WebhookAdapter with the given routes."""
|
||||
extra = {"host": "0.0.0.0", "port": 0, "routes": routes}
|
||||
extra.update(extra_kw)
|
||||
config = PlatformConfig(enabled=True, extra=extra)
|
||||
return WebhookAdapter(config)
|
||||
|
||||
|
||||
def _create_app(adapter: WebhookAdapter) -> web.Application:
|
||||
"""Build the aiohttp Application from the adapter."""
|
||||
app = web.Application()
|
||||
app.router.add_get("/health", adapter._handle_health)
|
||||
app.router.add_post("/webhooks/{route_name}", adapter._handle_webhook)
|
||||
return app
|
||||
|
||||
|
||||
def _github_signature(body: bytes, secret: str) -> str:
|
||||
"""Compute X-Hub-Signature-256 for *body* using *secret*."""
|
||||
return "sha256=" + hmac.new(
|
||||
secret.encode(), body, hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
|
||||
# A realistic GitHub pull_request event payload (trimmed)
|
||||
GITHUB_PR_PAYLOAD = {
|
||||
"action": "opened",
|
||||
"number": 42,
|
||||
"pull_request": {
|
||||
"title": "Add webhook adapter",
|
||||
"body": "This PR adds a generic webhook platform adapter.",
|
||||
"html_url": "https://github.com/org/repo/pull/42",
|
||||
"user": {"login": "contributor"},
|
||||
"head": {"ref": "feature/webhooks"},
|
||||
"base": {"ref": "main"},
|
||||
},
|
||||
"repository": {
|
||||
"full_name": "org/repo",
|
||||
"html_url": "https://github.com/org/repo",
|
||||
},
|
||||
"sender": {"login": "contributor"},
|
||||
}
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Test 1: GitHub PR webhook triggers agent
|
||||
# ===================================================================
|
||||
|
||||
class TestGitHubPRWebhook:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_github_pr_webhook_triggers_agent(self):
|
||||
"""POST with a realistic GitHub PR payload should:
|
||||
1. Return 202 Accepted
|
||||
2. Call handle_message with a MessageEvent
|
||||
3. The event text contains the rendered prompt
|
||||
4. The event source has chat_type 'webhook'
|
||||
"""
|
||||
secret = "gh-webhook-test-secret"
|
||||
routes = {
|
||||
"github-pr": {
|
||||
"secret": secret,
|
||||
"events": ["pull_request"],
|
||||
"prompt": (
|
||||
"Review PR #{number} by {sender.login}: "
|
||||
"{pull_request.title}\n\n{pull_request.body}"
|
||||
),
|
||||
"deliver": "log",
|
||||
}
|
||||
}
|
||||
adapter = _make_adapter(routes)
|
||||
|
||||
captured_events: list[MessageEvent] = []
|
||||
|
||||
async def _capture(event: MessageEvent):
|
||||
captured_events.append(event)
|
||||
|
||||
adapter.handle_message = _capture
|
||||
|
||||
app = _create_app(adapter)
|
||||
body = json.dumps(GITHUB_PR_PAYLOAD).encode()
|
||||
sig = _github_signature(body, secret)
|
||||
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.post(
|
||||
"/webhooks/github-pr",
|
||||
data=body,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"X-GitHub-Event": "pull_request",
|
||||
"X-Hub-Signature-256": sig,
|
||||
"X-GitHub-Delivery": "gh-delivery-001",
|
||||
},
|
||||
)
|
||||
assert resp.status == 202
|
||||
data = await resp.json()
|
||||
assert data["status"] == "accepted"
|
||||
assert data["route"] == "github-pr"
|
||||
assert data["event"] == "pull_request"
|
||||
assert data["delivery_id"] == "gh-delivery-001"
|
||||
|
||||
# Let the asyncio.create_task fire
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
assert len(captured_events) == 1
|
||||
event = captured_events[0]
|
||||
assert "Review PR #42 by contributor" in event.text
|
||||
assert "Add webhook adapter" in event.text
|
||||
assert event.source.chat_type == "webhook"
|
||||
assert event.source.platform == Platform.WEBHOOK
|
||||
assert "github-pr" in event.source.chat_id
|
||||
assert event.message_id == "gh-delivery-001"
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Test 2: Skills injected into prompt
|
||||
# ===================================================================
|
||||
|
||||
class TestSkillsInjection:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skills_injected_into_prompt(self):
|
||||
"""When a route has skills: [code-review], the adapter should
|
||||
call build_skill_invocation_message() and use its output as the
|
||||
prompt instead of the raw template render."""
|
||||
routes = {
|
||||
"pr-review": {
|
||||
"secret": _INSECURE_NO_AUTH,
|
||||
"events": ["pull_request"],
|
||||
"prompt": "Review this PR: {pull_request.title}",
|
||||
"skills": ["code-review"],
|
||||
}
|
||||
}
|
||||
adapter = _make_adapter(routes)
|
||||
|
||||
captured_events: list[MessageEvent] = []
|
||||
|
||||
async def _capture(event: MessageEvent):
|
||||
captured_events.append(event)
|
||||
|
||||
adapter.handle_message = _capture
|
||||
|
||||
skill_content = (
|
||||
"You are a code reviewer. Review the following:\n"
|
||||
"Review this PR: Add webhook adapter"
|
||||
)
|
||||
|
||||
# The imports are lazy (inside the handler), so patch the source module
|
||||
with patch(
|
||||
"agent.skill_commands.build_skill_invocation_message",
|
||||
return_value=skill_content,
|
||||
) as mock_build, patch(
|
||||
"agent.skill_commands.get_skill_commands",
|
||||
return_value={"/code-review": {"name": "code-review"}},
|
||||
):
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.post(
|
||||
"/webhooks/pr-review",
|
||||
json=GITHUB_PR_PAYLOAD,
|
||||
headers={
|
||||
"X-GitHub-Event": "pull_request",
|
||||
"X-GitHub-Delivery": "skill-test-001",
|
||||
},
|
||||
)
|
||||
assert resp.status == 202
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
assert len(captured_events) == 1
|
||||
event = captured_events[0]
|
||||
# The prompt should be the skill content, not the raw template
|
||||
assert "You are a code reviewer" in event.text
|
||||
mock_build.assert_called_once()
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Test 3: Cross-platform delivery (webhook → Telegram)
|
||||
# ===================================================================
|
||||
|
||||
class TestCrossPlatformDelivery:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cross_platform_delivery(self):
|
||||
"""When deliver='telegram', the response is routed to the
|
||||
Telegram adapter via gateway_runner.adapters."""
|
||||
routes = {
|
||||
"alerts": {
|
||||
"secret": _INSECURE_NO_AUTH,
|
||||
"prompt": "Alert: {message}",
|
||||
"deliver": "telegram",
|
||||
"deliver_extra": {"chat_id": "12345"},
|
||||
}
|
||||
}
|
||||
adapter = _make_adapter(routes)
|
||||
adapter.handle_message = AsyncMock()
|
||||
|
||||
# Set up a mock gateway runner with a mock Telegram adapter
|
||||
mock_tg_adapter = AsyncMock()
|
||||
mock_tg_adapter.send = AsyncMock(return_value=SendResult(success=True))
|
||||
|
||||
mock_runner = MagicMock()
|
||||
mock_runner.adapters = {Platform.TELEGRAM: mock_tg_adapter}
|
||||
mock_runner.config = GatewayConfig(
|
||||
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="fake")}
|
||||
)
|
||||
adapter.gateway_runner = mock_runner
|
||||
|
||||
# First, simulate a webhook POST to set up delivery_info
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.post(
|
||||
"/webhooks/alerts",
|
||||
json={"message": "Server is on fire!"},
|
||||
headers={"X-GitHub-Delivery": "alert-001"},
|
||||
)
|
||||
assert resp.status == 202
|
||||
|
||||
# The adapter should have stored delivery info
|
||||
chat_id = "webhook:alerts:alert-001"
|
||||
assert chat_id in adapter._delivery_info
|
||||
|
||||
# Now call send() as if the agent has finished
|
||||
result = await adapter.send(chat_id, "I've acknowledged the alert.")
|
||||
|
||||
assert result.success is True
|
||||
mock_tg_adapter.send.assert_awaited_once_with(
|
||||
"12345", "I've acknowledged the alert."
|
||||
)
|
||||
# Delivery info should be cleaned up
|
||||
assert chat_id not in adapter._delivery_info
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Test 4: GitHub comment delivery via gh CLI
|
||||
# ===================================================================
|
||||
|
||||
class TestGitHubCommentDelivery:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_github_comment_delivery(self):
|
||||
"""When deliver='github_comment', the adapter invokes
|
||||
``gh pr comment`` via subprocess.run (mocked)."""
|
||||
routes = {
|
||||
"pr-bot": {
|
||||
"secret": _INSECURE_NO_AUTH,
|
||||
"prompt": "Review: {pull_request.title}",
|
||||
"deliver": "github_comment",
|
||||
"deliver_extra": {
|
||||
"repo": "{repository.full_name}",
|
||||
"pr_number": "{number}",
|
||||
},
|
||||
}
|
||||
}
|
||||
adapter = _make_adapter(routes)
|
||||
adapter.handle_message = AsyncMock()
|
||||
|
||||
# POST a webhook to set up delivery info
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.post(
|
||||
"/webhooks/pr-bot",
|
||||
json=GITHUB_PR_PAYLOAD,
|
||||
headers={
|
||||
"X-GitHub-Event": "pull_request",
|
||||
"X-GitHub-Delivery": "gh-comment-001",
|
||||
},
|
||||
)
|
||||
assert resp.status == 202
|
||||
|
||||
chat_id = "webhook:pr-bot:gh-comment-001"
|
||||
assert chat_id in adapter._delivery_info
|
||||
|
||||
# Verify deliver_extra was rendered with payload data
|
||||
delivery = adapter._delivery_info[chat_id]
|
||||
assert delivery["deliver_extra"]["repo"] == "org/repo"
|
||||
assert delivery["deliver_extra"]["pr_number"] == "42"
|
||||
|
||||
# Mock subprocess.run and call send()
|
||||
mock_result = MagicMock()
|
||||
mock_result.returncode = 0
|
||||
mock_result.stdout = "Comment posted"
|
||||
mock_result.stderr = ""
|
||||
|
||||
with patch(
|
||||
"gateway.platforms.webhook.subprocess.run",
|
||||
return_value=mock_result,
|
||||
) as mock_run:
|
||||
result = await adapter.send(
|
||||
chat_id, "LGTM! The code looks great."
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
mock_run.assert_called_once_with(
|
||||
[
|
||||
"gh", "pr", "comment", "42",
|
||||
"--repo", "org/repo",
|
||||
"--body", "LGTM! The code looks great.",
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30,
|
||||
)
|
||||
# Delivery info cleaned up
|
||||
assert chat_id not in adapter._delivery_info
|
||||
419
hermes_code/tests/gateway/test_whatsapp_connect.py
Normal file
419
hermes_code/tests/gateway/test_whatsapp_connect.py
Normal file
|
|
@ -0,0 +1,419 @@
|
|||
"""Tests for WhatsApp connect() error handling.
|
||||
|
||||
Regression tests for two bugs in WhatsAppAdapter.connect():
|
||||
|
||||
1. Uninitialized ``data`` variable: when ``resp.json()`` raised after the
|
||||
health endpoint returned HTTP 200, ``http_ready`` was set to True but
|
||||
``data`` was never assigned. The subsequent ``data.get("status")``
|
||||
check raised ``NameError``.
|
||||
|
||||
2. Bridge log file handle leaked on error paths: the file was opened before
|
||||
the health-check loop but never closed when ``connect()`` returned False.
|
||||
Repeated connection failures accumulated open file descriptors.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _AsyncCM:
|
||||
"""Minimal async context manager returning a fixed value."""
|
||||
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
async def __aenter__(self):
|
||||
return self.value
|
||||
|
||||
async def __aexit__(self, *exc):
|
||||
return False
|
||||
|
||||
|
||||
def _make_adapter():
|
||||
"""Create a WhatsAppAdapter with test attributes (bypass __init__)."""
|
||||
from gateway.platforms.whatsapp import WhatsAppAdapter
|
||||
|
||||
adapter = WhatsAppAdapter.__new__(WhatsAppAdapter)
|
||||
adapter.platform = Platform.WHATSAPP
|
||||
adapter.config = MagicMock()
|
||||
adapter._bridge_port = 19876
|
||||
adapter._bridge_script = "/tmp/test-bridge.js"
|
||||
adapter._session_path = Path("/tmp/test-wa-session")
|
||||
adapter._bridge_log_fh = None
|
||||
adapter._bridge_log = None
|
||||
adapter._bridge_process = None
|
||||
adapter._reply_prefix = None
|
||||
adapter._running = False
|
||||
adapter._message_handler = None
|
||||
adapter._fatal_error_code = None
|
||||
adapter._fatal_error_message = None
|
||||
adapter._fatal_error_retryable = True
|
||||
adapter._fatal_error_handler = None
|
||||
adapter._active_sessions = {}
|
||||
adapter._pending_messages = {}
|
||||
adapter._background_tasks = set()
|
||||
adapter._auto_tts_disabled_chats = set()
|
||||
adapter._message_queue = asyncio.Queue()
|
||||
return adapter
|
||||
|
||||
|
||||
def _mock_aiohttp(status=200, json_data=None, json_side_effect=None):
|
||||
"""Build a mock ``aiohttp.ClientSession`` returning a fixed response."""
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status = status
|
||||
if json_side_effect:
|
||||
mock_resp.json = AsyncMock(side_effect=json_side_effect)
|
||||
else:
|
||||
mock_resp.json = AsyncMock(return_value=json_data or {})
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.get = MagicMock(return_value=_AsyncCM(mock_resp))
|
||||
|
||||
return MagicMock(return_value=_AsyncCM(mock_session))
|
||||
|
||||
|
||||
def _connect_patches(mock_proc, mock_fh, mock_client_cls=None):
|
||||
"""Return a dict of common patches needed to reach the health-check loop."""
|
||||
patches = {
|
||||
"gateway.platforms.whatsapp.check_whatsapp_requirements": True,
|
||||
"gateway.platforms.whatsapp.asyncio.create_task": MagicMock(),
|
||||
}
|
||||
base = [
|
||||
patch("gateway.platforms.whatsapp.check_whatsapp_requirements", return_value=True),
|
||||
patch.object(Path, "exists", return_value=True),
|
||||
patch.object(Path, "mkdir", return_value=None),
|
||||
patch("subprocess.run", return_value=MagicMock(returncode=0)),
|
||||
patch("subprocess.Popen", return_value=mock_proc),
|
||||
patch("builtins.open", return_value=mock_fh),
|
||||
patch("gateway.platforms.whatsapp.asyncio.sleep", new_callable=AsyncMock),
|
||||
patch("gateway.platforms.whatsapp.asyncio.create_task"),
|
||||
]
|
||||
if mock_client_cls is not None:
|
||||
base.append(patch("aiohttp.ClientSession", mock_client_cls))
|
||||
return base
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _close_bridge_log() unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCloseBridgeLog:
|
||||
"""Direct tests for the _close_bridge_log() helper method."""
|
||||
|
||||
@staticmethod
|
||||
def _bare_adapter():
|
||||
from gateway.platforms.whatsapp import WhatsAppAdapter
|
||||
a = WhatsAppAdapter.__new__(WhatsAppAdapter)
|
||||
a._bridge_log_fh = None
|
||||
return a
|
||||
|
||||
def test_closes_open_handle(self):
|
||||
adapter = self._bare_adapter()
|
||||
mock_fh = MagicMock()
|
||||
adapter._bridge_log_fh = mock_fh
|
||||
|
||||
adapter._close_bridge_log()
|
||||
|
||||
mock_fh.close.assert_called_once()
|
||||
assert adapter._bridge_log_fh is None
|
||||
|
||||
def test_noop_when_no_handle(self):
|
||||
adapter = self._bare_adapter()
|
||||
|
||||
adapter._close_bridge_log() # must not raise
|
||||
|
||||
assert adapter._bridge_log_fh is None
|
||||
|
||||
def test_suppresses_close_exception(self):
|
||||
adapter = self._bare_adapter()
|
||||
mock_fh = MagicMock()
|
||||
mock_fh.close.side_effect = OSError("already closed")
|
||||
adapter._bridge_log_fh = mock_fh
|
||||
|
||||
adapter._close_bridge_log() # must not raise
|
||||
|
||||
assert adapter._bridge_log_fh is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# data variable initialization
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDataInitialized:
|
||||
"""Verify ``data = {}`` prevents NameError when resp.json() fails."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_name_error_when_json_always_fails(self):
|
||||
"""HTTP 200 sets http_ready but json() always raises.
|
||||
|
||||
Without the fix, ``data`` was never assigned and the Phase 2 check
|
||||
``data.get("status")`` raised NameError. With ``data = {}``, the
|
||||
check evaluates to ``None != "connected"`` and Phase 2 runs normally.
|
||||
"""
|
||||
adapter = _make_adapter()
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.poll.return_value = None # bridge stays alive
|
||||
|
||||
mock_client_cls = _mock_aiohttp(
|
||||
status=200, json_side_effect=ValueError("bad json"),
|
||||
)
|
||||
mock_fh = MagicMock()
|
||||
|
||||
patches = _connect_patches(mock_proc, mock_fh, mock_client_cls)
|
||||
|
||||
with patches[0], patches[1], patches[2], patches[3], patches[4], \
|
||||
patches[5], patches[6], patches[7], patches[8], \
|
||||
patch.object(type(adapter), "_poll_messages", return_value=MagicMock()):
|
||||
# Must NOT raise NameError
|
||||
result = await adapter.connect()
|
||||
|
||||
# connect() returns True (warn-and-proceed path)
|
||||
assert result is True
|
||||
assert adapter._running is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# File handle cleanup on error paths
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFileHandleClosedOnError:
|
||||
"""Verify the bridge log file handle is closed on every failure path."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_closed_when_bridge_dies_phase1(self):
|
||||
"""Bridge process exits during Phase 1 health-check loop."""
|
||||
adapter = _make_adapter()
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.poll.return_value = 1 # dead immediately
|
||||
mock_proc.returncode = 1
|
||||
|
||||
mock_fh = MagicMock()
|
||||
patches = _connect_patches(mock_proc, mock_fh)
|
||||
|
||||
with patches[0], patches[1], patches[2], patches[3], patches[4], \
|
||||
patches[5], patches[6], patches[7]:
|
||||
result = await adapter.connect()
|
||||
|
||||
assert result is False
|
||||
mock_fh.close.assert_called_once()
|
||||
assert adapter._bridge_log_fh is None
|
||||
|
||||
|
||||
class TestBridgeRuntimeFailure:
|
||||
"""Verify runtime bridge death is surfaced as a fatal adapter error."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_marks_retryable_fatal_when_managed_bridge_exits(self):
|
||||
adapter = _make_adapter()
|
||||
fatal_handler = AsyncMock()
|
||||
adapter.set_fatal_error_handler(fatal_handler)
|
||||
adapter._running = True
|
||||
mock_fh = MagicMock()
|
||||
adapter._bridge_log_fh = mock_fh
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.poll.return_value = 7
|
||||
adapter._bridge_process = mock_proc
|
||||
|
||||
result = await adapter.send("chat-123", "hello")
|
||||
|
||||
assert result.success is False
|
||||
assert "exited unexpectedly" in result.error
|
||||
assert adapter.fatal_error_code == "whatsapp_bridge_exited"
|
||||
assert adapter.fatal_error_retryable is True
|
||||
fatal_handler.assert_awaited_once()
|
||||
mock_fh.close.assert_called_once()
|
||||
assert adapter._bridge_log_fh is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_messages_marks_retryable_fatal_when_managed_bridge_exits(self):
|
||||
adapter = _make_adapter()
|
||||
fatal_handler = AsyncMock()
|
||||
adapter.set_fatal_error_handler(fatal_handler)
|
||||
adapter._running = True
|
||||
mock_fh = MagicMock()
|
||||
adapter._bridge_log_fh = mock_fh
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.poll.return_value = 23
|
||||
adapter._bridge_process = mock_proc
|
||||
|
||||
await adapter._poll_messages()
|
||||
|
||||
assert adapter.fatal_error_code == "whatsapp_bridge_exited"
|
||||
assert adapter.fatal_error_retryable is True
|
||||
fatal_handler.assert_awaited_once()
|
||||
mock_fh.close.assert_called_once()
|
||||
assert adapter._bridge_log_fh is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_closed_when_http_not_ready(self):
|
||||
"""Health endpoint never returns 200 within 15 attempts."""
|
||||
adapter = _make_adapter()
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.poll.return_value = None # bridge alive
|
||||
|
||||
mock_client_cls = _mock_aiohttp(status=503)
|
||||
mock_fh = MagicMock()
|
||||
patches = _connect_patches(mock_proc, mock_fh, mock_client_cls)
|
||||
|
||||
with patches[0], patches[1], patches[2], patches[3], patches[4], \
|
||||
patches[5], patches[6], patches[7], patches[8]:
|
||||
result = await adapter.connect()
|
||||
|
||||
assert result is False
|
||||
mock_fh.close.assert_called_once()
|
||||
assert adapter._bridge_log_fh is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_closed_when_bridge_dies_phase2(self):
|
||||
"""Bridge alive during Phase 1 but dies during Phase 2."""
|
||||
adapter = _make_adapter()
|
||||
|
||||
# Phase 1 (15 iterations): alive. Phase 2 (iteration 16): dead.
|
||||
call_count = [0]
|
||||
|
||||
def poll_side_effect():
|
||||
call_count[0] += 1
|
||||
return None if call_count[0] <= 15 else 1
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.poll.side_effect = poll_side_effect
|
||||
mock_proc.returncode = 1
|
||||
|
||||
# Health returns 200 with status != "connected" -> triggers Phase 2
|
||||
mock_client_cls = _mock_aiohttp(
|
||||
status=200, json_data={"status": "disconnected"},
|
||||
)
|
||||
mock_fh = MagicMock()
|
||||
patches = _connect_patches(mock_proc, mock_fh, mock_client_cls)
|
||||
|
||||
with patches[0], patches[1], patches[2], patches[3], patches[4], \
|
||||
patches[5], patches[6], patches[7], patches[8]:
|
||||
result = await adapter.connect()
|
||||
|
||||
assert result is False
|
||||
mock_fh.close.assert_called_once()
|
||||
assert adapter._bridge_log_fh is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_closed_on_unexpected_exception(self):
|
||||
"""Popen raises, outer except block must still close the handle."""
|
||||
adapter = _make_adapter()
|
||||
|
||||
mock_fh = MagicMock()
|
||||
|
||||
with patch("gateway.platforms.whatsapp.check_whatsapp_requirements", return_value=True), \
|
||||
patch.object(Path, "exists", return_value=True), \
|
||||
patch.object(Path, "mkdir", return_value=None), \
|
||||
patch("subprocess.run", return_value=MagicMock(returncode=0)), \
|
||||
patch("subprocess.Popen", side_effect=OSError("spawn failed")), \
|
||||
patch("builtins.open", return_value=mock_fh):
|
||||
result = await adapter.connect()
|
||||
|
||||
assert result is False
|
||||
mock_fh.close.assert_called_once()
|
||||
assert adapter._bridge_log_fh is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _kill_port_process() cross-platform tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestKillPortProcess:
|
||||
"""Verify _kill_port_process uses platform-appropriate commands."""
|
||||
|
||||
def test_uses_netstat_and_taskkill_on_windows(self):
|
||||
from gateway.platforms.whatsapp import _kill_port_process
|
||||
|
||||
netstat_output = (
|
||||
" Proto Local Address Foreign Address State PID\n"
|
||||
" TCP 0.0.0.0:3000 0.0.0.0:0 LISTENING 12345\n"
|
||||
" TCP 0.0.0.0:3001 0.0.0.0:0 LISTENING 99999\n"
|
||||
)
|
||||
mock_netstat = MagicMock(stdout=netstat_output)
|
||||
mock_taskkill = MagicMock()
|
||||
|
||||
def run_side_effect(cmd, **kwargs):
|
||||
if cmd[0] == "netstat":
|
||||
return mock_netstat
|
||||
if cmd[0] == "taskkill":
|
||||
return mock_taskkill
|
||||
return MagicMock()
|
||||
|
||||
with patch("gateway.platforms.whatsapp._IS_WINDOWS", True), \
|
||||
patch("gateway.platforms.whatsapp.subprocess.run", side_effect=run_side_effect) as mock_run:
|
||||
_kill_port_process(3000)
|
||||
|
||||
# netstat called
|
||||
assert any(
|
||||
call.args[0][0] == "netstat" for call in mock_run.call_args_list
|
||||
)
|
||||
# taskkill called with correct PID
|
||||
assert any(
|
||||
call.args[0] == ["taskkill", "/PID", "12345", "/F"]
|
||||
for call in mock_run.call_args_list
|
||||
)
|
||||
|
||||
def test_does_not_kill_wrong_port_on_windows(self):
|
||||
from gateway.platforms.whatsapp import _kill_port_process
|
||||
|
||||
netstat_output = (
|
||||
" TCP 0.0.0.0:30000 0.0.0.0:0 LISTENING 55555\n"
|
||||
)
|
||||
mock_netstat = MagicMock(stdout=netstat_output)
|
||||
|
||||
with patch("gateway.platforms.whatsapp._IS_WINDOWS", True), \
|
||||
patch("gateway.platforms.whatsapp.subprocess.run", return_value=mock_netstat) as mock_run:
|
||||
_kill_port_process(3000)
|
||||
|
||||
# Should NOT call taskkill because port 30000 != 3000
|
||||
assert not any(
|
||||
call.args[0][0] == "taskkill"
|
||||
for call in mock_run.call_args_list
|
||||
)
|
||||
|
||||
def test_uses_fuser_on_linux(self):
|
||||
from gateway.platforms.whatsapp import _kill_port_process
|
||||
|
||||
mock_check = MagicMock(returncode=0)
|
||||
|
||||
with patch("gateway.platforms.whatsapp._IS_WINDOWS", False), \
|
||||
patch("gateway.platforms.whatsapp.subprocess.run", return_value=mock_check) as mock_run:
|
||||
_kill_port_process(3000)
|
||||
|
||||
calls = [c.args[0] for c in mock_run.call_args_list]
|
||||
assert ["fuser", "3000/tcp"] in calls
|
||||
assert ["fuser", "-k", "3000/tcp"] in calls
|
||||
|
||||
def test_skips_fuser_kill_when_port_free(self):
|
||||
from gateway.platforms.whatsapp import _kill_port_process
|
||||
|
||||
mock_check = MagicMock(returncode=1) # port not in use
|
||||
|
||||
with patch("gateway.platforms.whatsapp._IS_WINDOWS", False), \
|
||||
patch("gateway.platforms.whatsapp.subprocess.run", return_value=mock_check) as mock_run:
|
||||
_kill_port_process(3000)
|
||||
|
||||
calls = [c.args[0] for c in mock_run.call_args_list]
|
||||
assert ["fuser", "3000/tcp"] in calls
|
||||
assert ["fuser", "-k", "3000/tcp"] not in calls
|
||||
|
||||
def test_suppresses_exceptions(self):
|
||||
from gateway.platforms.whatsapp import _kill_port_process
|
||||
|
||||
with patch("gateway.platforms.whatsapp._IS_WINDOWS", True), \
|
||||
patch("gateway.platforms.whatsapp.subprocess.run", side_effect=OSError("no netstat")):
|
||||
_kill_port_process(3000) # must not raise
|
||||
121
hermes_code/tests/gateway/test_whatsapp_reply_prefix.py
Normal file
121
hermes_code/tests/gateway/test_whatsapp_reply_prefix.py
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
"""Tests for WhatsApp reply_prefix config.yaml support.
|
||||
|
||||
Covers:
|
||||
- config.yaml whatsapp.reply_prefix bridging into PlatformConfig.extra
|
||||
- WhatsAppAdapter reading reply_prefix from config.extra
|
||||
- Bridge subprocess receiving WHATSAPP_REPLY_PREFIX env var
|
||||
- Config version covers all ENV_VARS_BY_VERSION keys (regression guard)
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config bridging from config.yaml
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConfigYamlBridging:
|
||||
"""Test that whatsapp.reply_prefix in config.yaml flows into PlatformConfig."""
|
||||
|
||||
def test_reply_prefix_bridged_from_yaml(self, tmp_path):
|
||||
"""whatsapp.reply_prefix in config.yaml sets PlatformConfig.extra."""
|
||||
config_yaml = tmp_path / "config.yaml"
|
||||
config_yaml.write_text('whatsapp:\n reply_prefix: "Custom Bot"\n')
|
||||
|
||||
with patch("gateway.config.get_hermes_home", return_value=tmp_path):
|
||||
from gateway.config import load_gateway_config
|
||||
# Need to also patch WHATSAPP_ENABLED so the platform exists
|
||||
with patch.dict("os.environ", {"WHATSAPP_ENABLED": "true"}, clear=False):
|
||||
config = load_gateway_config()
|
||||
|
||||
wa_config = config.platforms.get(Platform.WHATSAPP)
|
||||
assert wa_config is not None
|
||||
assert wa_config.extra.get("reply_prefix") == "Custom Bot"
|
||||
|
||||
def test_empty_reply_prefix_bridged(self, tmp_path):
|
||||
"""Empty string reply_prefix disables the header."""
|
||||
config_yaml = tmp_path / "config.yaml"
|
||||
config_yaml.write_text('whatsapp:\n reply_prefix: ""\n')
|
||||
|
||||
with patch("gateway.config.get_hermes_home", return_value=tmp_path):
|
||||
from gateway.config import load_gateway_config
|
||||
with patch.dict("os.environ", {"WHATSAPP_ENABLED": "true"}, clear=False):
|
||||
config = load_gateway_config()
|
||||
|
||||
wa_config = config.platforms.get(Platform.WHATSAPP)
|
||||
assert wa_config is not None
|
||||
assert wa_config.extra.get("reply_prefix") == ""
|
||||
|
||||
def test_no_whatsapp_section_no_extra(self, tmp_path):
|
||||
"""Without whatsapp section, no reply_prefix is set."""
|
||||
config_yaml = tmp_path / "config.yaml"
|
||||
config_yaml.write_text("timezone: UTC\n")
|
||||
|
||||
with patch("gateway.config.get_hermes_home", return_value=tmp_path):
|
||||
from gateway.config import load_gateway_config
|
||||
with patch.dict("os.environ", {"WHATSAPP_ENABLED": "true"}, clear=False):
|
||||
config = load_gateway_config()
|
||||
|
||||
wa_config = config.platforms.get(Platform.WHATSAPP)
|
||||
assert wa_config is not None
|
||||
assert "reply_prefix" not in wa_config.extra
|
||||
|
||||
def test_whatsapp_section_without_reply_prefix(self, tmp_path):
|
||||
"""whatsapp section present but without reply_prefix key."""
|
||||
config_yaml = tmp_path / "config.yaml"
|
||||
config_yaml.write_text("whatsapp:\n other_setting: true\n")
|
||||
|
||||
with patch("gateway.config.get_hermes_home", return_value=tmp_path):
|
||||
from gateway.config import load_gateway_config
|
||||
with patch.dict("os.environ", {"WHATSAPP_ENABLED": "true"}, clear=False):
|
||||
config = load_gateway_config()
|
||||
|
||||
wa_config = config.platforms.get(Platform.WHATSAPP)
|
||||
assert "reply_prefix" not in wa_config.extra
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WhatsAppAdapter __init__
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAdapterInit:
|
||||
"""Test that WhatsAppAdapter reads reply_prefix from config.extra."""
|
||||
|
||||
def test_reply_prefix_from_extra(self):
|
||||
from gateway.platforms.whatsapp import WhatsAppAdapter
|
||||
config = PlatformConfig(enabled=True, extra={"reply_prefix": "Bot\\n"})
|
||||
adapter = WhatsAppAdapter(config)
|
||||
assert adapter._reply_prefix == "Bot\\n"
|
||||
|
||||
def test_reply_prefix_default_none(self):
|
||||
from gateway.platforms.whatsapp import WhatsAppAdapter
|
||||
config = PlatformConfig(enabled=True)
|
||||
adapter = WhatsAppAdapter(config)
|
||||
assert adapter._reply_prefix is None
|
||||
|
||||
def test_reply_prefix_empty_string(self):
|
||||
from gateway.platforms.whatsapp import WhatsAppAdapter
|
||||
config = PlatformConfig(enabled=True, extra={"reply_prefix": ""})
|
||||
adapter = WhatsAppAdapter(config)
|
||||
assert adapter._reply_prefix == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config version regression guard
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConfigVersionCoverage:
|
||||
"""Ensure _config_version covers all ENV_VARS_BY_VERSION keys."""
|
||||
|
||||
def test_default_config_version_covers_env_var_versions(self):
|
||||
"""_config_version must be >= the highest ENV_VARS_BY_VERSION key."""
|
||||
from hermes_cli.config import DEFAULT_CONFIG, ENV_VARS_BY_VERSION
|
||||
assert DEFAULT_CONFIG["_config_version"] >= max(ENV_VARS_BY_VERSION)
|
||||
Loading…
Add table
Add a link
Reference in a new issue