merge: resolve conflict with main (add mcp + homeassistant extras)
This commit is contained in:
commit
aefc330b8f
81 changed files with 8138 additions and 776 deletions
206
tests/gateway/test_channel_directory.py
Normal file
206
tests/gateway/test_channel_directory.py
Normal file
|
|
@ -0,0 +1,206 @@
|
|||
"""Tests for gateway/channel_directory.py — channel resolution and display."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from gateway.channel_directory import (
|
||||
resolve_channel_name,
|
||||
format_directory_for_display,
|
||||
load_directory,
|
||||
_build_from_sessions,
|
||||
DIRECTORY_PATH,
|
||||
)
|
||||
|
||||
|
||||
def _write_directory(tmp_path, platforms):
|
||||
"""Helper to write a fake channel directory."""
|
||||
data = {"updated_at": "2026-01-01T00:00:00", "platforms": platforms}
|
||||
cache_file = tmp_path / "channel_directory.json"
|
||||
cache_file.write_text(json.dumps(data))
|
||||
return cache_file
|
||||
|
||||
|
||||
class TestLoadDirectory:
|
||||
def test_missing_file(self, tmp_path):
|
||||
with patch("gateway.channel_directory.DIRECTORY_PATH", tmp_path / "nope.json"):
|
||||
result = load_directory()
|
||||
assert result["updated_at"] is None
|
||||
assert result["platforms"] == {}
|
||||
|
||||
def test_valid_file(self, tmp_path):
|
||||
cache_file = _write_directory(tmp_path, {
|
||||
"telegram": [{"id": "123", "name": "John", "type": "dm"}]
|
||||
})
|
||||
with patch("gateway.channel_directory.DIRECTORY_PATH", cache_file):
|
||||
result = load_directory()
|
||||
assert result["platforms"]["telegram"][0]["name"] == "John"
|
||||
|
||||
def test_corrupt_file(self, tmp_path):
|
||||
cache_file = tmp_path / "channel_directory.json"
|
||||
cache_file.write_text("{bad json")
|
||||
with patch("gateway.channel_directory.DIRECTORY_PATH", cache_file):
|
||||
result = load_directory()
|
||||
assert result["updated_at"] is None
|
||||
|
||||
|
||||
class TestResolveChannelName:
|
||||
def _setup(self, tmp_path, platforms):
|
||||
cache_file = _write_directory(tmp_path, platforms)
|
||||
return patch("gateway.channel_directory.DIRECTORY_PATH", cache_file)
|
||||
|
||||
def test_exact_match(self, tmp_path):
|
||||
platforms = {
|
||||
"discord": [
|
||||
{"id": "111", "name": "bot-home", "guild": "MyServer", "type": "channel"},
|
||||
{"id": "222", "name": "general", "guild": "MyServer", "type": "channel"},
|
||||
]
|
||||
}
|
||||
with self._setup(tmp_path, platforms):
|
||||
assert resolve_channel_name("discord", "bot-home") == "111"
|
||||
assert resolve_channel_name("discord", "#bot-home") == "111"
|
||||
|
||||
def test_case_insensitive(self, tmp_path):
|
||||
platforms = {
|
||||
"slack": [{"id": "C01", "name": "Engineering", "type": "channel"}]
|
||||
}
|
||||
with self._setup(tmp_path, platforms):
|
||||
assert resolve_channel_name("slack", "engineering") == "C01"
|
||||
assert resolve_channel_name("slack", "ENGINEERING") == "C01"
|
||||
|
||||
def test_guild_qualified_match(self, tmp_path):
|
||||
platforms = {
|
||||
"discord": [
|
||||
{"id": "111", "name": "general", "guild": "ServerA", "type": "channel"},
|
||||
{"id": "222", "name": "general", "guild": "ServerB", "type": "channel"},
|
||||
]
|
||||
}
|
||||
with self._setup(tmp_path, platforms):
|
||||
assert resolve_channel_name("discord", "ServerA/general") == "111"
|
||||
assert resolve_channel_name("discord", "ServerB/general") == "222"
|
||||
|
||||
def test_prefix_match_unambiguous(self, tmp_path):
|
||||
platforms = {
|
||||
"slack": [
|
||||
{"id": "C01", "name": "engineering-backend", "type": "channel"},
|
||||
{"id": "C02", "name": "design-team", "type": "channel"},
|
||||
]
|
||||
}
|
||||
with self._setup(tmp_path, platforms):
|
||||
# "engineering" prefix matches only one channel
|
||||
assert resolve_channel_name("slack", "engineering") == "C01"
|
||||
|
||||
def test_prefix_match_ambiguous_returns_none(self, tmp_path):
|
||||
platforms = {
|
||||
"slack": [
|
||||
{"id": "C01", "name": "eng-backend", "type": "channel"},
|
||||
{"id": "C02", "name": "eng-frontend", "type": "channel"},
|
||||
]
|
||||
}
|
||||
with self._setup(tmp_path, platforms):
|
||||
assert resolve_channel_name("slack", "eng") is None
|
||||
|
||||
def test_no_channels_returns_none(self, tmp_path):
|
||||
with self._setup(tmp_path, {}):
|
||||
assert resolve_channel_name("telegram", "someone") is None
|
||||
|
||||
def test_no_match_returns_none(self, tmp_path):
|
||||
platforms = {
|
||||
"telegram": [{"id": "123", "name": "John", "type": "dm"}]
|
||||
}
|
||||
with self._setup(tmp_path, platforms):
|
||||
assert resolve_channel_name("telegram", "nonexistent") is None
|
||||
|
||||
|
||||
class TestBuildFromSessions:
|
||||
def _write_sessions(self, tmp_path, sessions_data):
|
||||
"""Write sessions.json at the path _build_from_sessions expects."""
|
||||
sessions_path = tmp_path / ".hermes" / "sessions" / "sessions.json"
|
||||
sessions_path.parent.mkdir(parents=True)
|
||||
sessions_path.write_text(json.dumps(sessions_data))
|
||||
|
||||
def test_builds_from_sessions_json(self, tmp_path):
|
||||
self._write_sessions(tmp_path, {
|
||||
"session_1": {
|
||||
"origin": {
|
||||
"platform": "telegram",
|
||||
"chat_id": "12345",
|
||||
"chat_name": "Alice",
|
||||
},
|
||||
"chat_type": "dm",
|
||||
},
|
||||
"session_2": {
|
||||
"origin": {
|
||||
"platform": "telegram",
|
||||
"chat_id": "67890",
|
||||
"user_name": "Bob",
|
||||
},
|
||||
"chat_type": "group",
|
||||
},
|
||||
"session_3": {
|
||||
"origin": {
|
||||
"platform": "discord",
|
||||
"chat_id": "99999",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
with patch.object(Path, "home", return_value=tmp_path):
|
||||
entries = _build_from_sessions("telegram")
|
||||
|
||||
assert len(entries) == 2
|
||||
names = {e["name"] for e in entries}
|
||||
assert "Alice" in names
|
||||
assert "Bob" in names
|
||||
|
||||
def test_missing_sessions_file(self, tmp_path):
|
||||
with patch.object(Path, "home", return_value=tmp_path):
|
||||
entries = _build_from_sessions("telegram")
|
||||
assert entries == []
|
||||
|
||||
def test_deduplication_by_chat_id(self, tmp_path):
|
||||
self._write_sessions(tmp_path, {
|
||||
"s1": {"origin": {"platform": "telegram", "chat_id": "123", "chat_name": "X"}},
|
||||
"s2": {"origin": {"platform": "telegram", "chat_id": "123", "chat_name": "X"}},
|
||||
})
|
||||
|
||||
with patch.object(Path, "home", return_value=tmp_path):
|
||||
entries = _build_from_sessions("telegram")
|
||||
|
||||
assert len(entries) == 1
|
||||
|
||||
|
||||
class TestFormatDirectoryForDisplay:
|
||||
def test_empty_directory(self, tmp_path):
|
||||
with patch("gateway.channel_directory.DIRECTORY_PATH", tmp_path / "nope.json"):
|
||||
result = format_directory_for_display()
|
||||
assert "No messaging platforms" in result
|
||||
|
||||
def test_telegram_display(self, tmp_path):
|
||||
cache_file = _write_directory(tmp_path, {
|
||||
"telegram": [
|
||||
{"id": "123", "name": "Alice", "type": "dm"},
|
||||
{"id": "456", "name": "Dev Group", "type": "group"},
|
||||
]
|
||||
})
|
||||
with patch("gateway.channel_directory.DIRECTORY_PATH", cache_file):
|
||||
result = format_directory_for_display()
|
||||
|
||||
assert "Telegram:" in result
|
||||
assert "telegram:Alice" in result
|
||||
assert "telegram:Dev Group" in result
|
||||
|
||||
def test_discord_grouped_by_guild(self, tmp_path):
|
||||
cache_file = _write_directory(tmp_path, {
|
||||
"discord": [
|
||||
{"id": "1", "name": "general", "guild": "Server1", "type": "channel"},
|
||||
{"id": "2", "name": "bot-home", "guild": "Server1", "type": "channel"},
|
||||
{"id": "3", "name": "chat", "guild": "Server2", "type": "channel"},
|
||||
]
|
||||
})
|
||||
with patch("gateway.channel_directory.DIRECTORY_PATH", cache_file):
|
||||
result = format_directory_for_display()
|
||||
|
||||
assert "Discord (Server1):" in result
|
||||
assert "Discord (Server2):" in result
|
||||
assert "discord:#general" in result
|
||||
213
tests/gateway/test_hooks.py
Normal file
213
tests/gateway/test_hooks.py
Normal file
|
|
@ -0,0 +1,213 @@
|
|||
"""Tests for gateway/hooks.py — event hook system."""
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.hooks import HookRegistry
|
||||
|
||||
|
||||
def _create_hook(hooks_dir, hook_name, events, handler_code):
|
||||
"""Helper to create a hook directory with HOOK.yaml and handler.py."""
|
||||
hook_dir = hooks_dir / hook_name
|
||||
hook_dir.mkdir(parents=True)
|
||||
(hook_dir / "HOOK.yaml").write_text(
|
||||
f"name: {hook_name}\n"
|
||||
f"description: Test hook\n"
|
||||
f"events: {events}\n"
|
||||
)
|
||||
(hook_dir / "handler.py").write_text(handler_code)
|
||||
return hook_dir
|
||||
|
||||
|
||||
class TestHookRegistryInit:
|
||||
def test_empty_registry(self):
|
||||
reg = HookRegistry()
|
||||
assert reg.loaded_hooks == []
|
||||
assert reg._handlers == {}
|
||||
|
||||
|
||||
class TestDiscoverAndLoad:
|
||||
def test_loads_valid_hook(self, tmp_path):
|
||||
_create_hook(tmp_path, "my-hook", '["agent:start"]',
|
||||
"def handle(event_type, context):\n pass\n")
|
||||
|
||||
reg = HookRegistry()
|
||||
with patch("gateway.hooks.HOOKS_DIR", tmp_path):
|
||||
reg.discover_and_load()
|
||||
|
||||
assert len(reg.loaded_hooks) == 1
|
||||
assert reg.loaded_hooks[0]["name"] == "my-hook"
|
||||
assert "agent:start" in reg.loaded_hooks[0]["events"]
|
||||
|
||||
def test_skips_missing_hook_yaml(self, tmp_path):
|
||||
hook_dir = tmp_path / "bad-hook"
|
||||
hook_dir.mkdir()
|
||||
(hook_dir / "handler.py").write_text("def handle(e, c): pass\n")
|
||||
|
||||
reg = HookRegistry()
|
||||
with patch("gateway.hooks.HOOKS_DIR", tmp_path):
|
||||
reg.discover_and_load()
|
||||
|
||||
assert len(reg.loaded_hooks) == 0
|
||||
|
||||
def test_skips_missing_handler_py(self, tmp_path):
|
||||
hook_dir = tmp_path / "bad-hook"
|
||||
hook_dir.mkdir()
|
||||
(hook_dir / "HOOK.yaml").write_text("name: bad\nevents: ['agent:start']\n")
|
||||
|
||||
reg = HookRegistry()
|
||||
with patch("gateway.hooks.HOOKS_DIR", tmp_path):
|
||||
reg.discover_and_load()
|
||||
|
||||
assert len(reg.loaded_hooks) == 0
|
||||
|
||||
def test_skips_no_events(self, tmp_path):
|
||||
hook_dir = tmp_path / "empty-hook"
|
||||
hook_dir.mkdir()
|
||||
(hook_dir / "HOOK.yaml").write_text("name: empty\nevents: []\n")
|
||||
(hook_dir / "handler.py").write_text("def handle(e, c): pass\n")
|
||||
|
||||
reg = HookRegistry()
|
||||
with patch("gateway.hooks.HOOKS_DIR", tmp_path):
|
||||
reg.discover_and_load()
|
||||
|
||||
assert len(reg.loaded_hooks) == 0
|
||||
|
||||
def test_skips_no_handle_function(self, tmp_path):
|
||||
hook_dir = tmp_path / "no-handle"
|
||||
hook_dir.mkdir()
|
||||
(hook_dir / "HOOK.yaml").write_text("name: no-handle\nevents: ['agent:start']\n")
|
||||
(hook_dir / "handler.py").write_text("def something_else(): pass\n")
|
||||
|
||||
reg = HookRegistry()
|
||||
with patch("gateway.hooks.HOOKS_DIR", tmp_path):
|
||||
reg.discover_and_load()
|
||||
|
||||
assert len(reg.loaded_hooks) == 0
|
||||
|
||||
def test_nonexistent_hooks_dir(self, tmp_path):
|
||||
reg = HookRegistry()
|
||||
with patch("gateway.hooks.HOOKS_DIR", tmp_path / "nonexistent"):
|
||||
reg.discover_and_load()
|
||||
|
||||
assert len(reg.loaded_hooks) == 0
|
||||
|
||||
def test_multiple_hooks(self, tmp_path):
|
||||
_create_hook(tmp_path, "hook-a", '["agent:start"]',
|
||||
"def handle(e, c): pass\n")
|
||||
_create_hook(tmp_path, "hook-b", '["session:start", "session:reset"]',
|
||||
"def handle(e, c): pass\n")
|
||||
|
||||
reg = HookRegistry()
|
||||
with patch("gateway.hooks.HOOKS_DIR", tmp_path):
|
||||
reg.discover_and_load()
|
||||
|
||||
assert len(reg.loaded_hooks) == 2
|
||||
|
||||
|
||||
class TestEmit:
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_calls_sync_handler(self, tmp_path):
|
||||
results = []
|
||||
|
||||
_create_hook(tmp_path, "sync-hook", '["agent:start"]',
|
||||
"results = []\n"
|
||||
"def handle(event_type, context):\n"
|
||||
" results.append(event_type)\n")
|
||||
|
||||
reg = HookRegistry()
|
||||
with patch("gateway.hooks.HOOKS_DIR", tmp_path):
|
||||
reg.discover_and_load()
|
||||
|
||||
# Inject our results list into the handler's module globals
|
||||
handler_fn = reg._handlers["agent:start"][0]
|
||||
handler_fn.__globals__["results"] = results
|
||||
|
||||
await reg.emit("agent:start", {"test": True})
|
||||
assert "agent:start" in results
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_calls_async_handler(self, tmp_path):
|
||||
results = []
|
||||
|
||||
hook_dir = tmp_path / "async-hook"
|
||||
hook_dir.mkdir()
|
||||
(hook_dir / "HOOK.yaml").write_text(
|
||||
"name: async-hook\nevents: ['agent:end']\n"
|
||||
)
|
||||
(hook_dir / "handler.py").write_text(
|
||||
"import asyncio\n"
|
||||
"results = []\n"
|
||||
"async def handle(event_type, context):\n"
|
||||
" results.append(event_type)\n"
|
||||
)
|
||||
|
||||
reg = HookRegistry()
|
||||
with patch("gateway.hooks.HOOKS_DIR", tmp_path):
|
||||
reg.discover_and_load()
|
||||
|
||||
handler_fn = reg._handlers["agent:end"][0]
|
||||
handler_fn.__globals__["results"] = results
|
||||
|
||||
await reg.emit("agent:end", {})
|
||||
assert "agent:end" in results
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wildcard_matching(self, tmp_path):
|
||||
results = []
|
||||
|
||||
_create_hook(tmp_path, "wildcard-hook", '["command:*"]',
|
||||
"results = []\n"
|
||||
"def handle(event_type, context):\n"
|
||||
" results.append(event_type)\n")
|
||||
|
||||
reg = HookRegistry()
|
||||
with patch("gateway.hooks.HOOKS_DIR", tmp_path):
|
||||
reg.discover_and_load()
|
||||
|
||||
handler_fn = reg._handlers["command:*"][0]
|
||||
handler_fn.__globals__["results"] = results
|
||||
|
||||
await reg.emit("command:reset", {})
|
||||
assert "command:reset" in results
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_handlers_for_event(self, tmp_path):
|
||||
reg = HookRegistry()
|
||||
# Should not raise
|
||||
await reg.emit("unknown:event", {})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_error_does_not_propagate(self, tmp_path):
|
||||
_create_hook(tmp_path, "bad-hook", '["agent:start"]',
|
||||
"def handle(event_type, context):\n"
|
||||
" raise ValueError('boom')\n")
|
||||
|
||||
reg = HookRegistry()
|
||||
with patch("gateway.hooks.HOOKS_DIR", tmp_path):
|
||||
reg.discover_and_load()
|
||||
|
||||
# Should not raise even though handler throws
|
||||
await reg.emit("agent:start", {})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_default_context(self, tmp_path):
|
||||
captured = []
|
||||
|
||||
_create_hook(tmp_path, "ctx-hook", '["agent:start"]',
|
||||
"captured = []\n"
|
||||
"def handle(event_type, context):\n"
|
||||
" captured.append(context)\n")
|
||||
|
||||
reg = HookRegistry()
|
||||
with patch("gateway.hooks.HOOKS_DIR", tmp_path):
|
||||
reg.discover_and_load()
|
||||
|
||||
handler_fn = reg._handlers["agent:start"][0]
|
||||
handler_fn.__globals__["captured"] = captured
|
||||
|
||||
await reg.emit("agent:start") # no context arg
|
||||
assert captured[0] == {}
|
||||
162
tests/gateway/test_mirror.py
Normal file
162
tests/gateway/test_mirror.py
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
"""Tests for gateway/mirror.py — session mirroring."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import gateway.mirror as mirror_mod
|
||||
from gateway.mirror import (
|
||||
mirror_to_session,
|
||||
_find_session_id,
|
||||
_append_to_jsonl,
|
||||
)
|
||||
|
||||
|
||||
def _setup_sessions(tmp_path, sessions_data):
|
||||
"""Helper to write a fake sessions.json and patch module-level paths."""
|
||||
sessions_dir = tmp_path / "sessions"
|
||||
sessions_dir.mkdir(parents=True, exist_ok=True)
|
||||
index_file = sessions_dir / "sessions.json"
|
||||
index_file.write_text(json.dumps(sessions_data))
|
||||
return sessions_dir, index_file
|
||||
|
||||
|
||||
class TestFindSessionId:
|
||||
def test_finds_matching_session(self, tmp_path):
|
||||
sessions_dir, index_file = _setup_sessions(tmp_path, {
|
||||
"agent:main:telegram:dm": {
|
||||
"session_id": "sess_abc",
|
||||
"origin": {"platform": "telegram", "chat_id": "12345"},
|
||||
"updated_at": "2026-01-01T00:00:00",
|
||||
}
|
||||
})
|
||||
|
||||
with patch.object(mirror_mod, "_SESSIONS_DIR", sessions_dir), \
|
||||
patch.object(mirror_mod, "_SESSIONS_INDEX", index_file):
|
||||
result = _find_session_id("telegram", "12345")
|
||||
|
||||
assert result == "sess_abc"
|
||||
|
||||
def test_returns_most_recent(self, tmp_path):
|
||||
sessions_dir, index_file = _setup_sessions(tmp_path, {
|
||||
"old": {
|
||||
"session_id": "sess_old",
|
||||
"origin": {"platform": "telegram", "chat_id": "12345"},
|
||||
"updated_at": "2026-01-01T00:00:00",
|
||||
},
|
||||
"new": {
|
||||
"session_id": "sess_new",
|
||||
"origin": {"platform": "telegram", "chat_id": "12345"},
|
||||
"updated_at": "2026-02-01T00:00:00",
|
||||
},
|
||||
})
|
||||
|
||||
with patch.object(mirror_mod, "_SESSIONS_DIR", sessions_dir), \
|
||||
patch.object(mirror_mod, "_SESSIONS_INDEX", index_file):
|
||||
result = _find_session_id("telegram", "12345")
|
||||
|
||||
assert result == "sess_new"
|
||||
|
||||
def test_no_match_returns_none(self, tmp_path):
|
||||
sessions_dir, index_file = _setup_sessions(tmp_path, {
|
||||
"sess": {
|
||||
"session_id": "sess_1",
|
||||
"origin": {"platform": "discord", "chat_id": "999"},
|
||||
"updated_at": "2026-01-01T00:00:00",
|
||||
}
|
||||
})
|
||||
|
||||
with patch.object(mirror_mod, "_SESSIONS_INDEX", index_file):
|
||||
result = _find_session_id("telegram", "12345")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_missing_sessions_file(self, tmp_path):
|
||||
with patch.object(mirror_mod, "_SESSIONS_INDEX", tmp_path / "nope.json"):
|
||||
result = _find_session_id("telegram", "12345")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_platform_case_insensitive(self, tmp_path):
|
||||
sessions_dir, index_file = _setup_sessions(tmp_path, {
|
||||
"s1": {
|
||||
"session_id": "sess_1",
|
||||
"origin": {"platform": "Telegram", "chat_id": "123"},
|
||||
"updated_at": "2026-01-01T00:00:00",
|
||||
}
|
||||
})
|
||||
|
||||
with patch.object(mirror_mod, "_SESSIONS_INDEX", index_file):
|
||||
result = _find_session_id("telegram", "123")
|
||||
|
||||
assert result == "sess_1"
|
||||
|
||||
|
||||
class TestAppendToJsonl:
|
||||
def test_appends_message(self, tmp_path):
|
||||
sessions_dir = tmp_path / "sessions"
|
||||
sessions_dir.mkdir()
|
||||
|
||||
with patch.object(mirror_mod, "_SESSIONS_DIR", sessions_dir):
|
||||
_append_to_jsonl("sess_1", {"role": "assistant", "content": "Hello"})
|
||||
|
||||
transcript = sessions_dir / "sess_1.jsonl"
|
||||
lines = transcript.read_text().strip().splitlines()
|
||||
assert len(lines) == 1
|
||||
msg = json.loads(lines[0])
|
||||
assert msg["role"] == "assistant"
|
||||
assert msg["content"] == "Hello"
|
||||
|
||||
def test_appends_multiple_messages(self, tmp_path):
|
||||
sessions_dir = tmp_path / "sessions"
|
||||
sessions_dir.mkdir()
|
||||
|
||||
with patch.object(mirror_mod, "_SESSIONS_DIR", sessions_dir):
|
||||
_append_to_jsonl("sess_1", {"role": "assistant", "content": "msg1"})
|
||||
_append_to_jsonl("sess_1", {"role": "assistant", "content": "msg2"})
|
||||
|
||||
transcript = sessions_dir / "sess_1.jsonl"
|
||||
lines = transcript.read_text().strip().splitlines()
|
||||
assert len(lines) == 2
|
||||
|
||||
|
||||
class TestMirrorToSession:
|
||||
def test_successful_mirror(self, tmp_path):
|
||||
sessions_dir, index_file = _setup_sessions(tmp_path, {
|
||||
"s1": {
|
||||
"session_id": "sess_abc",
|
||||
"origin": {"platform": "telegram", "chat_id": "12345"},
|
||||
"updated_at": "2026-01-01T00:00:00",
|
||||
}
|
||||
})
|
||||
|
||||
with patch.object(mirror_mod, "_SESSIONS_DIR", sessions_dir), \
|
||||
patch.object(mirror_mod, "_SESSIONS_INDEX", index_file), \
|
||||
patch("gateway.mirror._append_to_sqlite"):
|
||||
result = mirror_to_session("telegram", "12345", "Hello!", source_label="cli")
|
||||
|
||||
assert result is True
|
||||
|
||||
# Check JSONL was written
|
||||
transcript = sessions_dir / "sess_abc.jsonl"
|
||||
assert transcript.exists()
|
||||
msg = json.loads(transcript.read_text().strip())
|
||||
assert msg["content"] == "Hello!"
|
||||
assert msg["role"] == "assistant"
|
||||
assert msg["mirror"] is True
|
||||
assert msg["mirror_source"] == "cli"
|
||||
|
||||
def test_no_matching_session(self, tmp_path):
|
||||
sessions_dir, index_file = _setup_sessions(tmp_path, {})
|
||||
|
||||
with patch.object(mirror_mod, "_SESSIONS_DIR", sessions_dir), \
|
||||
patch.object(mirror_mod, "_SESSIONS_INDEX", index_file):
|
||||
result = mirror_to_session("telegram", "99999", "Hello!")
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_error_returns_false(self, tmp_path):
|
||||
with patch("gateway.mirror._find_session_id", side_effect=Exception("boom")):
|
||||
result = mirror_to_session("telegram", "123", "msg")
|
||||
|
||||
assert result is False
|
||||
|
|
@ -1,9 +1,13 @@
|
|||
"""Tests for gateway session management."""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
from gateway.config import Platform, HomeChannel, GatewayConfig, PlatformConfig
|
||||
from gateway.session import (
|
||||
SessionSource,
|
||||
SessionStore,
|
||||
build_session_context,
|
||||
build_session_context_prompt,
|
||||
)
|
||||
|
|
@ -31,6 +35,24 @@ class TestSessionSourceRoundtrip:
|
|||
assert restored.user_name == "alice"
|
||||
assert restored.thread_id == "t1"
|
||||
|
||||
def test_full_roundtrip_with_chat_topic(self):
|
||||
"""chat_topic should survive to_dict/from_dict roundtrip."""
|
||||
source = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="789",
|
||||
chat_name="Server / #project-planning",
|
||||
chat_type="group",
|
||||
user_id="42",
|
||||
user_name="bob",
|
||||
chat_topic="Planning and coordination for Project X",
|
||||
)
|
||||
d = source.to_dict()
|
||||
assert d["chat_topic"] == "Planning and coordination for Project X"
|
||||
|
||||
restored = SessionSource.from_dict(d)
|
||||
assert restored.chat_topic == "Planning and coordination for Project X"
|
||||
assert restored.chat_name == "Server / #project-planning"
|
||||
|
||||
def test_minimal_roundtrip(self):
|
||||
source = SessionSource(platform=Platform.LOCAL, chat_id="cli")
|
||||
d = source.to_dict()
|
||||
|
|
@ -57,6 +79,7 @@ class TestSessionSourceRoundtrip:
|
|||
assert restored.user_id is None
|
||||
assert restored.user_name is None
|
||||
assert restored.thread_id is None
|
||||
assert restored.chat_topic is None
|
||||
assert restored.chat_type == "dm"
|
||||
|
||||
def test_invalid_platform_raises(self):
|
||||
|
|
@ -174,6 +197,52 @@ class TestBuildSessionContextPrompt:
|
|||
|
||||
assert "Discord" in prompt
|
||||
|
||||
def test_discord_prompt_with_channel_topic(self):
|
||||
"""Channel topic should appear in the session context prompt."""
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.DISCORD: PlatformConfig(
|
||||
enabled=True,
|
||||
token="fake-discord-token",
|
||||
),
|
||||
},
|
||||
)
|
||||
source = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="guild-123",
|
||||
chat_name="Server / #project-planning",
|
||||
chat_type="group",
|
||||
user_name="alice",
|
||||
chat_topic="Planning and coordination for Project X",
|
||||
)
|
||||
ctx = build_session_context(source, config)
|
||||
prompt = build_session_context_prompt(ctx)
|
||||
|
||||
assert "Discord" in prompt
|
||||
assert "**Channel Topic:** Planning and coordination for Project X" in prompt
|
||||
|
||||
def test_prompt_omits_channel_topic_when_none(self):
|
||||
"""Channel Topic line should NOT appear when chat_topic is None."""
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.DISCORD: PlatformConfig(
|
||||
enabled=True,
|
||||
token="fake-discord-token",
|
||||
),
|
||||
},
|
||||
)
|
||||
source = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="guild-123",
|
||||
chat_name="Server / #general",
|
||||
chat_type="group",
|
||||
user_name="alice",
|
||||
)
|
||||
ctx = build_session_context(source, config)
|
||||
prompt = build_session_context_prompt(ctx)
|
||||
|
||||
assert "Channel Topic" not in prompt
|
||||
|
||||
def test_local_prompt_mentions_machine(self):
|
||||
config = GatewayConfig()
|
||||
source = SessionSource.local_cli()
|
||||
|
|
@ -199,3 +268,59 @@ class TestBuildSessionContextPrompt:
|
|||
prompt = build_session_context_prompt(ctx)
|
||||
|
||||
assert "WhatsApp" in prompt or "whatsapp" in prompt.lower()
|
||||
|
||||
|
||||
class TestSessionStoreRewriteTranscript:
|
||||
"""Regression: /retry and /undo must persist truncated history to disk."""
|
||||
|
||||
@pytest.fixture()
|
||||
def store(self, tmp_path):
|
||||
config = GatewayConfig()
|
||||
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||
s = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
s._db = None # no SQLite for these tests
|
||||
s._loaded = True
|
||||
return s
|
||||
|
||||
def test_rewrite_replaces_jsonl(self, store, tmp_path):
|
||||
session_id = "test_session_1"
|
||||
# Write initial transcript
|
||||
for msg in [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "hi"},
|
||||
{"role": "user", "content": "undo this"},
|
||||
{"role": "assistant", "content": "ok"},
|
||||
]:
|
||||
store.append_to_transcript(session_id, msg)
|
||||
|
||||
# Rewrite with truncated history
|
||||
store.rewrite_transcript(session_id, [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "hi"},
|
||||
])
|
||||
|
||||
reloaded = store.load_transcript(session_id)
|
||||
assert len(reloaded) == 2
|
||||
assert reloaded[0]["content"] == "hello"
|
||||
assert reloaded[1]["content"] == "hi"
|
||||
|
||||
def test_rewrite_with_empty_list(self, store):
|
||||
session_id = "test_session_2"
|
||||
store.append_to_transcript(session_id, {"role": "user", "content": "hi"})
|
||||
|
||||
store.rewrite_transcript(session_id, [])
|
||||
|
||||
reloaded = store.load_transcript(session_id)
|
||||
assert reloaded == []
|
||||
|
||||
|
||||
class TestSessionStoreEntriesAttribute:
|
||||
"""Regression: /reset must access _entries, not _sessions."""
|
||||
|
||||
def test_entries_attribute_exists(self):
|
||||
config = GatewayConfig()
|
||||
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||
store = SessionStore(sessions_dir=Path("/tmp"), config=config)
|
||||
store._loaded = True
|
||||
assert hasattr(store, "_entries")
|
||||
assert not hasattr(store, "_sessions")
|
||||
|
|
|
|||
127
tests/gateway/test_sticker_cache.py
Normal file
127
tests/gateway/test_sticker_cache.py
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
"""Tests for gateway/sticker_cache.py — sticker description cache."""
|
||||
|
||||
import json
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
|
||||
from gateway.sticker_cache import (
|
||||
_load_cache,
|
||||
_save_cache,
|
||||
get_cached_description,
|
||||
cache_sticker_description,
|
||||
build_sticker_injection,
|
||||
build_animated_sticker_injection,
|
||||
STICKER_VISION_PROMPT,
|
||||
)
|
||||
|
||||
|
||||
class TestLoadSaveCache:
|
||||
def test_load_missing_file(self, tmp_path):
|
||||
with patch("gateway.sticker_cache.CACHE_PATH", tmp_path / "nope.json"):
|
||||
assert _load_cache() == {}
|
||||
|
||||
def test_load_corrupt_file(self, tmp_path):
|
||||
bad_file = tmp_path / "bad.json"
|
||||
bad_file.write_text("not json{{{")
|
||||
with patch("gateway.sticker_cache.CACHE_PATH", bad_file):
|
||||
assert _load_cache() == {}
|
||||
|
||||
def test_save_and_load_roundtrip(self, tmp_path):
|
||||
cache_file = tmp_path / "cache.json"
|
||||
data = {"abc123": {"description": "A cat", "emoji": "", "set_name": "", "cached_at": 1.0}}
|
||||
with patch("gateway.sticker_cache.CACHE_PATH", cache_file):
|
||||
_save_cache(data)
|
||||
loaded = _load_cache()
|
||||
assert loaded == data
|
||||
|
||||
def test_save_creates_parent_dirs(self, tmp_path):
|
||||
cache_file = tmp_path / "sub" / "dir" / "cache.json"
|
||||
with patch("gateway.sticker_cache.CACHE_PATH", cache_file):
|
||||
_save_cache({"key": "value"})
|
||||
assert cache_file.exists()
|
||||
|
||||
|
||||
class TestCacheSticker:
|
||||
def test_cache_and_retrieve(self, tmp_path):
|
||||
cache_file = tmp_path / "cache.json"
|
||||
with patch("gateway.sticker_cache.CACHE_PATH", cache_file):
|
||||
cache_sticker_description("uid_1", "A happy dog", emoji="🐕", set_name="Dogs")
|
||||
result = get_cached_description("uid_1")
|
||||
|
||||
assert result is not None
|
||||
assert result["description"] == "A happy dog"
|
||||
assert result["emoji"] == "🐕"
|
||||
assert result["set_name"] == "Dogs"
|
||||
assert "cached_at" in result
|
||||
|
||||
def test_missing_sticker_returns_none(self, tmp_path):
|
||||
cache_file = tmp_path / "cache.json"
|
||||
with patch("gateway.sticker_cache.CACHE_PATH", cache_file):
|
||||
result = get_cached_description("nonexistent")
|
||||
assert result is None
|
||||
|
||||
def test_overwrite_existing(self, tmp_path):
|
||||
cache_file = tmp_path / "cache.json"
|
||||
with patch("gateway.sticker_cache.CACHE_PATH", cache_file):
|
||||
cache_sticker_description("uid_1", "Old description")
|
||||
cache_sticker_description("uid_1", "New description")
|
||||
result = get_cached_description("uid_1")
|
||||
|
||||
assert result["description"] == "New description"
|
||||
|
||||
def test_multiple_stickers(self, tmp_path):
|
||||
cache_file = tmp_path / "cache.json"
|
||||
with patch("gateway.sticker_cache.CACHE_PATH", cache_file):
|
||||
cache_sticker_description("uid_1", "Cat")
|
||||
cache_sticker_description("uid_2", "Dog")
|
||||
r1 = get_cached_description("uid_1")
|
||||
r2 = get_cached_description("uid_2")
|
||||
|
||||
assert r1["description"] == "Cat"
|
||||
assert r2["description"] == "Dog"
|
||||
|
||||
|
||||
class TestBuildStickerInjection:
|
||||
def test_exact_format_no_context(self):
|
||||
result = build_sticker_injection("A cat waving")
|
||||
assert result == '[The user sent a sticker~ It shows: "A cat waving" (=^.w.^=)]'
|
||||
|
||||
def test_exact_format_emoji_only(self):
|
||||
result = build_sticker_injection("A cat", emoji="😀")
|
||||
assert result == '[The user sent a sticker 😀~ It shows: "A cat" (=^.w.^=)]'
|
||||
|
||||
def test_exact_format_emoji_and_set_name(self):
|
||||
result = build_sticker_injection("A cat", emoji="😀", set_name="MyPack")
|
||||
assert result == '[The user sent a sticker 😀 from "MyPack"~ It shows: "A cat" (=^.w.^=)]'
|
||||
|
||||
def test_set_name_without_emoji_ignored(self):
|
||||
"""set_name alone (no emoji) produces no context — only emoji+set_name triggers 'from' clause."""
|
||||
result = build_sticker_injection("A cat", set_name="MyPack")
|
||||
assert result == '[The user sent a sticker~ It shows: "A cat" (=^.w.^=)]'
|
||||
assert "MyPack" not in result
|
||||
|
||||
def test_description_with_quotes(self):
|
||||
result = build_sticker_injection('A "happy" dog')
|
||||
assert '"A \\"happy\\" dog"' not in result # no escaping happens
|
||||
assert 'A "happy" dog' in result
|
||||
|
||||
def test_empty_description(self):
|
||||
result = build_sticker_injection("")
|
||||
assert result == '[The user sent a sticker~ It shows: "" (=^.w.^=)]'
|
||||
|
||||
|
||||
class TestBuildAnimatedStickerInjection:
|
||||
def test_exact_format_with_emoji(self):
|
||||
result = build_animated_sticker_injection(emoji="🎉")
|
||||
assert result == (
|
||||
"[The user sent an animated sticker 🎉~ "
|
||||
"I can't see animated ones yet, but the emoji suggests: 🎉]"
|
||||
)
|
||||
|
||||
def test_exact_format_without_emoji(self):
|
||||
result = build_animated_sticker_injection()
|
||||
assert result == "[The user sent an animated sticker~ I can't see animated ones yet]"
|
||||
|
||||
def test_empty_emoji_same_as_no_emoji(self):
|
||||
result = build_animated_sticker_injection(emoji="")
|
||||
assert result == build_animated_sticker_injection()
|
||||
Loading…
Add table
Add a link
Reference in a new issue