merge: resolve conflicts with origin/main

This commit is contained in:
teknium1 2026-03-17 04:30:37 -07:00
commit 0897e4350e
100 changed files with 11637 additions and 1337 deletions

View file

@ -113,11 +113,13 @@ class TestDefaultContextLengths:
def test_gpt4_models_128k_or_1m(self):
# gpt-4.1 and gpt-4.1-mini have 1M context; other gpt-4* have 128k
for key, value in DEFAULT_CONTEXT_LENGTHS.items():
if "gpt-4" in key:
if "gpt-4.1" in key:
assert value == 1047576, f"{key} should be 1047576 (1M)"
else:
assert value == 128000, f"{key} should be 128000"
if "gpt-4" in key and "gpt-4.1" not in key:
assert value == 128000, f"{key} should be 128000"
def test_gpt41_models_1m(self):
for key, value in DEFAULT_CONTEXT_LENGTHS.items():
if "gpt-4.1" in key:
assert value == 1047576, f"{key} should be 1047576"
def test_gemini_models_1m(self):
for key, value in DEFAULT_CONTEXT_LENGTHS.items():

View file

@ -0,0 +1,101 @@
from types import SimpleNamespace
from agent.usage_pricing import (
CanonicalUsage,
estimate_usage_cost,
get_pricing_entry,
normalize_usage,
)
def test_normalize_usage_anthropic_keeps_cache_buckets_separate():
usage = SimpleNamespace(
input_tokens=1000,
output_tokens=500,
cache_read_input_tokens=2000,
cache_creation_input_tokens=400,
)
normalized = normalize_usage(usage, provider="anthropic", api_mode="anthropic_messages")
assert normalized.input_tokens == 1000
assert normalized.output_tokens == 500
assert normalized.cache_read_tokens == 2000
assert normalized.cache_write_tokens == 400
assert normalized.prompt_tokens == 3400
def test_normalize_usage_openai_subtracts_cached_prompt_tokens():
usage = SimpleNamespace(
prompt_tokens=3000,
completion_tokens=700,
prompt_tokens_details=SimpleNamespace(cached_tokens=1800),
)
normalized = normalize_usage(usage, provider="openai", api_mode="chat_completions")
assert normalized.input_tokens == 1200
assert normalized.cache_read_tokens == 1800
assert normalized.output_tokens == 700
def test_openrouter_models_api_pricing_is_converted_from_per_token_to_per_million(monkeypatch):
monkeypatch.setattr(
"agent.usage_pricing.fetch_model_metadata",
lambda: {
"anthropic/claude-opus-4.6": {
"pricing": {
"prompt": "0.000005",
"completion": "0.000025",
"input_cache_read": "0.0000005",
"input_cache_write": "0.00000625",
}
}
},
)
entry = get_pricing_entry(
"anthropic/claude-opus-4.6",
provider="openrouter",
base_url="https://openrouter.ai/api/v1",
)
assert float(entry.input_cost_per_million) == 5.0
assert float(entry.output_cost_per_million) == 25.0
assert float(entry.cache_read_cost_per_million) == 0.5
assert float(entry.cache_write_cost_per_million) == 6.25
def test_estimate_usage_cost_marks_subscription_routes_included():
result = estimate_usage_cost(
"gpt-5.3-codex",
CanonicalUsage(input_tokens=1000, output_tokens=500),
provider="openai-codex",
base_url="https://chatgpt.com/backend-api/codex",
)
assert result.status == "included"
assert float(result.amount_usd) == 0.0
def test_estimate_usage_cost_refuses_cache_pricing_without_official_cache_rate(monkeypatch):
monkeypatch.setattr(
"agent.usage_pricing.fetch_model_metadata",
lambda: {
"google/gemini-2.5-pro": {
"pricing": {
"prompt": "0.00000125",
"completion": "0.00001",
}
}
},
)
result = estimate_usage_cost(
"google/gemini-2.5-pro",
CanonicalUsage(input_tokens=1000, output_tokens=500, cache_read_tokens=100),
provider="openrouter",
base_url="https://openrouter.ai/api/v1",
)
assert result.status == "unknown"

View file

@ -50,13 +50,16 @@ def _build_runner(monkeypatch, tmp_path, mode: str) -> GatewayRunner:
return runner
def _watcher_dict(session_id="proc_test"):
return {
def _watcher_dict(session_id="proc_test", thread_id=""):
d = {
"session_id": session_id,
"check_interval": 0,
"platform": "telegram",
"chat_id": "123",
}
if thread_id:
d["thread_id"] = thread_id
return d
# ---------------------------------------------------------------------------
@ -196,3 +199,47 @@ async def test_run_process_watcher_respects_notification_mode(
if expected_fragment is not None:
sent_message = adapter.send.await_args.args[1]
assert expected_fragment in sent_message
@pytest.mark.asyncio
async def test_thread_id_passed_to_send(monkeypatch, tmp_path):
"""thread_id from watcher dict is forwarded as metadata to adapter.send()."""
import tools.process_registry as pr_module
sessions = [SimpleNamespace(output_buffer="done\n", exited=True, exit_code=0)]
monkeypatch.setattr(pr_module, "process_registry", _FakeRegistry(sessions))
async def _instant_sleep(*_a, **_kw):
pass
monkeypatch.setattr(asyncio, "sleep", _instant_sleep)
runner = _build_runner(monkeypatch, tmp_path, "all")
adapter = runner.adapters[Platform.TELEGRAM]
await runner._run_process_watcher(_watcher_dict(thread_id="42"))
assert adapter.send.await_count == 1
_, kwargs = adapter.send.call_args
assert kwargs["metadata"] == {"thread_id": "42"}
@pytest.mark.asyncio
async def test_no_thread_id_sends_no_metadata(monkeypatch, tmp_path):
"""When thread_id is empty, metadata should be None (general topic)."""
import tools.process_registry as pr_module
sessions = [SimpleNamespace(output_buffer="done\n", exited=True, exit_code=0)]
monkeypatch.setattr(pr_module, "process_registry", _FakeRegistry(sessions))
async def _instant_sleep(*_a, **_kw):
pass
monkeypatch.setattr(asyncio, "sleep", _instant_sleep)
runner = _build_runner(monkeypatch, tmp_path, "all")
adapter = runner.adapters[Platform.TELEGRAM]
await runner._run_process_watcher(_watcher_dict())
assert adapter.send.await_count == 1
_, kwargs = adapter.send.call_args
assert kwargs["metadata"] is None

View file

@ -0,0 +1,274 @@
"""Tests for DingTalk platform adapter."""
import asyncio
import json
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch, PropertyMock
import pytest
from gateway.config import Platform, PlatformConfig
# ---------------------------------------------------------------------------
# Requirements check
# ---------------------------------------------------------------------------
class TestDingTalkRequirements:
def test_returns_false_when_sdk_missing(self, monkeypatch):
with patch.dict("sys.modules", {"dingtalk_stream": None}):
monkeypatch.setattr(
"gateway.platforms.dingtalk.DINGTALK_STREAM_AVAILABLE", False
)
from gateway.platforms.dingtalk import check_dingtalk_requirements
assert check_dingtalk_requirements() is False
def test_returns_false_when_env_vars_missing(self, monkeypatch):
monkeypatch.setattr(
"gateway.platforms.dingtalk.DINGTALK_STREAM_AVAILABLE", True
)
monkeypatch.setattr("gateway.platforms.dingtalk.HTTPX_AVAILABLE", True)
monkeypatch.delenv("DINGTALK_CLIENT_ID", raising=False)
monkeypatch.delenv("DINGTALK_CLIENT_SECRET", raising=False)
from gateway.platforms.dingtalk import check_dingtalk_requirements
assert check_dingtalk_requirements() is False
def test_returns_true_when_all_available(self, monkeypatch):
monkeypatch.setattr(
"gateway.platforms.dingtalk.DINGTALK_STREAM_AVAILABLE", True
)
monkeypatch.setattr("gateway.platforms.dingtalk.HTTPX_AVAILABLE", True)
monkeypatch.setenv("DINGTALK_CLIENT_ID", "test-id")
monkeypatch.setenv("DINGTALK_CLIENT_SECRET", "test-secret")
from gateway.platforms.dingtalk import check_dingtalk_requirements
assert check_dingtalk_requirements() is True
# ---------------------------------------------------------------------------
# Adapter construction
# ---------------------------------------------------------------------------
class TestDingTalkAdapterInit:
def test_reads_config_from_extra(self):
from gateway.platforms.dingtalk import DingTalkAdapter
config = PlatformConfig(
enabled=True,
extra={"client_id": "cfg-id", "client_secret": "cfg-secret"},
)
adapter = DingTalkAdapter(config)
assert adapter._client_id == "cfg-id"
assert adapter._client_secret == "cfg-secret"
assert adapter.name == "Dingtalk" # base class uses .title()
def test_falls_back_to_env_vars(self, monkeypatch):
monkeypatch.setenv("DINGTALK_CLIENT_ID", "env-id")
monkeypatch.setenv("DINGTALK_CLIENT_SECRET", "env-secret")
from gateway.platforms.dingtalk import DingTalkAdapter
config = PlatformConfig(enabled=True)
adapter = DingTalkAdapter(config)
assert adapter._client_id == "env-id"
assert adapter._client_secret == "env-secret"
# ---------------------------------------------------------------------------
# Message text extraction
# ---------------------------------------------------------------------------
class TestExtractText:
def test_extracts_dict_text(self):
from gateway.platforms.dingtalk import DingTalkAdapter
msg = MagicMock()
msg.text = {"content": " hello world "}
msg.rich_text = None
assert DingTalkAdapter._extract_text(msg) == "hello world"
def test_extracts_string_text(self):
from gateway.platforms.dingtalk import DingTalkAdapter
msg = MagicMock()
msg.text = "plain text"
msg.rich_text = None
assert DingTalkAdapter._extract_text(msg) == "plain text"
def test_falls_back_to_rich_text(self):
from gateway.platforms.dingtalk import DingTalkAdapter
msg = MagicMock()
msg.text = ""
msg.rich_text = [{"text": "part1"}, {"text": "part2"}, {"image": "url"}]
assert DingTalkAdapter._extract_text(msg) == "part1 part2"
def test_returns_empty_for_no_content(self):
from gateway.platforms.dingtalk import DingTalkAdapter
msg = MagicMock()
msg.text = ""
msg.rich_text = None
assert DingTalkAdapter._extract_text(msg) == ""
# ---------------------------------------------------------------------------
# Deduplication
# ---------------------------------------------------------------------------
class TestDeduplication:
def test_first_message_not_duplicate(self):
from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
assert adapter._is_duplicate("msg-1") is False
def test_second_same_message_is_duplicate(self):
from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
adapter._is_duplicate("msg-1")
assert adapter._is_duplicate("msg-1") is True
def test_different_messages_not_duplicate(self):
from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
adapter._is_duplicate("msg-1")
assert adapter._is_duplicate("msg-2") is False
def test_cache_cleanup_on_overflow(self):
from gateway.platforms.dingtalk import DingTalkAdapter, DEDUP_MAX_SIZE
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
# Fill beyond max
for i in range(DEDUP_MAX_SIZE + 10):
adapter._is_duplicate(f"msg-{i}")
# Cache should have been pruned
assert len(adapter._seen_messages) <= DEDUP_MAX_SIZE + 10
# ---------------------------------------------------------------------------
# Send
# ---------------------------------------------------------------------------
class TestSend:
@pytest.mark.asyncio
async def test_send_posts_to_webhook(self):
from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.text = "OK"
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_response)
adapter._http_client = mock_client
result = await adapter.send(
"chat-123", "Hello!",
metadata={"session_webhook": "https://dingtalk.example/webhook"}
)
assert result.success is True
mock_client.post.assert_called_once()
call_args = mock_client.post.call_args
assert call_args[0][0] == "https://dingtalk.example/webhook"
payload = call_args[1]["json"]
assert payload["msgtype"] == "markdown"
assert payload["markdown"]["title"] == "Hermes"
assert payload["markdown"]["text"] == "Hello!"
@pytest.mark.asyncio
async def test_send_fails_without_webhook(self):
from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
adapter._http_client = AsyncMock()
result = await adapter.send("chat-123", "Hello!")
assert result.success is False
assert "session_webhook" in result.error
@pytest.mark.asyncio
async def test_send_uses_cached_webhook(self):
from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
mock_response = MagicMock()
mock_response.status_code = 200
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_response)
adapter._http_client = mock_client
adapter._session_webhooks["chat-123"] = "https://cached.example/webhook"
result = await adapter.send("chat-123", "Hello!")
assert result.success is True
assert mock_client.post.call_args[0][0] == "https://cached.example/webhook"
@pytest.mark.asyncio
async def test_send_handles_http_error(self):
from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
mock_response = MagicMock()
mock_response.status_code = 400
mock_response.text = "Bad Request"
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_response)
adapter._http_client = mock_client
result = await adapter.send(
"chat-123", "Hello!",
metadata={"session_webhook": "https://example/webhook"}
)
assert result.success is False
assert "400" in result.error
# ---------------------------------------------------------------------------
# Connect / disconnect
# ---------------------------------------------------------------------------
class TestConnect:
@pytest.mark.asyncio
async def test_connect_fails_without_sdk(self, monkeypatch):
monkeypatch.setattr(
"gateway.platforms.dingtalk.DINGTALK_STREAM_AVAILABLE", False
)
from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
result = await adapter.connect()
assert result is False
@pytest.mark.asyncio
async def test_connect_fails_without_credentials(self):
from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
adapter._client_id = ""
adapter._client_secret = ""
result = await adapter.connect()
assert result is False
@pytest.mark.asyncio
async def test_disconnect_cleans_up(self):
from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
adapter._session_webhooks["a"] = "http://x"
adapter._seen_messages["b"] = 1.0
adapter._http_client = AsyncMock()
adapter._stream_task = None
await adapter.disconnect()
assert len(adapter._session_webhooks) == 0
assert len(adapter._seen_messages) == 0
assert adapter._http_client is None
# ---------------------------------------------------------------------------
# Platform enum
# ---------------------------------------------------------------------------
class TestPlatformEnum:
def test_dingtalk_in_platform_enum(self):
assert Platform.DINGTALK.value == "dingtalk"

View file

@ -0,0 +1,448 @@
"""Tests for Matrix platform adapter."""
import json
import re
import pytest
from unittest.mock import MagicMock, patch, AsyncMock
from gateway.config import Platform, PlatformConfig
# ---------------------------------------------------------------------------
# Platform & Config
# ---------------------------------------------------------------------------
class TestMatrixPlatformEnum:
def test_matrix_enum_exists(self):
assert Platform.MATRIX.value == "matrix"
def test_matrix_in_platform_list(self):
platforms = [p.value for p in Platform]
assert "matrix" in platforms
class TestMatrixConfigLoading:
def test_apply_env_overrides_with_access_token(self, monkeypatch):
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_abc123")
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
assert Platform.MATRIX in config.platforms
mc = config.platforms[Platform.MATRIX]
assert mc.enabled is True
assert mc.token == "syt_abc123"
assert mc.extra.get("homeserver") == "https://matrix.example.org"
def test_apply_env_overrides_with_password(self, monkeypatch):
monkeypatch.delenv("MATRIX_ACCESS_TOKEN", raising=False)
monkeypatch.setenv("MATRIX_PASSWORD", "secret123")
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
monkeypatch.setenv("MATRIX_USER_ID", "@bot:example.org")
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
assert Platform.MATRIX in config.platforms
mc = config.platforms[Platform.MATRIX]
assert mc.enabled is True
assert mc.extra.get("password") == "secret123"
assert mc.extra.get("user_id") == "@bot:example.org"
def test_matrix_not_loaded_without_creds(self, monkeypatch):
monkeypatch.delenv("MATRIX_ACCESS_TOKEN", raising=False)
monkeypatch.delenv("MATRIX_PASSWORD", raising=False)
monkeypatch.delenv("MATRIX_HOMESERVER", raising=False)
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
assert Platform.MATRIX not in config.platforms
def test_matrix_encryption_flag(self, monkeypatch):
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_abc123")
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
monkeypatch.setenv("MATRIX_ENCRYPTION", "true")
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
mc = config.platforms[Platform.MATRIX]
assert mc.extra.get("encryption") is True
def test_matrix_encryption_default_off(self, monkeypatch):
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_abc123")
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
monkeypatch.delenv("MATRIX_ENCRYPTION", raising=False)
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
mc = config.platforms[Platform.MATRIX]
assert mc.extra.get("encryption") is False
def test_matrix_home_room(self, monkeypatch):
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_abc123")
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
monkeypatch.setenv("MATRIX_HOME_ROOM", "!room123:example.org")
monkeypatch.setenv("MATRIX_HOME_ROOM_NAME", "Bot Room")
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
home = config.get_home_channel(Platform.MATRIX)
assert home is not None
assert home.chat_id == "!room123:example.org"
assert home.name == "Bot Room"
def test_matrix_user_id_stored_in_extra(self, monkeypatch):
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_abc123")
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
monkeypatch.setenv("MATRIX_USER_ID", "@hermes:example.org")
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
mc = config.platforms[Platform.MATRIX]
assert mc.extra.get("user_id") == "@hermes:example.org"
# ---------------------------------------------------------------------------
# Adapter helpers
# ---------------------------------------------------------------------------
def _make_adapter():
"""Create a MatrixAdapter with mocked config."""
from gateway.platforms.matrix import MatrixAdapter
config = PlatformConfig(
enabled=True,
token="syt_test_token",
extra={
"homeserver": "https://matrix.example.org",
"user_id": "@bot:example.org",
},
)
adapter = MatrixAdapter(config)
return adapter
# ---------------------------------------------------------------------------
# mxc:// URL conversion
# ---------------------------------------------------------------------------
class TestMatrixMxcToHttp:
def setup_method(self):
self.adapter = _make_adapter()
def test_basic_mxc_conversion(self):
"""mxc://server/media_id should become an authenticated HTTP URL."""
mxc = "mxc://matrix.org/abc123"
result = self.adapter._mxc_to_http(mxc)
assert result == "https://matrix.example.org/_matrix/client/v1/media/download/matrix.org/abc123"
def test_mxc_with_different_server(self):
"""mxc:// from a different server should still use our homeserver."""
mxc = "mxc://other.server/media456"
result = self.adapter._mxc_to_http(mxc)
assert result.startswith("https://matrix.example.org/")
assert "other.server/media456" in result
def test_non_mxc_url_passthrough(self):
"""Non-mxc URLs should be returned unchanged."""
url = "https://example.com/image.png"
assert self.adapter._mxc_to_http(url) == url
def test_mxc_uses_client_v1_endpoint(self):
"""Should use /_matrix/client/v1/media/download/ not the deprecated path."""
mxc = "mxc://example.com/test123"
result = self.adapter._mxc_to_http(mxc)
assert "/_matrix/client/v1/media/download/" in result
assert "/_matrix/media/v3/download/" not in result
# ---------------------------------------------------------------------------
# DM detection
# ---------------------------------------------------------------------------
class TestMatrixDmDetection:
def setup_method(self):
self.adapter = _make_adapter()
def test_room_in_m_direct_is_dm(self):
"""A room listed in m.direct should be detected as DM."""
self.adapter._joined_rooms = {"!dm_room:ex.org", "!group_room:ex.org"}
self.adapter._dm_rooms = {
"!dm_room:ex.org": True,
"!group_room:ex.org": False,
}
assert self.adapter._dm_rooms.get("!dm_room:ex.org") is True
assert self.adapter._dm_rooms.get("!group_room:ex.org") is False
def test_unknown_room_not_in_cache(self):
"""Unknown rooms should not be in the DM cache."""
self.adapter._dm_rooms = {}
assert self.adapter._dm_rooms.get("!unknown:ex.org") is None
@pytest.mark.asyncio
async def test_refresh_dm_cache_with_m_direct(self):
"""_refresh_dm_cache should populate _dm_rooms from m.direct data."""
self.adapter._joined_rooms = {"!room_a:ex.org", "!room_b:ex.org", "!room_c:ex.org"}
mock_client = MagicMock()
mock_resp = MagicMock()
mock_resp.content = {
"@alice:ex.org": ["!room_a:ex.org"],
"@bob:ex.org": ["!room_b:ex.org"],
}
mock_client.get_account_data = AsyncMock(return_value=mock_resp)
self.adapter._client = mock_client
await self.adapter._refresh_dm_cache()
assert self.adapter._dm_rooms["!room_a:ex.org"] is True
assert self.adapter._dm_rooms["!room_b:ex.org"] is True
assert self.adapter._dm_rooms["!room_c:ex.org"] is False
# ---------------------------------------------------------------------------
# Reply fallback stripping
# ---------------------------------------------------------------------------
class TestMatrixReplyFallbackStripping:
"""Test that Matrix reply fallback lines ('> ' prefix) are stripped."""
def setup_method(self):
self.adapter = _make_adapter()
self.adapter._user_id = "@bot:example.org"
self.adapter._startup_ts = 0.0
self.adapter._dm_rooms = {}
self.adapter._message_handler = AsyncMock()
def _strip_fallback(self, body: str, has_reply: bool = True) -> str:
"""Simulate the reply fallback stripping logic from _on_room_message."""
reply_to = "some_event_id" if has_reply else None
if reply_to and body.startswith("> "):
lines = body.split("\n")
stripped = []
past_fallback = False
for line in lines:
if not past_fallback:
if line.startswith("> ") or line == ">":
continue
if line == "":
past_fallback = True
continue
past_fallback = True
stripped.append(line)
body = "\n".join(stripped) if stripped else body
return body
def test_simple_reply_fallback(self):
body = "> <@alice:ex.org> Original message\n\nActual reply"
result = self._strip_fallback(body)
assert result == "Actual reply"
def test_multiline_reply_fallback(self):
body = "> <@alice:ex.org> Line 1\n> Line 2\n\nMy response"
result = self._strip_fallback(body)
assert result == "My response"
def test_no_reply_fallback_preserved(self):
body = "Just a normal message"
result = self._strip_fallback(body, has_reply=False)
assert result == "Just a normal message"
def test_quote_without_reply_preserved(self):
"""'> ' lines without a reply_to context should be preserved."""
body = "> This is a blockquote"
result = self._strip_fallback(body, has_reply=False)
assert result == "> This is a blockquote"
def test_empty_fallback_separator(self):
"""The blank line between fallback and actual content should be stripped."""
body = "> <@alice:ex.org> hi\n>\n\nResponse"
result = self._strip_fallback(body)
assert result == "Response"
def test_multiline_response_after_fallback(self):
body = "> <@alice:ex.org> Original\n\nLine 1\nLine 2\nLine 3"
result = self._strip_fallback(body)
assert result == "Line 1\nLine 2\nLine 3"
# ---------------------------------------------------------------------------
# Thread detection
# ---------------------------------------------------------------------------
class TestMatrixThreadDetection:
def test_thread_id_from_m_relates_to(self):
"""m.relates_to with rel_type=m.thread should extract the event_id."""
relates_to = {
"rel_type": "m.thread",
"event_id": "$thread_root_event",
"is_falling_back": True,
"m.in_reply_to": {"event_id": "$some_event"},
}
# Simulate the extraction logic from _on_room_message
thread_id = None
if relates_to.get("rel_type") == "m.thread":
thread_id = relates_to.get("event_id")
assert thread_id == "$thread_root_event"
def test_no_thread_for_reply(self):
"""m.in_reply_to without m.thread should not set thread_id."""
relates_to = {
"m.in_reply_to": {"event_id": "$reply_event"},
}
thread_id = None
if relates_to.get("rel_type") == "m.thread":
thread_id = relates_to.get("event_id")
assert thread_id is None
def test_no_thread_for_edit(self):
"""m.replace relation should not set thread_id."""
relates_to = {
"rel_type": "m.replace",
"event_id": "$edited_event",
}
thread_id = None
if relates_to.get("rel_type") == "m.thread":
thread_id = relates_to.get("event_id")
assert thread_id is None
def test_empty_relates_to(self):
"""Empty m.relates_to should not set thread_id."""
relates_to = {}
thread_id = None
if relates_to.get("rel_type") == "m.thread":
thread_id = relates_to.get("event_id")
assert thread_id is None
# ---------------------------------------------------------------------------
# Format message
# ---------------------------------------------------------------------------
class TestMatrixFormatMessage:
def setup_method(self):
self.adapter = _make_adapter()
def test_image_markdown_stripped(self):
"""![alt](url) should be converted to just the URL."""
result = self.adapter.format_message("![cat](https://img.example.com/cat.png)")
assert result == "https://img.example.com/cat.png"
def test_regular_markdown_preserved(self):
"""Standard markdown should be preserved (Matrix supports it)."""
content = "**bold** and *italic* and `code`"
assert self.adapter.format_message(content) == content
def test_plain_text_unchanged(self):
content = "Hello, world!"
assert self.adapter.format_message(content) == content
def test_multiple_images_stripped(self):
content = "![a](http://a.com/1.png) and ![b](http://b.com/2.png)"
result = self.adapter.format_message(content)
assert "![" not in result
assert "http://a.com/1.png" in result
assert "http://b.com/2.png" in result
# ---------------------------------------------------------------------------
# Markdown to HTML conversion
# ---------------------------------------------------------------------------
class TestMatrixMarkdownToHtml:
def setup_method(self):
self.adapter = _make_adapter()
def test_bold_conversion(self):
"""**bold** should produce <strong> tags."""
result = self.adapter._markdown_to_html("**bold**")
assert "<strong>" in result or "<b>" in result
assert "bold" in result
def test_italic_conversion(self):
"""*italic* should produce <em> tags."""
result = self.adapter._markdown_to_html("*italic*")
assert "<em>" in result or "<i>" in result
def test_inline_code(self):
"""`code` should produce <code> tags."""
result = self.adapter._markdown_to_html("`code`")
assert "<code>" in result
def test_plain_text_returns_html(self):
"""Plain text should still be returned (possibly with <br> or <p>)."""
result = self.adapter._markdown_to_html("Hello world")
assert "Hello world" in result
# ---------------------------------------------------------------------------
# Helper: display name extraction
# ---------------------------------------------------------------------------
class TestMatrixDisplayName:
def setup_method(self):
self.adapter = _make_adapter()
def test_get_display_name_from_room_users(self):
"""Should get display name from room's users dict."""
mock_room = MagicMock()
mock_user = MagicMock()
mock_user.display_name = "Alice"
mock_room.users = {"@alice:ex.org": mock_user}
name = self.adapter._get_display_name(mock_room, "@alice:ex.org")
assert name == "Alice"
def test_get_display_name_fallback_to_localpart(self):
"""Should extract localpart from @user:server format."""
mock_room = MagicMock()
mock_room.users = {}
name = self.adapter._get_display_name(mock_room, "@bob:example.org")
assert name == "bob"
def test_get_display_name_no_room(self):
"""Should handle None room gracefully."""
name = self.adapter._get_display_name(None, "@charlie:ex.org")
assert name == "charlie"
# ---------------------------------------------------------------------------
# Requirements check
# ---------------------------------------------------------------------------
class TestMatrixRequirements:
def test_check_requirements_with_token(self, monkeypatch):
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_test")
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
from gateway.platforms.matrix import check_matrix_requirements
try:
import nio # noqa: F401
assert check_matrix_requirements() is True
except ImportError:
assert check_matrix_requirements() is False
def test_check_requirements_without_creds(self, monkeypatch):
monkeypatch.delenv("MATRIX_ACCESS_TOKEN", raising=False)
monkeypatch.delenv("MATRIX_PASSWORD", raising=False)
monkeypatch.delenv("MATRIX_HOMESERVER", raising=False)
from gateway.platforms.matrix import check_matrix_requirements
assert check_matrix_requirements() is False
def test_check_requirements_without_homeserver(self, monkeypatch):
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_test")
monkeypatch.delenv("MATRIX_HOMESERVER", raising=False)
from gateway.platforms.matrix import check_matrix_requirements
assert check_matrix_requirements() is False

View file

@ -0,0 +1,574 @@
"""Tests for Mattermost platform adapter."""
import json
import time
import pytest
from unittest.mock import MagicMock, patch, AsyncMock
from gateway.config import Platform, PlatformConfig
# ---------------------------------------------------------------------------
# Platform & Config
# ---------------------------------------------------------------------------
class TestMattermostPlatformEnum:
def test_mattermost_enum_exists(self):
assert Platform.MATTERMOST.value == "mattermost"
def test_mattermost_in_platform_list(self):
platforms = [p.value for p in Platform]
assert "mattermost" in platforms
class TestMattermostConfigLoading:
def test_apply_env_overrides_mattermost(self, monkeypatch):
monkeypatch.setenv("MATTERMOST_TOKEN", "mm-tok-abc123")
monkeypatch.setenv("MATTERMOST_URL", "https://mm.example.com")
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
assert Platform.MATTERMOST in config.platforms
mc = config.platforms[Platform.MATTERMOST]
assert mc.enabled is True
assert mc.token == "mm-tok-abc123"
assert mc.extra.get("url") == "https://mm.example.com"
def test_mattermost_not_loaded_without_token(self, monkeypatch):
monkeypatch.delenv("MATTERMOST_TOKEN", raising=False)
monkeypatch.delenv("MATTERMOST_URL", raising=False)
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
assert Platform.MATTERMOST not in config.platforms
def test_connected_platforms_includes_mattermost(self, monkeypatch):
monkeypatch.setenv("MATTERMOST_TOKEN", "mm-tok-abc123")
monkeypatch.setenv("MATTERMOST_URL", "https://mm.example.com")
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
connected = config.get_connected_platforms()
assert Platform.MATTERMOST in connected
def test_mattermost_home_channel(self, monkeypatch):
monkeypatch.setenv("MATTERMOST_TOKEN", "mm-tok-abc123")
monkeypatch.setenv("MATTERMOST_URL", "https://mm.example.com")
monkeypatch.setenv("MATTERMOST_HOME_CHANNEL", "ch_abc123")
monkeypatch.setenv("MATTERMOST_HOME_CHANNEL_NAME", "General")
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
home = config.get_home_channel(Platform.MATTERMOST)
assert home is not None
assert home.chat_id == "ch_abc123"
assert home.name == "General"
def test_mattermost_url_warning_without_url(self, monkeypatch):
"""MATTERMOST_TOKEN set but MATTERMOST_URL missing should still load."""
monkeypatch.setenv("MATTERMOST_TOKEN", "mm-tok-abc123")
monkeypatch.delenv("MATTERMOST_URL", raising=False)
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
assert Platform.MATTERMOST in config.platforms
assert config.platforms[Platform.MATTERMOST].extra.get("url") == ""
# ---------------------------------------------------------------------------
# Adapter format / truncate
# ---------------------------------------------------------------------------
def _make_adapter():
"""Create a MattermostAdapter with mocked config."""
from gateway.platforms.mattermost import MattermostAdapter
config = PlatformConfig(
enabled=True,
token="test-token",
extra={"url": "https://mm.example.com"},
)
adapter = MattermostAdapter(config)
return adapter
class TestMattermostFormatMessage:
def setup_method(self):
self.adapter = _make_adapter()
def test_image_markdown_to_url(self):
"""![alt](url) should be converted to just the URL."""
result = self.adapter.format_message("![cat](https://img.example.com/cat.png)")
assert result == "https://img.example.com/cat.png"
def test_image_markdown_strips_alt_text(self):
result = self.adapter.format_message("Here: ![my image](https://x.com/a.jpg) done")
assert "![" not in result
assert "https://x.com/a.jpg" in result
def test_regular_markdown_preserved(self):
"""Regular markdown (bold, italic, code) should be kept as-is."""
content = "**bold** and *italic* and `code`"
assert self.adapter.format_message(content) == content
def test_regular_links_preserved(self):
"""Non-image links should be preserved."""
content = "[click](https://example.com)"
assert self.adapter.format_message(content) == content
def test_plain_text_unchanged(self):
content = "Hello, world!"
assert self.adapter.format_message(content) == content
def test_multiple_images(self):
content = "![a](http://a.com/1.png) text ![b](http://b.com/2.png)"
result = self.adapter.format_message(content)
assert "![" not in result
assert "http://a.com/1.png" in result
assert "http://b.com/2.png" in result
class TestMattermostTruncateMessage:
def setup_method(self):
self.adapter = _make_adapter()
def test_short_message_single_chunk(self):
msg = "Hello, world!"
chunks = self.adapter.truncate_message(msg, 4000)
assert len(chunks) == 1
assert chunks[0] == msg
def test_long_message_splits(self):
msg = "a " * 2500 # 5000 chars
chunks = self.adapter.truncate_message(msg, 4000)
assert len(chunks) >= 2
for chunk in chunks:
assert len(chunk) <= 4000
def test_custom_max_length(self):
msg = "Hello " * 20
chunks = self.adapter.truncate_message(msg, max_length=50)
assert all(len(c) <= 50 for c in chunks)
def test_exactly_at_limit(self):
msg = "x" * 4000
chunks = self.adapter.truncate_message(msg, 4000)
assert len(chunks) == 1
# ---------------------------------------------------------------------------
# Send
# ---------------------------------------------------------------------------
class TestMattermostSend:
def setup_method(self):
self.adapter = _make_adapter()
self.adapter._session = MagicMock()
@pytest.mark.asyncio
async def test_send_calls_api_post(self):
"""send() should POST to /api/v4/posts with channel_id and message."""
mock_resp = AsyncMock()
mock_resp.status = 200
mock_resp.json = AsyncMock(return_value={"id": "post123"})
mock_resp.text = AsyncMock(return_value="")
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
mock_resp.__aexit__ = AsyncMock(return_value=False)
self.adapter._session.post = MagicMock(return_value=mock_resp)
result = await self.adapter.send("channel_1", "Hello!")
assert result.success is True
assert result.message_id == "post123"
# Verify post was called with correct URL
call_args = self.adapter._session.post.call_args
assert "/api/v4/posts" in call_args[0][0]
# Verify payload
payload = call_args[1]["json"]
assert payload["channel_id"] == "channel_1"
assert payload["message"] == "Hello!"
@pytest.mark.asyncio
async def test_send_empty_content_succeeds(self):
"""Empty content should return success without calling the API."""
result = await self.adapter.send("channel_1", "")
assert result.success is True
@pytest.mark.asyncio
async def test_send_with_thread_reply(self):
"""When reply_mode is 'thread', reply_to should become root_id."""
self.adapter._reply_mode = "thread"
mock_resp = AsyncMock()
mock_resp.status = 200
mock_resp.json = AsyncMock(return_value={"id": "post456"})
mock_resp.text = AsyncMock(return_value="")
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
mock_resp.__aexit__ = AsyncMock(return_value=False)
self.adapter._session.post = MagicMock(return_value=mock_resp)
result = await self.adapter.send("channel_1", "Reply!", reply_to="root_post")
assert result.success is True
payload = self.adapter._session.post.call_args[1]["json"]
assert payload["root_id"] == "root_post"
@pytest.mark.asyncio
async def test_send_without_thread_no_root_id(self):
"""When reply_mode is 'off', reply_to should NOT set root_id."""
self.adapter._reply_mode = "off"
mock_resp = AsyncMock()
mock_resp.status = 200
mock_resp.json = AsyncMock(return_value={"id": "post789"})
mock_resp.text = AsyncMock(return_value="")
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
mock_resp.__aexit__ = AsyncMock(return_value=False)
self.adapter._session.post = MagicMock(return_value=mock_resp)
result = await self.adapter.send("channel_1", "Reply!", reply_to="root_post")
assert result.success is True
payload = self.adapter._session.post.call_args[1]["json"]
assert "root_id" not in payload
@pytest.mark.asyncio
async def test_send_api_failure(self):
"""When API returns error, send should return failure."""
mock_resp = AsyncMock()
mock_resp.status = 500
mock_resp.json = AsyncMock(return_value={})
mock_resp.text = AsyncMock(return_value="Internal Server Error")
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
mock_resp.__aexit__ = AsyncMock(return_value=False)
self.adapter._session.post = MagicMock(return_value=mock_resp)
result = await self.adapter.send("channel_1", "Hello!")
assert result.success is False
# ---------------------------------------------------------------------------
# WebSocket event parsing
# ---------------------------------------------------------------------------
class TestMattermostWebSocketParsing:
def setup_method(self):
self.adapter = _make_adapter()
self.adapter._bot_user_id = "bot_user_id"
# Mock handle_message to capture the MessageEvent without processing
self.adapter.handle_message = AsyncMock()
@pytest.mark.asyncio
async def test_parse_posted_event(self):
"""'posted' events should extract message from double-encoded post JSON."""
post_data = {
"id": "post_abc",
"user_id": "user_123",
"channel_id": "chan_456",
"message": "Hello from Matrix!",
}
event = {
"event": "posted",
"data": {
"post": json.dumps(post_data), # double-encoded JSON string
"channel_type": "O",
"sender_name": "@alice",
},
}
await self.adapter._handle_ws_event(event)
assert self.adapter.handle_message.called
msg_event = self.adapter.handle_message.call_args[0][0]
assert msg_event.text == "Hello from Matrix!"
assert msg_event.message_id == "post_abc"
@pytest.mark.asyncio
async def test_ignore_own_messages(self):
"""Messages from the bot's own user_id should be ignored."""
post_data = {
"id": "post_self",
"user_id": "bot_user_id", # same as bot
"channel_id": "chan_456",
"message": "Bot echo",
}
event = {
"event": "posted",
"data": {
"post": json.dumps(post_data),
"channel_type": "O",
},
}
await self.adapter._handle_ws_event(event)
assert not self.adapter.handle_message.called
@pytest.mark.asyncio
async def test_ignore_non_posted_events(self):
"""Non-'posted' events should be ignored."""
event = {
"event": "typing",
"data": {"user_id": "user_123"},
}
await self.adapter._handle_ws_event(event)
assert not self.adapter.handle_message.called
@pytest.mark.asyncio
async def test_ignore_system_posts(self):
"""Posts with a 'type' field (system messages) should be ignored."""
post_data = {
"id": "sys_post",
"user_id": "user_123",
"channel_id": "chan_456",
"message": "user joined",
"type": "system_join_channel",
}
event = {
"event": "posted",
"data": {
"post": json.dumps(post_data),
"channel_type": "O",
},
}
await self.adapter._handle_ws_event(event)
assert not self.adapter.handle_message.called
@pytest.mark.asyncio
async def test_channel_type_mapping(self):
"""channel_type 'D' should map to 'dm'."""
post_data = {
"id": "post_dm",
"user_id": "user_123",
"channel_id": "chan_dm",
"message": "DM message",
}
event = {
"event": "posted",
"data": {
"post": json.dumps(post_data),
"channel_type": "D",
"sender_name": "@bob",
},
}
await self.adapter._handle_ws_event(event)
assert self.adapter.handle_message.called
msg_event = self.adapter.handle_message.call_args[0][0]
assert msg_event.source.chat_type == "dm"
@pytest.mark.asyncio
async def test_thread_id_from_root_id(self):
"""Post with root_id should have thread_id set."""
post_data = {
"id": "post_reply",
"user_id": "user_123",
"channel_id": "chan_456",
"message": "Thread reply",
"root_id": "root_post_123",
}
event = {
"event": "posted",
"data": {
"post": json.dumps(post_data),
"channel_type": "O",
"sender_name": "@alice",
},
}
await self.adapter._handle_ws_event(event)
assert self.adapter.handle_message.called
msg_event = self.adapter.handle_message.call_args[0][0]
assert msg_event.source.thread_id == "root_post_123"
@pytest.mark.asyncio
async def test_invalid_post_json_ignored(self):
"""Invalid JSON in data.post should be silently ignored."""
event = {
"event": "posted",
"data": {
"post": "not-valid-json{{{",
"channel_type": "O",
},
}
await self.adapter._handle_ws_event(event)
assert not self.adapter.handle_message.called
# ---------------------------------------------------------------------------
# File upload (send_image)
# ---------------------------------------------------------------------------
class TestMattermostFileUpload:
def setup_method(self):
self.adapter = _make_adapter()
self.adapter._session = MagicMock()
@pytest.mark.asyncio
async def test_send_image_downloads_and_uploads(self):
"""send_image should download the URL, upload via /api/v4/files, then post."""
# Mock the download (GET)
mock_dl_resp = AsyncMock()
mock_dl_resp.status = 200
mock_dl_resp.read = AsyncMock(return_value=b"\x89PNG\x00fake-image-data")
mock_dl_resp.content_type = "image/png"
mock_dl_resp.__aenter__ = AsyncMock(return_value=mock_dl_resp)
mock_dl_resp.__aexit__ = AsyncMock(return_value=False)
# Mock the upload (POST to /files)
mock_upload_resp = AsyncMock()
mock_upload_resp.status = 200
mock_upload_resp.json = AsyncMock(return_value={
"file_infos": [{"id": "file_abc123"}]
})
mock_upload_resp.text = AsyncMock(return_value="")
mock_upload_resp.__aenter__ = AsyncMock(return_value=mock_upload_resp)
mock_upload_resp.__aexit__ = AsyncMock(return_value=False)
# Mock the post (POST to /posts)
mock_post_resp = AsyncMock()
mock_post_resp.status = 200
mock_post_resp.json = AsyncMock(return_value={"id": "post_with_file"})
mock_post_resp.text = AsyncMock(return_value="")
mock_post_resp.__aenter__ = AsyncMock(return_value=mock_post_resp)
mock_post_resp.__aexit__ = AsyncMock(return_value=False)
# Route calls: first GET (download), then POST (upload), then POST (create post)
self.adapter._session.get = MagicMock(return_value=mock_dl_resp)
post_call_count = 0
original_post_returns = [mock_upload_resp, mock_post_resp]
def post_side_effect(*args, **kwargs):
nonlocal post_call_count
resp = original_post_returns[min(post_call_count, len(original_post_returns) - 1)]
post_call_count += 1
return resp
self.adapter._session.post = MagicMock(side_effect=post_side_effect)
result = await self.adapter.send_image(
"channel_1", "https://img.example.com/cat.png", caption="A cat"
)
assert result.success is True
assert result.message_id == "post_with_file"
# ---------------------------------------------------------------------------
# Dedup cache
# ---------------------------------------------------------------------------
class TestMattermostDedup:
def setup_method(self):
self.adapter = _make_adapter()
self.adapter._bot_user_id = "bot_user_id"
# Mock handle_message to capture calls without processing
self.adapter.handle_message = AsyncMock()
@pytest.mark.asyncio
async def test_duplicate_post_ignored(self):
"""The same post_id within the TTL window should be ignored."""
post_data = {
"id": "post_dup",
"user_id": "user_123",
"channel_id": "chan_456",
"message": "Hello!",
}
event = {
"event": "posted",
"data": {
"post": json.dumps(post_data),
"channel_type": "O",
"sender_name": "@alice",
},
}
# First time: should process
await self.adapter._handle_ws_event(event)
assert self.adapter.handle_message.call_count == 1
# Second time (same post_id): should be deduped
await self.adapter._handle_ws_event(event)
assert self.adapter.handle_message.call_count == 1 # still 1
@pytest.mark.asyncio
async def test_different_post_ids_both_processed(self):
"""Different post IDs should both be processed."""
for i, pid in enumerate(["post_a", "post_b"]):
post_data = {
"id": pid,
"user_id": "user_123",
"channel_id": "chan_456",
"message": f"Message {i}",
}
event = {
"event": "posted",
"data": {
"post": json.dumps(post_data),
"channel_type": "O",
"sender_name": "@alice",
},
}
await self.adapter._handle_ws_event(event)
assert self.adapter.handle_message.call_count == 2
def test_prune_seen_clears_expired(self):
"""_prune_seen should remove entries older than _SEEN_TTL."""
now = time.time()
# Fill with enough expired entries to trigger pruning
for i in range(self.adapter._SEEN_MAX + 10):
self.adapter._seen_posts[f"old_{i}"] = now - 600 # 10 min ago
# Add a fresh one
self.adapter._seen_posts["fresh"] = now
self.adapter._prune_seen()
# Old entries should be pruned, fresh one kept
assert "fresh" in self.adapter._seen_posts
assert len(self.adapter._seen_posts) < self.adapter._SEEN_MAX
def test_seen_cache_tracks_post_ids(self):
"""Posts are tracked in _seen_posts dict."""
self.adapter._seen_posts["test_post"] = time.time()
assert "test_post" in self.adapter._seen_posts
# ---------------------------------------------------------------------------
# Requirements check
# ---------------------------------------------------------------------------
class TestMattermostRequirements:
def test_check_requirements_with_token_and_url(self, monkeypatch):
monkeypatch.setenv("MATTERMOST_TOKEN", "test-token")
monkeypatch.setenv("MATTERMOST_URL", "https://mm.example.com")
from gateway.platforms.mattermost import check_mattermost_requirements
assert check_mattermost_requirements() is True
def test_check_requirements_without_token(self, monkeypatch):
monkeypatch.delenv("MATTERMOST_TOKEN", raising=False)
monkeypatch.delenv("MATTERMOST_URL", raising=False)
from gateway.platforms.mattermost import check_mattermost_requirements
assert check_mattermost_requirements() is False
def test_check_requirements_without_url(self, monkeypatch):
monkeypatch.setenv("MATTERMOST_TOKEN", "test-token")
monkeypatch.delenv("MATTERMOST_URL", raising=False)
from gateway.platforms.mattermost import check_mattermost_requirements
assert check_mattermost_requirements() is False

View file

@ -703,5 +703,15 @@ class TestLastPromptTokens:
store.update_session("k1", model="openai/gpt-5.4")
store._db.update_token_counts.assert_called_once_with(
"s1", 0, 0, model="openai/gpt-5.4"
"s1",
input_tokens=0,
output_tokens=0,
cache_read_tokens=0,
cache_write_tokens=0,
estimated_cost_usd=None,
cost_status=None,
cost_source=None,
billing_provider=None,
billing_base_url=None,
model="openai/gpt-5.4",
)

View file

@ -1,240 +1,215 @@
"""Tests for SMS (Telnyx) platform adapter."""
import json
"""Tests for SMS (Twilio) platform integration.
Covers config loading, format/truncate, echo prevention,
requirements check, and toolset verification.
"""
import os
from unittest.mock import patch
import pytest
from unittest.mock import MagicMock, patch, AsyncMock
from gateway.config import Platform, PlatformConfig
from gateway.config import Platform, PlatformConfig, HomeChannel
# ---------------------------------------------------------------------------
# Platform & Config
# ---------------------------------------------------------------------------
class TestSmsPlatformEnum:
def test_sms_enum_exists(self):
assert Platform.SMS.value == "sms"
def test_sms_in_platform_list(self):
platforms = [p.value for p in Platform]
assert "sms" in platforms
# ── Config loading ──────────────────────────────────────────────────
class TestSmsConfigLoading:
def test_apply_env_overrides_sms(self, monkeypatch):
monkeypatch.setenv("TELNYX_API_KEY", "KEY_test123")
"""Verify _apply_env_overrides wires SMS correctly."""
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
def test_sms_platform_enum_exists(self):
assert Platform.SMS.value == "sms"
assert Platform.SMS in config.platforms
sc = config.platforms[Platform.SMS]
assert sc.enabled is True
assert sc.api_key == "KEY_test123"
def test_env_overrides_create_sms_config(self):
from gateway.config import load_gateway_config
def test_sms_not_loaded_without_key(self, monkeypatch):
monkeypatch.delenv("TELNYX_API_KEY", raising=False)
env = {
"TWILIO_ACCOUNT_SID": "ACtest123",
"TWILIO_AUTH_TOKEN": "token_abc",
"TWILIO_PHONE_NUMBER": "+15551234567",
}
with patch.dict(os.environ, env, clear=False):
config = load_gateway_config()
assert Platform.SMS in config.platforms
pc = config.platforms[Platform.SMS]
assert pc.enabled is True
assert pc.api_key == "token_abc"
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
def test_env_overrides_set_home_channel(self):
from gateway.config import load_gateway_config
assert Platform.SMS not in config.platforms
env = {
"TWILIO_ACCOUNT_SID": "ACtest123",
"TWILIO_AUTH_TOKEN": "token_abc",
"TWILIO_PHONE_NUMBER": "+15551234567",
"SMS_HOME_CHANNEL": "+15559876543",
"SMS_HOME_CHANNEL_NAME": "My Phone",
}
with patch.dict(os.environ, env, clear=False):
config = load_gateway_config()
hc = config.platforms[Platform.SMS].home_channel
assert hc is not None
assert hc.chat_id == "+15559876543"
assert hc.name == "My Phone"
assert hc.platform == Platform.SMS
def test_connected_platforms_includes_sms(self, monkeypatch):
monkeypatch.setenv("TELNYX_API_KEY", "KEY_test123")
def test_sms_in_connected_platforms(self):
from gateway.config import load_gateway_config
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
connected = config.get_connected_platforms()
assert Platform.SMS in connected
def test_sms_home_channel(self, monkeypatch):
monkeypatch.setenv("TELNYX_API_KEY", "KEY_test123")
monkeypatch.setenv("SMS_HOME_CHANNEL", "+15559876543")
monkeypatch.setenv("SMS_HOME_CHANNEL_NAME", "Owner")
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
home = config.get_home_channel(Platform.SMS)
assert home is not None
assert home.chat_id == "+15559876543"
assert home.name == "Owner"
env = {
"TWILIO_ACCOUNT_SID": "ACtest123",
"TWILIO_AUTH_TOKEN": "token_abc",
}
with patch.dict(os.environ, env, clear=False):
config = load_gateway_config()
connected = config.get_connected_platforms()
assert Platform.SMS in connected
# ---------------------------------------------------------------------------
# Adapter format / truncate
# ---------------------------------------------------------------------------
# ── Format / truncate ───────────────────────────────────────────────
class TestSmsFormatMessage:
def setup_method(self):
class TestSmsFormatAndTruncate:
"""Test SmsAdapter.format_message strips markdown."""
def _make_adapter(self):
from gateway.platforms.sms import SmsAdapter
config = PlatformConfig(enabled=True, api_key="test_key")
with patch.dict("os.environ", {"TELNYX_API_KEY": "test_key"}):
self.adapter = SmsAdapter(config)
def test_strip_bold(self):
assert self.adapter.format_message("**bold**") == "bold"
env = {
"TWILIO_ACCOUNT_SID": "ACtest",
"TWILIO_AUTH_TOKEN": "tok",
"TWILIO_PHONE_NUMBER": "+15550001111",
}
with patch.dict(os.environ, env):
pc = PlatformConfig(enabled=True, api_key="tok")
adapter = object.__new__(SmsAdapter)
adapter.config = pc
adapter._platform = Platform.SMS
adapter._account_sid = "ACtest"
adapter._auth_token = "tok"
adapter._from_number = "+15550001111"
return adapter
def test_strip_italic(self):
assert self.adapter.format_message("*italic*") == "italic"
def test_strips_bold(self):
adapter = self._make_adapter()
assert adapter.format_message("**hello**") == "hello"
def test_strip_code_block(self):
result = self.adapter.format_message("```python\ncode\n```")
def test_strips_italic(self):
adapter = self._make_adapter()
assert adapter.format_message("*world*") == "world"
def test_strips_code_blocks(self):
adapter = self._make_adapter()
result = adapter.format_message("```python\nprint('hi')\n```")
assert "```" not in result
assert "code" in result
assert "print('hi')" in result
def test_strip_inline_code(self):
assert self.adapter.format_message("`code`") == "code"
def test_strips_inline_code(self):
adapter = self._make_adapter()
assert adapter.format_message("`code`") == "code"
def test_strip_headers(self):
assert self.adapter.format_message("## Header") == "Header"
def test_strips_headers(self):
adapter = self._make_adapter()
assert adapter.format_message("## Title") == "Title"
def test_strip_links(self):
assert self.adapter.format_message("[click](http://example.com)") == "click"
def test_strips_links(self):
adapter = self._make_adapter()
assert adapter.format_message("[click](https://example.com)") == "click"
def test_collapse_newlines(self):
result = self.adapter.format_message("a\n\n\n\nb")
def test_collapses_newlines(self):
adapter = self._make_adapter()
result = adapter.format_message("a\n\n\n\nb")
assert result == "a\n\nb"
class TestSmsTruncateMessage:
def setup_method(self):
# ── Echo prevention ────────────────────────────────────────────────
class TestSmsEchoPrevention:
"""Adapter should ignore messages from its own number."""
def test_own_number_detection(self):
"""The adapter stores _from_number for echo prevention."""
from gateway.platforms.sms import SmsAdapter
config = PlatformConfig(enabled=True, api_key="test_key")
with patch.dict("os.environ", {"TELNYX_API_KEY": "test_key"}):
self.adapter = SmsAdapter(config)
def test_short_message_single_chunk(self):
msg = "Hello, world!"
chunks = self.adapter.truncate_message(msg)
assert len(chunks) == 1
assert chunks[0] == msg
def test_long_message_splits(self):
msg = "a " * 1000 # 2000 chars
chunks = self.adapter.truncate_message(msg)
assert len(chunks) >= 2
for chunk in chunks:
assert len(chunk) <= 1600
def test_custom_max_length(self):
msg = "Hello " * 20
chunks = self.adapter.truncate_message(msg, max_length=50)
assert all(len(c) <= 50 for c in chunks)
env = {
"TWILIO_ACCOUNT_SID": "ACtest",
"TWILIO_AUTH_TOKEN": "tok",
"TWILIO_PHONE_NUMBER": "+15550001111",
}
with patch.dict(os.environ, env):
pc = PlatformConfig(enabled=True, api_key="tok")
adapter = SmsAdapter(pc)
assert adapter._from_number == "+15550001111"
# ---------------------------------------------------------------------------
# Echo loop prevention
# ---------------------------------------------------------------------------
class TestSmsEchoLoop:
def test_own_number_ignored(self):
from gateway.platforms.sms import SmsAdapter
config = PlatformConfig(enabled=True, api_key="test_key")
with patch.dict("os.environ", {
"TELNYX_API_KEY": "test_key",
"TELNYX_FROM_NUMBERS": "+15551234567,+15559876543",
}):
adapter = SmsAdapter(config)
assert "+15551234567" in adapter._from_numbers
assert "+15559876543" in adapter._from_numbers
# ---------------------------------------------------------------------------
# Auth maps
# ---------------------------------------------------------------------------
class TestSmsAuthMaps:
def test_sms_in_allowed_users_map(self):
"""SMS should be in the platform auth maps in run.py."""
# Verify the env var names are consistent
import os
os.environ.setdefault("SMS_ALLOWED_USERS", "+15551234567")
assert os.getenv("SMS_ALLOWED_USERS") == "+15551234567"
def test_sms_allow_all_env_var(self):
"""SMS_ALLOW_ALL_USERS should be recognized."""
import os
os.environ.setdefault("SMS_ALLOW_ALL_USERS", "true")
assert os.getenv("SMS_ALLOW_ALL_USERS") == "true"
# ---------------------------------------------------------------------------
# Requirements check
# ---------------------------------------------------------------------------
# ── Requirements check ─────────────────────────────────────────────
class TestSmsRequirements:
def test_check_sms_requirements_with_key(self, monkeypatch):
monkeypatch.setenv("TELNYX_API_KEY", "KEY_test123")
def test_check_sms_requirements_missing_sid(self):
from gateway.platforms.sms import check_sms_requirements
# aiohttp is available in test environment
assert check_sms_requirements() is True
def test_check_sms_requirements_without_key(self, monkeypatch):
monkeypatch.delenv("TELNYX_API_KEY", raising=False)
env = {"TWILIO_AUTH_TOKEN": "tok"}
with patch.dict(os.environ, env, clear=True):
assert check_sms_requirements() is False
def test_check_sms_requirements_missing_token(self):
from gateway.platforms.sms import check_sms_requirements
assert check_sms_requirements() is False
env = {"TWILIO_ACCOUNT_SID": "ACtest"}
with patch.dict(os.environ, env, clear=True):
assert check_sms_requirements() is False
def test_check_sms_requirements_both_set(self):
from gateway.platforms.sms import check_sms_requirements
env = {
"TWILIO_ACCOUNT_SID": "ACtest",
"TWILIO_AUTH_TOKEN": "tok",
}
with patch.dict(os.environ, env, clear=False):
# Only returns True if aiohttp is also importable
result = check_sms_requirements()
try:
import aiohttp # noqa: F401
assert result is True
except ImportError:
assert result is False
# ---------------------------------------------------------------------------
# Toolset & integration points
# ---------------------------------------------------------------------------
# ── Toolset verification ───────────────────────────────────────────
class TestSmsToolset:
def test_hermes_sms_toolset_exists(self):
from toolsets import get_toolset
ts = get_toolset("hermes-sms")
assert ts is not None
assert "hermes-sms" in ts.get("description", "").lower() or "sms" in ts.get("description", "").lower()
assert "tools" in ts
def test_hermes_gateway_includes_sms(self):
def test_hermes_sms_in_gateway_includes(self):
from toolsets import get_toolset
gw = get_toolset("hermes-gateway")
assert gw is not None
assert "hermes-sms" in gw["includes"]
class TestSmsPlatformHints:
def test_sms_in_platform_hints(self):
def test_sms_platform_hint_exists(self):
from agent.prompt_builder import PLATFORM_HINTS
assert "sms" in PLATFORM_HINTS
assert "SMS" in PLATFORM_HINTS["sms"] or "sms" in PLATFORM_HINTS["sms"].lower()
assert "concise" in PLATFORM_HINTS["sms"].lower()
class TestSmsCronDelivery:
def test_sms_in_cron_platform_map(self):
"""Verify the cron scheduler can resolve 'sms' platform."""
# The platform_map in _deliver_result should include sms
from gateway.config import Platform
def test_sms_in_scheduler_platform_map(self):
"""Verify cron scheduler recognizes 'sms' as a valid platform."""
# Just check the Platform enum has SMS — the scheduler imports it dynamically
assert Platform.SMS.value == "sms"
class TestSmsSendMessageTool:
def test_sms_in_send_message_platform_map(self):
"""The send_message tool should recognize 'sms' as a valid platform."""
# We verify by checking that SMS is in the Platform enum
# and the code path exists
from gateway.config import Platform
"""Verify send_message_tool recognizes 'sms'."""
# The platform_map is built inside _handle_send; verify SMS enum exists
assert hasattr(Platform, "SMS")
class TestSmsChannelDirectory:
def test_sms_in_session_discovery(self):
"""Verify SMS is included in session-based channel discovery."""
import inspect
from gateway.channel_directory import build_channel_directory
source = inspect.getsource(build_channel_directory)
assert '"sms"' in source
class TestSmsStatus:
def test_sms_in_status_platforms(self):
"""Verify SMS appears in the status command platforms dict."""
import inspect
from hermes_cli.status import show_status
source = inspect.getsource(show_status)
assert '"SMS"' in source or "'SMS'" in source
def test_sms_in_cronjob_deliver_description(self):
"""Verify cronjob_tools mentions sms in deliver description."""
from tools.cronjob_tools import CRONJOB_SCHEMA
deliver_desc = CRONJOB_SCHEMA["parameters"]["properties"]["deliver"]["description"]
assert "sms" in deliver_desc.lower()

View file

@ -128,6 +128,13 @@ async def test_handle_message_persists_agent_token_counts(monkeypatch):
session_entry.session_key,
input_tokens=120,
output_tokens=45,
cache_read_tokens=0,
cache_write_tokens=0,
last_prompt_tokens=80,
model="openai/test-model",
estimated_cost_usd=None,
cost_status=None,
cost_source=None,
provider=None,
base_url=None,
)

View file

@ -316,6 +316,38 @@ class TestSanitizeEnvLines:
assert fixes == 0
class TestOptionalEnvVarsRegistry:
"""Verify that key env vars are registered in OPTIONAL_ENV_VARS."""
def test_tavily_api_key_registered(self):
"""TAVILY_API_KEY is listed in OPTIONAL_ENV_VARS."""
from hermes_cli.config import OPTIONAL_ENV_VARS
assert "TAVILY_API_KEY" in OPTIONAL_ENV_VARS
def test_tavily_api_key_is_tool_category(self):
"""TAVILY_API_KEY is in the 'tool' category."""
from hermes_cli.config import OPTIONAL_ENV_VARS
assert OPTIONAL_ENV_VARS["TAVILY_API_KEY"]["category"] == "tool"
def test_tavily_api_key_is_password(self):
"""TAVILY_API_KEY is marked as password."""
from hermes_cli.config import OPTIONAL_ENV_VARS
assert OPTIONAL_ENV_VARS["TAVILY_API_KEY"]["password"] is True
def test_tavily_api_key_has_url(self):
"""TAVILY_API_KEY has a URL."""
from hermes_cli.config import OPTIONAL_ENV_VARS
assert OPTIONAL_ENV_VARS["TAVILY_API_KEY"]["url"] == "https://app.tavily.com/home"
def test_tavily_in_env_vars_by_version(self):
"""TAVILY_API_KEY is listed in ENV_VARS_BY_VERSION."""
from hermes_cli.config import ENV_VARS_BY_VERSION
all_vars = []
for vars_list in ENV_VARS_BY_VERSION.values():
all_vars.extend(vars_list)
assert "TAVILY_API_KEY" in all_vars
class TestAnthropicTokenMigration:
"""Test that config version 8→9 clears ANTHROPIC_TOKEN."""

View file

@ -0,0 +1,291 @@
"""Tests for MCP tools interactive configuration in hermes_cli.tools_config."""
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from hermes_cli.tools_config import _configure_mcp_tools_interactive
# Patch targets: imports happen inside the function body, so patch at source
_PROBE = "tools.mcp_tool.probe_mcp_server_tools"
_CHECKLIST = "hermes_cli.curses_ui.curses_checklist"
_SAVE = "hermes_cli.tools_config.save_config"
def test_no_mcp_servers_prints_info(capsys):
"""Returns immediately when no MCP servers are configured."""
config = {}
_configure_mcp_tools_interactive(config)
captured = capsys.readouterr()
assert "No MCP servers configured" in captured.out
def test_all_servers_disabled_prints_info(capsys):
"""Returns immediately when all configured servers have enabled=false."""
config = {
"mcp_servers": {
"github": {"command": "npx", "enabled": False},
"slack": {"command": "npx", "enabled": "false"},
}
}
_configure_mcp_tools_interactive(config)
captured = capsys.readouterr()
assert "disabled" in captured.out
def test_probe_failure_shows_warning(capsys):
"""Shows warning when probe returns no tools."""
config = {"mcp_servers": {"github": {"command": "npx"}}}
with patch(_PROBE, return_value={}):
_configure_mcp_tools_interactive(config)
captured = capsys.readouterr()
assert "Could not discover" in captured.out
def test_probe_exception_shows_error(capsys):
"""Shows error when probe raises an exception."""
config = {"mcp_servers": {"github": {"command": "npx"}}}
with patch(_PROBE, side_effect=RuntimeError("MCP not installed")):
_configure_mcp_tools_interactive(config)
captured = capsys.readouterr()
assert "Failed to probe" in captured.out
def test_no_changes_when_checklist_cancelled(capsys):
"""No config changes when user cancels (ESC) the checklist."""
config = {
"mcp_servers": {
"github": {"command": "npx", "args": ["-y", "server-github"]},
}
}
tools = [("create_issue", "Create an issue"), ("search_repos", "Search repos")]
with patch(_PROBE, return_value={"github": tools}), \
patch(_CHECKLIST, return_value={0, 1}), \
patch(_SAVE) as mock_save:
_configure_mcp_tools_interactive(config)
mock_save.assert_not_called()
captured = capsys.readouterr()
assert "no changes" in captured.out.lower()
def test_disabling_tool_writes_exclude_list(capsys):
"""Unchecking a tool adds it to the exclude list."""
config = {
"mcp_servers": {
"github": {"command": "npx"},
}
}
tools = [
("create_issue", "Create an issue"),
("delete_repo", "Delete a repo"),
("search_repos", "Search repos"),
]
# User unchecks delete_repo (index 1)
with patch(_PROBE, return_value={"github": tools}), \
patch(_CHECKLIST, return_value={0, 2}), \
patch(_SAVE) as mock_save:
_configure_mcp_tools_interactive(config)
mock_save.assert_called_once()
tools_cfg = config["mcp_servers"]["github"]["tools"]
assert tools_cfg["exclude"] == ["delete_repo"]
assert "include" not in tools_cfg
def test_enabling_all_clears_filters(capsys):
"""Checking all tools clears both include and exclude lists."""
config = {
"mcp_servers": {
"github": {
"command": "npx",
"tools": {"exclude": ["delete_repo"], "include": ["create_issue"]},
},
}
}
tools = [("create_issue", "Create"), ("delete_repo", "Delete")]
# User checks all tools — pre_selected would be {0} (include mode),
# so returning {0, 1} is a change
with patch(_PROBE, return_value={"github": tools}), \
patch(_CHECKLIST, return_value={0, 1}), \
patch(_SAVE) as mock_save:
_configure_mcp_tools_interactive(config)
mock_save.assert_called_once()
tools_cfg = config["mcp_servers"]["github"]["tools"]
assert "exclude" not in tools_cfg
assert "include" not in tools_cfg
def test_pre_selection_respects_existing_exclude(capsys):
"""Tools in exclude list start unchecked."""
config = {
"mcp_servers": {
"github": {
"command": "npx",
"tools": {"exclude": ["delete_repo"]},
},
}
}
tools = [("create_issue", "Create"), ("delete_repo", "Delete"), ("search", "Search")]
captured_pre_selected = {}
def fake_checklist(title, labels, pre_selected, **kwargs):
captured_pre_selected["value"] = set(pre_selected)
return pre_selected # No changes
with patch(_PROBE, return_value={"github": tools}), \
patch(_CHECKLIST, side_effect=fake_checklist), \
patch(_SAVE):
_configure_mcp_tools_interactive(config)
# create_issue (0) and search (2) should be pre-selected, delete_repo (1) should not
assert captured_pre_selected["value"] == {0, 2}
def test_pre_selection_respects_existing_include(capsys):
"""Only tools in include list start checked."""
config = {
"mcp_servers": {
"github": {
"command": "npx",
"tools": {"include": ["search"]},
},
}
}
tools = [("create_issue", "Create"), ("delete_repo", "Delete"), ("search", "Search")]
captured_pre_selected = {}
def fake_checklist(title, labels, pre_selected, **kwargs):
captured_pre_selected["value"] = set(pre_selected)
return pre_selected # No changes
with patch(_PROBE, return_value={"github": tools}), \
patch(_CHECKLIST, side_effect=fake_checklist), \
patch(_SAVE):
_configure_mcp_tools_interactive(config)
# Only search (2) should be pre-selected
assert captured_pre_selected["value"] == {2}
def test_multiple_servers_each_get_checklist(capsys):
"""Each server gets its own checklist."""
config = {
"mcp_servers": {
"github": {"command": "npx"},
"slack": {"url": "https://mcp.example.com"},
}
}
checklist_calls = []
def fake_checklist(title, labels, pre_selected, **kwargs):
checklist_calls.append(title)
return pre_selected # No changes
with patch(
_PROBE,
return_value={
"github": [("create_issue", "Create")],
"slack": [("send_message", "Send")],
},
), patch(_CHECKLIST, side_effect=fake_checklist), \
patch(_SAVE):
_configure_mcp_tools_interactive(config)
assert len(checklist_calls) == 2
assert any("github" in t for t in checklist_calls)
assert any("slack" in t for t in checklist_calls)
def test_failed_server_shows_warning(capsys):
"""Servers that fail to connect show warnings."""
config = {
"mcp_servers": {
"github": {"command": "npx"},
"broken": {"command": "nonexistent"},
}
}
# Only github succeeds
with patch(
_PROBE, return_value={"github": [("create_issue", "Create")]},
), patch(_CHECKLIST, return_value={0}), \
patch(_SAVE):
_configure_mcp_tools_interactive(config)
captured = capsys.readouterr()
assert "broken" in captured.out
def test_description_truncation_in_labels():
"""Long descriptions are truncated in checklist labels."""
config = {
"mcp_servers": {
"github": {"command": "npx"},
}
}
long_desc = "A" * 100
captured_labels = {}
def fake_checklist(title, labels, pre_selected, **kwargs):
captured_labels["value"] = labels
return pre_selected
with patch(
_PROBE, return_value={"github": [("my_tool", long_desc)]},
), patch(_CHECKLIST, side_effect=fake_checklist), \
patch(_SAVE):
_configure_mcp_tools_interactive(config)
label = captured_labels["value"][0]
assert "..." in label
assert len(label) < len(long_desc) + 30 # truncated + tool name + parens
def test_switching_from_include_to_exclude(capsys):
"""When user modifies selection, include list is replaced by exclude list."""
config = {
"mcp_servers": {
"github": {
"command": "npx",
"tools": {"include": ["create_issue"]},
},
}
}
tools = [("create_issue", "Create"), ("search", "Search"), ("delete", "Delete")]
# User selects create_issue and search (deselects delete)
# pre_selected would be {0} (only create_issue from include), so {0, 1} is a change
with patch(_PROBE, return_value={"github": tools}), \
patch(_CHECKLIST, return_value={0, 1}), \
patch(_SAVE):
_configure_mcp_tools_interactive(config)
tools_cfg = config["mcp_servers"]["github"]["tools"]
assert tools_cfg["exclude"] == ["delete"]
assert "include" not in tools_cfg
def test_empty_tools_server_skipped(capsys):
"""Server with no tools shows info message and skips checklist."""
config = {
"mcp_servers": {
"empty": {"command": "npx"},
}
}
checklist_calls = []
def fake_checklist(title, labels, pre_selected, **kwargs):
checklist_calls.append(title)
return pre_selected
with patch(_PROBE, return_value={"empty": []}), \
patch(_CHECKLIST, side_effect=fake_checklist), \
patch(_SAVE):
_configure_mcp_tools_interactive(config)
assert len(checklist_calls) == 0
captured = capsys.readouterr()
assert "no tools found" in captured.out

View file

@ -5,6 +5,13 @@ from hermes_cli.config import load_config, save_config
from hermes_cli.setup import setup_model_provider
def _maybe_keep_current_tts(question, choices):
if question != "Select TTS provider:":
return None
assert choices[-1].startswith("Keep current (")
return len(choices) - 1
def _clear_provider_env(monkeypatch):
for key in (
"NOUS_API_KEY",
@ -25,16 +32,22 @@ def test_nous_oauth_setup_keeps_current_model_when_syncing_disk_provider(
config = load_config()
# Provider selection always comes first. Depending on available vision
# backends, setup may either skip the optional vision step or prompt for
# it before the default-model choice. Provide enough selections for both
# paths while still ending on "keep current model".
prompt_choices = iter([0, 2, 2])
monkeypatch.setattr(
"hermes_cli.setup.prompt_choice",
lambda *args, **kwargs: next(prompt_choices),
)
def fake_prompt_choice(question, choices, default=0):
if question == "Select your inference provider:":
return 0
if question == "Configure vision:":
return len(choices) - 1
if question == "Select default model:":
assert choices[-1] == "Keep current (anthropic/claude-opus-4.6)"
return len(choices) - 1
tts_idx = _maybe_keep_current_tts(question, choices)
if tts_idx is not None:
return tts_idx
raise AssertionError(f"Unexpected prompt_choice call: {question}")
monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice)
monkeypatch.setattr("hermes_cli.setup.prompt", lambda *args, **kwargs: "")
monkeypatch.setattr("hermes_cli.auth.detect_external_credentials", lambda: [])
def _fake_login_nous(*args, **kwargs):
auth_path = tmp_path / "auth.json"
@ -53,7 +66,6 @@ def test_nous_oauth_setup_keeps_current_model_when_syncing_disk_provider(
"hermes_cli.auth.fetch_nous_models",
lambda *args, **kwargs: ["gemini-3-flash"],
)
monkeypatch.setattr("hermes_cli.setup._setup_tts_provider", lambda config: None)
setup_model_provider(config)
save_config(config)
@ -75,21 +87,29 @@ def test_custom_setup_clears_active_oauth_provider(tmp_path, monkeypatch):
config = load_config()
monkeypatch.setattr("hermes_cli.setup.prompt_choice", lambda *args, **kwargs: 3)
def fake_prompt_choice(question, choices, default=0):
if question == "Select your inference provider:":
return 3
tts_idx = _maybe_keep_current_tts(question, choices)
if tts_idx is not None:
return tts_idx
raise AssertionError(f"Unexpected prompt_choice call: {question}")
monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice)
prompt_values = iter(
[
"https://custom.example/v1",
"custom-api-key",
"custom/model",
"",
]
)
monkeypatch.setattr(
"hermes_cli.setup.prompt",
lambda *args, **kwargs: next(prompt_values),
)
monkeypatch.setattr("hermes_cli.setup._setup_tts_provider", lambda config: None)
monkeypatch.setattr("hermes_cli.setup.prompt_yes_no", lambda *args, **kwargs: False)
monkeypatch.setattr("hermes_cli.auth.detect_external_credentials", lambda: [])
setup_model_provider(config)
save_config(config)
@ -111,11 +131,17 @@ def test_codex_setup_uses_runtime_access_token_for_live_model_list(tmp_path, mon
config = load_config()
prompt_choices = iter([1, 0])
monkeypatch.setattr(
"hermes_cli.setup.prompt_choice",
lambda *args, **kwargs: next(prompt_choices),
)
def fake_prompt_choice(question, choices, default=0):
if question == "Select your inference provider:":
return 1
if question == "Select default model:":
return 0
tts_idx = _maybe_keep_current_tts(question, choices)
if tts_idx is not None:
return tts_idx
raise AssertionError(f"Unexpected prompt_choice call: {question}")
monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice)
monkeypatch.setattr("hermes_cli.setup.prompt", lambda *args, **kwargs: "")
monkeypatch.setattr("hermes_cli.auth.detect_external_credentials", lambda: [])
monkeypatch.setattr("hermes_cli.auth._login_openai_codex", lambda *args, **kwargs: None)
@ -137,7 +163,6 @@ def test_codex_setup_uses_runtime_access_token_for_live_model_list(tmp_path, mon
"hermes_cli.codex_models.get_codex_model_ids",
_fake_get_codex_model_ids,
)
monkeypatch.setattr("hermes_cli.setup._setup_tts_provider", lambda config: None)
setup_model_provider(config)
save_config(config)

View file

@ -6,6 +6,13 @@ from hermes_cli.config import load_config, save_config, save_env_value
from hermes_cli.setup import _print_setup_summary, setup_model_provider
def _maybe_keep_current_tts(question, choices):
if question != "Select TTS provider:":
return None
assert choices[-1].startswith("Keep current (")
return len(choices) - 1
def _read_env(home):
env_path = home / ".env"
data = {}
@ -50,19 +57,18 @@ def test_setup_keep_current_custom_from_config_does_not_fall_through(tmp_path, m
}
save_config(config)
calls = {"count": 0}
def fake_prompt_choice(question, choices, default=0):
calls["count"] += 1
if calls["count"] == 1:
if question == "Select your inference provider:":
assert choices[-1] == "Keep current (Custom: https://example.invalid/v1)"
return len(choices) - 1
tts_idx = _maybe_keep_current_tts(question, choices)
if tts_idx is not None:
return tts_idx
raise AssertionError("Model menu should not appear for keep-current custom")
monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice)
monkeypatch.setattr("hermes_cli.setup.prompt", lambda *args, **kwargs: "")
monkeypatch.setattr("hermes_cli.setup.prompt_yes_no", lambda *args, **kwargs: False)
monkeypatch.setattr("hermes_cli.setup._setup_tts_provider", lambda config: None)
monkeypatch.setattr("hermes_cli.auth.get_active_provider", lambda: None)
monkeypatch.setattr("hermes_cli.auth.detect_external_credentials", lambda: [])
@ -73,7 +79,6 @@ def test_setup_keep_current_custom_from_config_does_not_fall_through(tmp_path, m
assert reloaded["model"]["provider"] == "custom"
assert reloaded["model"]["default"] == "custom/model"
assert reloaded["model"]["base_url"] == "https://example.invalid/v1"
assert calls["count"] == 1
def test_setup_custom_endpoint_saves_working_v1_base_url(tmp_path, monkeypatch):
@ -87,8 +92,9 @@ def test_setup_custom_endpoint_saves_working_v1_base_url(tmp_path, monkeypatch):
return 3 # Custom endpoint
if question == "Configure vision:":
return len(choices) - 1 # Skip
if question == "Select TTS provider:":
return len(choices) - 1 # Keep current
tts_idx = _maybe_keep_current_tts(question, choices)
if tts_idx is not None:
return tts_idx
raise AssertionError(f"Unexpected prompt_choice call: {question}")
def fake_prompt(message, current=None, **kwargs):
@ -103,7 +109,6 @@ def test_setup_custom_endpoint_saves_working_v1_base_url(tmp_path, monkeypatch):
monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice)
monkeypatch.setattr("hermes_cli.setup.prompt", fake_prompt)
monkeypatch.setattr("hermes_cli.setup.prompt_yes_no", lambda *args, **kwargs: False)
monkeypatch.setattr("hermes_cli.setup._setup_tts_provider", lambda config: None)
monkeypatch.setattr("hermes_cli.auth.get_active_provider", lambda: None)
monkeypatch.setattr("hermes_cli.auth.detect_external_credentials", lambda: [])
monkeypatch.setattr("agent.auxiliary_client.get_available_vision_backends", lambda: [])
@ -144,25 +149,23 @@ def test_setup_keep_current_config_provider_uses_provider_specific_model_menu(tm
save_config(config)
captured = {"provider_choices": None, "model_choices": None}
calls = {"count": 0}
def fake_prompt_choice(question, choices, default=0):
calls["count"] += 1
if calls["count"] == 1:
if question == "Select your inference provider:":
captured["provider_choices"] = list(choices)
assert choices[-1] == "Keep current (Anthropic)"
return len(choices) - 1
if calls["count"] == 2:
if question == "Configure vision:":
assert question == "Configure vision:"
assert choices[-1] == "Skip for now"
return len(choices) - 1
if calls["count"] == 3:
if question == "Select default model:":
captured["model_choices"] = list(choices)
return len(choices) - 1 # keep current model
if calls["count"] == 4:
assert question == "Select TTS provider:"
return len(choices) - 1 # Keep current
raise AssertionError("Unexpected extra prompt_choice call")
tts_idx = _maybe_keep_current_tts(question, choices)
if tts_idx is not None:
return tts_idx
raise AssertionError(f"Unexpected prompt_choice call: {question}")
monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice)
monkeypatch.setattr("hermes_cli.setup.prompt", lambda *args, **kwargs: "")
@ -179,7 +182,6 @@ def test_setup_keep_current_config_provider_uses_provider_specific_model_menu(tm
assert captured["model_choices"] is not None
assert captured["model_choices"][0] == "claude-opus-4-6"
assert "anthropic/claude-opus-4.6 (recommended)" not in captured["model_choices"]
assert calls["count"] == 4 # provider, vision, model, TTS
def test_setup_keep_current_anthropic_can_configure_openai_vision_default(tmp_path, monkeypatch):
@ -193,15 +195,24 @@ def test_setup_keep_current_anthropic_can_configure_openai_vision_default(tmp_pa
}
save_config(config)
picks = iter([
10, # keep current provider (shifted +1 by kilocode insertion)
1, # configure vision with OpenAI
5, # use default gpt-4o-mini vision model
4, # keep current Anthropic model
4, # TTS: Keep current
])
def fake_prompt_choice(question, choices, default=0):
if question == "Select your inference provider:":
assert choices[-1] == "Keep current (Anthropic)"
return len(choices) - 1
if question == "Configure vision:":
return 1
if question == "Select vision model:":
assert choices[-1] == "Use default (gpt-4o-mini)"
return len(choices) - 1
if question == "Select default model:":
assert choices[-1] == "Keep current (claude-opus-4-6)"
return len(choices) - 1
tts_idx = _maybe_keep_current_tts(question, choices)
if tts_idx is not None:
return tts_idx
raise AssertionError(f"Unexpected prompt_choice call: {question}")
monkeypatch.setattr("hermes_cli.setup.prompt_choice", lambda *args, **kwargs: next(picks))
monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice)
monkeypatch.setattr(
"hermes_cli.setup.prompt",
lambda message, *args, **kwargs: "sk-openai" if "OpenAI API key" in message else "",
@ -237,8 +248,17 @@ def test_setup_switch_custom_to_codex_clears_custom_endpoint_and_updates_config(
}
save_config(config)
picks = iter([1, 0, 4]) # provider, model; 4 = TTS Keep current
monkeypatch.setattr("hermes_cli.setup.prompt_choice", lambda *args, **kwargs: next(picks))
def fake_prompt_choice(question, choices, default=0):
if question == "Select your inference provider:":
return 1
if question == "Select default model:":
return 0
tts_idx = _maybe_keep_current_tts(question, choices)
if tts_idx is not None:
return tts_idx
raise AssertionError(f"Unexpected prompt_choice call: {question}")
monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice)
monkeypatch.setattr("hermes_cli.setup.prompt", lambda *args, **kwargs: "")
monkeypatch.setattr("hermes_cli.setup.prompt_yes_no", lambda *args, **kwargs: False)
monkeypatch.setattr("hermes_cli.auth.get_active_provider", lambda: None)

View file

@ -0,0 +1,14 @@
from types import SimpleNamespace
from hermes_cli.status import show_status
def test_show_status_includes_tavily_key(monkeypatch, capsys, tmp_path):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
monkeypatch.setenv("TAVILY_API_KEY", "tvly-1234567890abcdef")
show_status(SimpleNamespace(all=False, deep=False))
output = capsys.readouterr().out
assert "Tavily" in output
assert "tvly...cdef" in output

View file

@ -4,6 +4,7 @@ from types import SimpleNamespace
import pytest
from hermes_cli import config as hermes_config
from hermes_cli import main as hermes_main
@ -235,3 +236,82 @@ def test_stash_local_changes_if_needed_raises_when_stash_ref_missing(monkeypatch
with pytest.raises(CalledProcessError):
hermes_main._stash_local_changes_if_needed(["git"], Path(tmp_path))
# ---------------------------------------------------------------------------
# Update uses .[all] with fallback to .
# ---------------------------------------------------------------------------
def _setup_update_mocks(monkeypatch, tmp_path):
"""Common setup for cmd_update tests."""
(tmp_path / ".git").mkdir()
monkeypatch.setattr(hermes_main, "PROJECT_ROOT", tmp_path)
monkeypatch.setattr(hermes_main, "_stash_local_changes_if_needed", lambda *a, **kw: None)
monkeypatch.setattr(hermes_main, "_restore_stashed_changes", lambda *a, **kw: True)
monkeypatch.setattr(hermes_config, "get_missing_env_vars", lambda required_only=True: [])
monkeypatch.setattr(hermes_config, "get_missing_config_fields", lambda: [])
monkeypatch.setattr(hermes_config, "check_config_version", lambda: (5, 5))
monkeypatch.setattr(hermes_config, "migrate_config", lambda **kw: {"env_added": [], "config_added": []})
def test_cmd_update_tries_extras_first_then_falls_back(monkeypatch, tmp_path):
"""When .[all] fails, update should fall back to . instead of aborting."""
_setup_update_mocks(monkeypatch, tmp_path)
monkeypatch.setattr("shutil.which", lambda name: "/usr/bin/uv" if name == "uv" else None)
recorded = []
def fake_run(cmd, **kwargs):
recorded.append(cmd)
if cmd == ["git", "fetch", "origin"]:
return SimpleNamespace(stdout="", stderr="", returncode=0)
if cmd == ["git", "rev-parse", "--abbrev-ref", "HEAD"]:
return SimpleNamespace(stdout="main\n", stderr="", returncode=0)
if cmd == ["git", "rev-list", "HEAD..origin/main", "--count"]:
return SimpleNamespace(stdout="1\n", stderr="", returncode=0)
if cmd == ["git", "pull", "origin", "main"]:
return SimpleNamespace(stdout="Updating\n", stderr="", returncode=0)
# .[all] fails
if ".[all]" in cmd:
raise CalledProcessError(returncode=1, cmd=cmd)
# bare . succeeds
if cmd == ["/usr/bin/uv", "pip", "install", "-e", ".", "--quiet"]:
return SimpleNamespace(returncode=0)
return SimpleNamespace(returncode=0)
monkeypatch.setattr(hermes_main.subprocess, "run", fake_run)
hermes_main.cmd_update(SimpleNamespace())
install_cmds = [c for c in recorded if "pip" in c and "install" in c]
assert len(install_cmds) == 2
assert ".[all]" in install_cmds[0]
assert "." in install_cmds[1] and ".[all]" not in install_cmds[1]
def test_cmd_update_succeeds_with_extras(monkeypatch, tmp_path):
"""When .[all] succeeds, no fallback should be attempted."""
_setup_update_mocks(monkeypatch, tmp_path)
monkeypatch.setattr("shutil.which", lambda name: "/usr/bin/uv" if name == "uv" else None)
recorded = []
def fake_run(cmd, **kwargs):
recorded.append(cmd)
if cmd == ["git", "fetch", "origin"]:
return SimpleNamespace(stdout="", stderr="", returncode=0)
if cmd == ["git", "rev-parse", "--abbrev-ref", "HEAD"]:
return SimpleNamespace(stdout="main\n", stderr="", returncode=0)
if cmd == ["git", "rev-list", "HEAD..origin/main", "--count"]:
return SimpleNamespace(stdout="1\n", stderr="", returncode=0)
if cmd == ["git", "pull", "origin", "main"]:
return SimpleNamespace(stdout="Updating\n", stderr="", returncode=0)
return SimpleNamespace(returncode=0)
monkeypatch.setattr(hermes_main.subprocess, "run", fake_run)
hermes_main.cmd_update(SimpleNamespace())
install_cmds = [c for c in recorded if "pip" in c and "install" in c]
assert len(install_cmds) == 1
assert ".[all]" in install_cmds[0]

View file

@ -63,11 +63,13 @@ class TestFromEnv:
class TestFromGlobalConfig:
def test_missing_config_falls_back_to_env(self, tmp_path):
config = HonchoClientConfig.from_global_config(
config_path=tmp_path / "nonexistent.json"
)
with patch.dict(os.environ, {}, clear=True):
config = HonchoClientConfig.from_global_config(
config_path=tmp_path / "nonexistent.json"
)
# Should fall back to from_env
assert config.enabled is True or config.api_key is None # depends on env
assert config.enabled is False
assert config.api_key is None
def test_reads_full_config(self, tmp_path):
config_file = tmp_path / "config.json"

View file

@ -3,7 +3,7 @@
Comprehensive Test Suite for Web Tools Module
This script tests all web tools functionality to ensure they work correctly.
Run this after any updates to the web_tools.py module or Firecrawl library.
Run this after any updates to the web_tools.py module or backend libraries.
Usage:
python test_web_tools.py # Run all tests
@ -11,7 +11,7 @@ Usage:
python test_web_tools.py --verbose # Show detailed output
Requirements:
- FIRECRAWL_API_KEY environment variable must be set
- PARALLEL_API_KEY or FIRECRAWL_API_KEY environment variable must be set
- An auxiliary LLM provider (OPENROUTER_API_KEY or Nous Portal auth) (optional, for LLM tests)
"""
@ -28,12 +28,14 @@ from typing import List
# Import the web tools to test (updated path after moving tools/)
from tools.web_tools import (
web_search_tool,
web_extract_tool,
web_search_tool,
web_extract_tool,
web_crawl_tool,
check_firecrawl_api_key,
check_web_api_key,
check_auxiliary_model,
get_debug_session_info
get_debug_session_info,
_get_backend,
)
@ -121,12 +123,13 @@ class WebToolsTester:
"""Test environment setup and API keys"""
print_section("Environment Check")
# Check Firecrawl API key
if not check_firecrawl_api_key():
self.log_result("Firecrawl API Key", "failed", "FIRECRAWL_API_KEY not set")
# Check web backend API key (Parallel or Firecrawl)
if not check_web_api_key():
self.log_result("Web Backend API Key", "failed", "PARALLEL_API_KEY or FIRECRAWL_API_KEY not set")
return False
else:
self.log_result("Firecrawl API Key", "passed", "Found")
backend = _get_backend()
self.log_result("Web Backend API Key", "passed", f"Using {backend} backend")
# Check auxiliary LLM provider (optional)
if not check_auxiliary_model():
@ -578,7 +581,9 @@ class WebToolsTester:
},
"results": self.test_results,
"environment": {
"web_backend": _get_backend() if check_web_api_key() else None,
"firecrawl_api_key": check_firecrawl_api_key(),
"parallel_api_key": bool(os.getenv("PARALLEL_API_KEY")),
"auxiliary_model": check_auxiliary_model(),
"debug_mode": get_debug_session_info()["enabled"]
}

View file

@ -24,6 +24,7 @@ def main() -> int:
parent._interrupt_requested = False
parent._interrupt_message = None
parent._active_children = []
parent._active_children_lock = threading.Lock()
parent.quiet_mode = True
parent.model = "test/model"
parent.base_url = "http://localhost:1"

View file

@ -0,0 +1,263 @@
"""Unit tests for AIAgent pre/post-LLM-call guardrails.
Covers three static methods on AIAgent (inspired by PR #1321 — @alireza78a):
- _sanitize_api_messages() Phase 1: orphaned tool pair repair
- _cap_delegate_task_calls() Phase 2a: subagent concurrency limit
- _deduplicate_tool_calls() Phase 2b: identical call deduplication
"""
import types
from run_agent import AIAgent
from tools.delegate_tool import MAX_CONCURRENT_CHILDREN
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def make_tc(name: str, arguments: str = "{}") -> types.SimpleNamespace:
"""Create a minimal tool_call SimpleNamespace mirroring the OpenAI SDK object."""
tc = types.SimpleNamespace()
tc.function = types.SimpleNamespace(name=name, arguments=arguments)
return tc
def tool_result(call_id: str, content: str = "ok") -> dict:
return {"role": "tool", "tool_call_id": call_id, "content": content}
def assistant_dict_call(call_id: str, name: str = "terminal") -> dict:
"""Dict-style tool_call (as stored in message history)."""
return {"id": call_id, "function": {"name": name, "arguments": "{}"}}
# ---------------------------------------------------------------------------
# Phase 1 — _sanitize_api_messages
# ---------------------------------------------------------------------------
class TestSanitizeApiMessages:
def test_orphaned_result_removed(self):
msgs = [
{"role": "assistant", "tool_calls": [assistant_dict_call("c1")]},
tool_result("c1"),
tool_result("c_ORPHAN"),
]
out = AIAgent._sanitize_api_messages(msgs)
assert len(out) == 2
assert all(m.get("tool_call_id") != "c_ORPHAN" for m in out)
def test_orphaned_call_gets_stub_result(self):
msgs = [
{"role": "assistant", "tool_calls": [assistant_dict_call("c2")]},
]
out = AIAgent._sanitize_api_messages(msgs)
assert len(out) == 2
stub = out[1]
assert stub["role"] == "tool"
assert stub["tool_call_id"] == "c2"
assert stub["content"]
def test_clean_messages_pass_through(self):
msgs = [
{"role": "user", "content": "hello"},
{"role": "assistant", "tool_calls": [assistant_dict_call("c3")]},
tool_result("c3"),
{"role": "assistant", "content": "done"},
]
out = AIAgent._sanitize_api_messages(msgs)
assert out == msgs
def test_mixed_orphaned_result_and_orphaned_call(self):
msgs = [
{"role": "assistant", "tool_calls": [
assistant_dict_call("c4"),
assistant_dict_call("c5"),
]},
tool_result("c4"),
tool_result("c_DANGLING"),
]
out = AIAgent._sanitize_api_messages(msgs)
ids = [m.get("tool_call_id") for m in out if m.get("role") == "tool"]
assert "c_DANGLING" not in ids
assert "c4" in ids
assert "c5" in ids
def test_empty_list_is_safe(self):
assert AIAgent._sanitize_api_messages([]) == []
def test_no_tool_messages(self):
msgs = [
{"role": "user", "content": "hi"},
{"role": "assistant", "content": "hello"},
]
out = AIAgent._sanitize_api_messages(msgs)
assert out == msgs
def test_sdk_object_tool_calls(self):
tc_obj = types.SimpleNamespace(id="c6", function=types.SimpleNamespace(
name="terminal", arguments="{}"
))
msgs = [
{"role": "assistant", "tool_calls": [tc_obj]},
]
out = AIAgent._sanitize_api_messages(msgs)
assert len(out) == 2
assert out[1]["tool_call_id"] == "c6"
# ---------------------------------------------------------------------------
# Phase 2a — _cap_delegate_task_calls
# ---------------------------------------------------------------------------
class TestCapDelegateTaskCalls:
def test_excess_delegates_truncated(self):
tcs = [make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN + 2)]
out = AIAgent._cap_delegate_task_calls(tcs)
delegate_count = sum(1 for tc in out if tc.function.name == "delegate_task")
assert delegate_count == MAX_CONCURRENT_CHILDREN
def test_non_delegate_calls_preserved(self):
tcs = (
[make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN + 1)]
+ [make_tc("terminal"), make_tc("web_search")]
)
out = AIAgent._cap_delegate_task_calls(tcs)
names = [tc.function.name for tc in out]
assert "terminal" in names
assert "web_search" in names
def test_at_limit_passes_through(self):
tcs = [make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN)]
out = AIAgent._cap_delegate_task_calls(tcs)
assert out is tcs
def test_below_limit_passes_through(self):
tcs = [make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN - 1)]
out = AIAgent._cap_delegate_task_calls(tcs)
assert out is tcs
def test_no_delegate_calls_unchanged(self):
tcs = [make_tc("terminal"), make_tc("web_search")]
out = AIAgent._cap_delegate_task_calls(tcs)
assert out is tcs
def test_empty_list_safe(self):
assert AIAgent._cap_delegate_task_calls([]) == []
def test_original_list_not_mutated(self):
tcs = [make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN + 2)]
original_len = len(tcs)
AIAgent._cap_delegate_task_calls(tcs)
assert len(tcs) == original_len
def test_interleaved_order_preserved(self):
delegates = [make_tc("delegate_task", f'{{"task":"{i}"}}')
for i in range(MAX_CONCURRENT_CHILDREN + 1)]
t1 = make_tc("terminal", '{"cmd":"ls"}')
w1 = make_tc("web_search", '{"q":"x"}')
tcs = [delegates[0], t1, delegates[1], w1] + delegates[2:]
out = AIAgent._cap_delegate_task_calls(tcs)
expected = [delegates[0], t1, delegates[1], w1] + delegates[2:MAX_CONCURRENT_CHILDREN]
assert len(out) == len(expected)
for i, (actual, exp) in enumerate(zip(out, expected)):
assert actual is exp, f"mismatch at index {i}"
# ---------------------------------------------------------------------------
# Phase 2b — _deduplicate_tool_calls
# ---------------------------------------------------------------------------
class TestDeduplicateToolCalls:
def test_duplicate_pair_deduplicated(self):
tcs = [
make_tc("web_search", '{"query":"foo"}'),
make_tc("web_search", '{"query":"foo"}'),
]
out = AIAgent._deduplicate_tool_calls(tcs)
assert len(out) == 1
def test_multiple_duplicates(self):
tcs = [
make_tc("web_search", '{"q":"a"}'),
make_tc("web_search", '{"q":"a"}'),
make_tc("terminal", '{"cmd":"ls"}'),
make_tc("terminal", '{"cmd":"ls"}'),
make_tc("terminal", '{"cmd":"pwd"}'),
]
out = AIAgent._deduplicate_tool_calls(tcs)
assert len(out) == 3
def test_same_tool_different_args_kept(self):
tcs = [
make_tc("terminal", '{"cmd":"ls"}'),
make_tc("terminal", '{"cmd":"pwd"}'),
]
out = AIAgent._deduplicate_tool_calls(tcs)
assert out is tcs
def test_different_tools_same_args_kept(self):
tcs = [
make_tc("tool_a", '{"x":1}'),
make_tc("tool_b", '{"x":1}'),
]
out = AIAgent._deduplicate_tool_calls(tcs)
assert out is tcs
def test_clean_list_unchanged(self):
tcs = [
make_tc("web_search", '{"q":"x"}'),
make_tc("terminal", '{"cmd":"ls"}'),
]
out = AIAgent._deduplicate_tool_calls(tcs)
assert out is tcs
def test_empty_list_safe(self):
assert AIAgent._deduplicate_tool_calls([]) == []
def test_first_occurrence_kept(self):
tc1 = make_tc("terminal", '{"cmd":"ls"}')
tc2 = make_tc("terminal", '{"cmd":"ls"}')
out = AIAgent._deduplicate_tool_calls([tc1, tc2])
assert len(out) == 1
assert out[0] is tc1
def test_original_list_not_mutated(self):
tcs = [
make_tc("web_search", '{"q":"dup"}'),
make_tc("web_search", '{"q":"dup"}'),
]
original_len = len(tcs)
AIAgent._deduplicate_tool_calls(tcs)
assert len(tcs) == original_len
# ---------------------------------------------------------------------------
# _get_tool_call_id_static
# ---------------------------------------------------------------------------
class TestGetToolCallIdStatic:
def test_dict_with_valid_id(self):
assert AIAgent._get_tool_call_id_static({"id": "call_123"}) == "call_123"
def test_dict_with_none_id(self):
assert AIAgent._get_tool_call_id_static({"id": None}) == ""
def test_dict_without_id_key(self):
assert AIAgent._get_tool_call_id_static({"function": {}}) == ""
def test_object_with_valid_id(self):
tc = types.SimpleNamespace(id="call_456")
assert AIAgent._get_tool_call_id_static(tc) == "call_456"
def test_object_with_none_id(self):
tc = types.SimpleNamespace(id=None)
assert AIAgent._get_tool_call_id_static(tc) == ""
def test_object_without_id_attr(self):
tc = types.SimpleNamespace()
assert AIAgent._get_tool_call_id_static(tc) == ""

View file

@ -98,11 +98,14 @@ class TestProviderRegistry:
# =============================================================================
PROVIDER_ENV_VARS = (
"OPENROUTER_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY",
"OPENROUTER_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY", "ANTHROPIC_TOKEN",
"CLAUDE_CODE_OAUTH_TOKEN",
"GLM_API_KEY", "ZAI_API_KEY", "Z_AI_API_KEY",
"KIMI_API_KEY", "KIMI_BASE_URL", "MINIMAX_API_KEY", "MINIMAX_CN_API_KEY",
"AI_GATEWAY_API_KEY", "AI_GATEWAY_BASE_URL",
"KILOCODE_API_KEY", "KILOCODE_BASE_URL",
"DASHSCOPE_API_KEY", "OPENCODE_ZEN_API_KEY", "OPENCODE_GO_API_KEY",
"NOUS_API_KEY",
"OPENAI_BASE_URL",
)
@ -111,6 +114,7 @@ PROVIDER_ENV_VARS = (
def _clear_provider_env(monkeypatch):
for key in PROVIDER_ENV_VARS:
monkeypatch.delenv(key, raising=False)
monkeypatch.setattr("hermes_cli.auth._load_auth_store", lambda: {})
class TestResolveProvider:

View file

@ -43,6 +43,7 @@ class TestCLISubagentInterrupt(unittest.TestCase):
parent._interrupt_requested = False
parent._interrupt_message = None
parent._active_children = []
parent._active_children_lock = threading.Lock()
parent.quiet_mode = True
parent.model = "test/model"
parent.base_url = "http://localhost:1"
@ -112,21 +113,21 @@ class TestCLISubagentInterrupt(unittest.TestCase):
mock_instance._interrupt_requested = False
mock_instance._interrupt_message = None
mock_instance._active_children = []
mock_instance._active_children_lock = threading.Lock()
mock_instance.quiet_mode = True
mock_instance.run_conversation = mock_child_run_conversation
mock_instance.interrupt = lambda msg=None: setattr(mock_instance, '_interrupt_requested', True) or setattr(mock_instance, '_interrupt_message', msg)
mock_instance.tools = []
MockAgent.return_value = mock_instance
# Register child manually (normally done by _build_child_agent)
parent._active_children.append(mock_instance)
result = _run_single_child(
task_index=0,
goal="Do something slow",
context=None,
toolsets=["terminal"],
model=None,
max_iterations=50,
child=mock_instance,
parent_agent=parent,
task_count=1,
)
delegate_result[0] = result
except Exception as e:

View file

@ -16,6 +16,10 @@ def _make_cli(model: str = "anthropic/claude-sonnet-4-20250514"):
def _attach_agent(
cli_obj,
*,
input_tokens: int | None = None,
output_tokens: int | None = None,
cache_read_tokens: int = 0,
cache_write_tokens: int = 0,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
@ -26,6 +30,12 @@ def _attach_agent(
):
cli_obj.agent = SimpleNamespace(
model=cli_obj.model,
provider="anthropic" if cli_obj.model.startswith("anthropic/") else None,
base_url="",
session_input_tokens=input_tokens if input_tokens is not None else prompt_tokens,
session_output_tokens=output_tokens if output_tokens is not None else completion_tokens,
session_cache_read_tokens=cache_read_tokens,
session_cache_write_tokens=cache_write_tokens,
session_prompt_tokens=prompt_tokens,
session_completion_tokens=completion_tokens,
session_total_tokens=total_tokens,
@ -68,20 +78,19 @@ class TestCLIStatusBar:
assert "$0.06" not in text # cost hidden by default
assert "15m" in text
def test_build_status_bar_text_shows_cost_when_enabled(self):
def test_build_status_bar_text_no_cost_in_status_bar(self):
cli_obj = _attach_agent(
_make_cli(),
prompt_tokens=10000,
completion_tokens=2400,
total_tokens=12400,
completion_tokens=5000,
total_tokens=15000,
api_calls=7,
context_tokens=12400,
context_tokens=50000,
context_length=200_000,
)
cli_obj.show_cost = True
text = cli_obj._build_status_bar_text(width=120)
assert "$" in text # cost is shown when enabled
assert "$" not in text # cost is never shown in status bar
def test_build_status_bar_text_collapses_for_narrow_terminal(self):
cli_obj = _attach_agent(
@ -128,8 +137,8 @@ class TestCLIUsageReport:
output = capsys.readouterr().out
assert "Model:" in output
assert "Input cost:" in output
assert "Output cost:" in output
assert "Cost status:" in output
assert "Cost source:" in output
assert "Total cost:" in output
assert "$" in output
assert "0.064" in output

View file

@ -657,7 +657,7 @@ class TestSchemaInit:
def test_schema_version(self, db):
cursor = db._conn.execute("SELECT version FROM schema_version")
version = cursor.fetchone()[0]
assert version == 4
assert version == 5
def test_title_column_exists(self, db):
"""Verify the title column was created in the sessions table."""
@ -713,12 +713,12 @@ class TestSchemaInit:
conn.commit()
conn.close()
# Open with SessionDB — should migrate to v4
# Open with SessionDB — should migrate to v5
migrated_db = SessionDB(db_path=db_path)
# Verify migration
cursor = migrated_db._conn.execute("SELECT version FROM schema_version")
assert cursor.fetchone()[0] == 4
assert cursor.fetchone()[0] == 5
# Verify title column exists and is NULL for existing sessions
session = migrated_db.get_session("existing")

View file

@ -123,28 +123,16 @@ def populated_db(db):
# =========================================================================
class TestPricing:
def test_exact_match(self):
pricing = _get_pricing("gpt-4o")
assert pricing["input"] == 2.50
assert pricing["output"] == 10.00
def test_provider_prefix_stripped(self):
pricing = _get_pricing("anthropic/claude-sonnet-4-20250514")
assert pricing["input"] == 3.00
assert pricing["output"] == 15.00
def test_prefix_match(self):
pricing = _get_pricing("claude-3-5-sonnet-20241022")
assert pricing["input"] == 3.00
def test_keyword_heuristic_opus(self):
def test_unknown_models_do_not_use_heuristics(self):
pricing = _get_pricing("some-new-opus-model")
assert pricing["input"] == 15.00
assert pricing["output"] == 75.00
def test_keyword_heuristic_haiku(self):
assert pricing == _DEFAULT_PRICING
pricing = _get_pricing("anthropic/claude-haiku-future")
assert pricing["input"] == 0.80
assert pricing == _DEFAULT_PRICING
def test_unknown_model_returns_zero_cost(self):
"""Unknown/custom models should NOT have fabricated costs."""
@ -168,40 +156,12 @@ class TestPricing:
pricing = _get_pricing("")
assert pricing == _DEFAULT_PRICING
def test_deepseek_heuristic(self):
pricing = _get_pricing("deepseek-v3")
assert pricing["input"] == 0.14
def test_gemini_heuristic(self):
pricing = _get_pricing("gemini-3.0-ultra")
assert pricing["input"] == 0.15
def test_dated_model_gpt4o_mini(self):
"""gpt-4o-mini-2024-07-18 should match gpt-4o-mini, NOT gpt-4o."""
pricing = _get_pricing("gpt-4o-mini-2024-07-18")
assert pricing["input"] == 0.15 # gpt-4o-mini price, not gpt-4o's 2.50
def test_dated_model_o3_mini(self):
"""o3-mini-2025-01-31 should match o3-mini, NOT o3."""
pricing = _get_pricing("o3-mini-2025-01-31")
assert pricing["input"] == 1.10 # o3-mini price, not o3's 10.00
def test_dated_model_gpt41_mini(self):
"""gpt-4.1-mini-2025-04-14 should match gpt-4.1-mini, NOT gpt-4.1."""
pricing = _get_pricing("gpt-4.1-mini-2025-04-14")
assert pricing["input"] == 0.40 # gpt-4.1-mini, not gpt-4.1's 2.00
def test_dated_model_gpt41_nano(self):
"""gpt-4.1-nano-2025-04-14 should match gpt-4.1-nano, NOT gpt-4.1."""
pricing = _get_pricing("gpt-4.1-nano-2025-04-14")
assert pricing["input"] == 0.10 # gpt-4.1-nano, not gpt-4.1's 2.00
class TestHasKnownPricing:
def test_known_commercial_model(self):
assert _has_known_pricing("gpt-4o") is True
assert _has_known_pricing("gpt-4o", provider="openai") is True
assert _has_known_pricing("anthropic/claude-sonnet-4-20250514") is True
assert _has_known_pricing("deepseek-chat") is True
assert _has_known_pricing("gpt-4.1", provider="openai") is True
def test_unknown_custom_model(self):
assert _has_known_pricing("FP16_Hermes_4.5") is False
@ -210,26 +170,39 @@ class TestHasKnownPricing:
assert _has_known_pricing("") is False
assert _has_known_pricing(None) is False
def test_heuristic_matched_models(self):
"""Models matched by keyword heuristics should be considered known."""
assert _has_known_pricing("some-opus-model") is True
assert _has_known_pricing("future-sonnet-v2") is True
def test_heuristic_matched_models_are_not_considered_known(self):
assert _has_known_pricing("some-opus-model") is False
assert _has_known_pricing("future-sonnet-v2") is False
class TestEstimateCost:
def test_basic_cost(self):
# gpt-4o: 2.50/M input, 10.00/M output
cost = _estimate_cost("gpt-4o", 1_000_000, 1_000_000)
assert cost == pytest.approx(12.50, abs=0.01)
cost, status = _estimate_cost(
"anthropic/claude-sonnet-4-20250514",
1_000_000,
1_000_000,
provider="anthropic",
)
assert status == "estimated"
assert cost == pytest.approx(18.0, abs=0.01)
def test_zero_tokens(self):
cost = _estimate_cost("gpt-4o", 0, 0)
cost, status = _estimate_cost("gpt-4o", 0, 0, provider="openai")
assert status == "estimated"
assert cost == 0.0
def test_small_usage(self):
cost = _estimate_cost("gpt-4o", 1000, 500)
# 1000 * 2.50/1M + 500 * 10.00/1M = 0.0025 + 0.005 = 0.0075
assert cost == pytest.approx(0.0075, abs=0.0001)
def test_cache_aware_usage(self):
cost, status = _estimate_cost(
"anthropic/claude-sonnet-4-20250514",
1000,
500,
cache_read_tokens=2000,
cache_write_tokens=400,
provider="anthropic",
)
assert status == "estimated"
expected = (1000 * 3.0 + 500 * 15.0 + 2000 * 0.30 + 400 * 3.75) / 1_000_000
assert cost == pytest.approx(expected, abs=0.0001)
# =========================================================================
@ -660,8 +633,13 @@ class TestEdgeCases:
def test_mixed_commercial_and_custom_models(self, db):
"""Mix of commercial and custom models: only commercial ones get costs."""
db.create_session(session_id="s1", source="cli", model="gpt-4o")
db.update_token_counts("s1", input_tokens=10000, output_tokens=5000)
db.create_session(session_id="s1", source="cli", model="anthropic/claude-sonnet-4-20250514")
db.update_token_counts(
"s1",
input_tokens=10000,
output_tokens=5000,
billing_provider="anthropic",
)
db.create_session(session_id="s2", source="cli", model="my-local-llama")
db.update_token_counts("s2", input_tokens=10000, output_tokens=5000)
db._conn.commit()
@ -672,13 +650,13 @@ class TestEdgeCases:
# Cost should only come from gpt-4o, not from the custom model
overview = report["overview"]
assert overview["estimated_cost"] > 0
assert "gpt-4o" in overview["models_with_pricing"] # list now, not set
assert "claude-sonnet-4-20250514" in overview["models_with_pricing"] # list now, not set
assert "my-local-llama" in overview["models_without_pricing"]
# Verify individual model entries
gpt = next(m for m in report["models"] if m["model"] == "gpt-4o")
assert gpt["has_pricing"] is True
assert gpt["cost"] > 0
claude = next(m for m in report["models"] if m["model"] == "claude-sonnet-4-20250514")
assert claude["has_pricing"] is True
assert claude["cost"] > 0
llama = next(m for m in report["models"] if m["model"] == "my-local-llama")
assert llama["has_pricing"] is False

View file

@ -57,6 +57,7 @@ def main() -> int:
parent._interrupt_requested = False
parent._interrupt_message = None
parent._active_children = []
parent._active_children_lock = threading.Lock()
parent.quiet_mode = True
parent.model = "test/model"
parent.base_url = "http://localhost:1"

View file

@ -30,12 +30,14 @@ class TestInterruptPropagationToChild(unittest.TestCase):
parent._interrupt_requested = False
parent._interrupt_message = None
parent._active_children = []
parent._active_children_lock = threading.Lock()
parent.quiet_mode = True
child = AIAgent.__new__(AIAgent)
child._interrupt_requested = False
child._interrupt_message = None
child._active_children = []
child._active_children_lock = threading.Lock()
child.quiet_mode = True
parent._active_children.append(child)
@ -60,6 +62,7 @@ class TestInterruptPropagationToChild(unittest.TestCase):
child._interrupt_message = "msg"
child.quiet_mode = True
child._active_children = []
child._active_children_lock = threading.Lock()
# Global is set
set_interrupt(True)
@ -78,6 +81,7 @@ class TestInterruptPropagationToChild(unittest.TestCase):
child._interrupt_requested = False
child._interrupt_message = None
child._active_children = []
child._active_children_lock = threading.Lock()
child.quiet_mode = True
child.api_mode = "chat_completions"
child.log_prefix = ""
@ -119,12 +123,14 @@ class TestInterruptPropagationToChild(unittest.TestCase):
parent._interrupt_requested = False
parent._interrupt_message = None
parent._active_children = []
parent._active_children_lock = threading.Lock()
parent.quiet_mode = True
child = AIAgent.__new__(AIAgent)
child._interrupt_requested = False
child._interrupt_message = None
child._active_children = []
child._active_children_lock = threading.Lock()
child.quiet_mode = True
# Register child (simulating what _run_single_child does)

View file

@ -47,6 +47,28 @@ class TestCLIQuickCommands:
args = cli.console.print.call_args[0][0]
assert "no output" in args.lower()
def test_alias_command_routes_to_target(self):
"""Alias quick commands rewrite to the target command."""
cli = self._make_cli({"shortcut": {"type": "alias", "target": "/help"}})
with patch.object(cli, "process_command", wraps=cli.process_command) as spy:
cli.process_command("/shortcut")
# Should recursively call process_command with /help
spy.assert_any_call("/help")
def test_alias_command_passes_args(self):
"""Alias quick commands forward user arguments to the target."""
cli = self._make_cli({"sc": {"type": "alias", "target": "/context"}})
with patch.object(cli, "process_command", wraps=cli.process_command) as spy:
cli.process_command("/sc some args")
spy.assert_any_call("/context some args")
def test_alias_no_target_shows_error(self):
cli = self._make_cli({"broken": {"type": "alias", "target": ""}})
cli.process_command("/broken")
cli.console.print.assert_called_once()
args = cli.console.print.call_args[0][0]
assert "no target defined" in args.lower()
def test_unsupported_type_shows_error(self):
cli = self._make_cli({"bad": {"type": "prompt", "command": "echo hi"}})
cli.process_command("/bad")

View file

@ -55,6 +55,7 @@ class TestRealSubagentInterrupt(unittest.TestCase):
parent._interrupt_requested = False
parent._interrupt_message = None
parent._active_children = []
parent._active_children_lock = threading.Lock()
parent.quiet_mode = True
parent.model = "test/model"
parent.base_url = "http://localhost:1"
@ -103,19 +104,28 @@ class TestRealSubagentInterrupt(unittest.TestCase):
return original_run(self_agent, *args, **kwargs)
with patch.object(AIAgent, 'run_conversation', patched_run):
# Build a real child agent (AIAgent is NOT patched here,
# only run_conversation and _build_system_prompt are)
child = AIAgent(
base_url="http://localhost:1",
api_key="test-key",
model="test/model",
provider="test",
api_mode="chat_completions",
max_iterations=5,
enabled_toolsets=["terminal"],
quiet_mode=True,
skip_context_files=True,
skip_memory=True,
platform="cli",
)
child._delegate_depth = 1
parent._active_children.append(child)
result = _run_single_child(
task_index=0,
goal="Test task",
context=None,
toolsets=["terminal"],
model="test/model",
max_iterations=5,
child=child,
parent_agent=parent,
task_count=1,
override_provider="test",
override_base_url="http://localhost:1",
override_api_key="test",
override_api_mode="chat_completions",
)
result_holder[0] = result
except Exception as e:

View file

@ -12,6 +12,7 @@ Run with: python -m pytest tests/test_delegate.py -v
import json
import os
import sys
import threading
import unittest
from unittest.mock import MagicMock, patch
@ -44,6 +45,7 @@ def _make_mock_parent(depth=0):
parent._session_db = None
parent._delegate_depth = depth
parent._active_children = []
parent._active_children_lock = threading.Lock()
return parent
@ -722,7 +724,12 @@ class TestDelegationProviderIntegration(unittest.TestCase):
}
parent = _make_mock_parent(depth=0)
with patch("tools.delegate_tool._run_single_child") as mock_run:
# Patch _build_child_agent since credentials are now passed there
# (agents are built in the main thread before being handed to workers)
with patch("tools.delegate_tool._build_child_agent") as mock_build, \
patch("tools.delegate_tool._run_single_child") as mock_run:
mock_child = MagicMock()
mock_build.return_value = mock_child
mock_run.return_value = {
"task_index": 0, "status": "completed",
"summary": "Done", "api_calls": 1, "duration_seconds": 1.0
@ -731,7 +738,8 @@ class TestDelegationProviderIntegration(unittest.TestCase):
tasks = [{"goal": "Task A"}, {"goal": "Task B"}]
delegate_task(tasks=tasks, parent_agent=parent)
for call in mock_run.call_args_list:
self.assertEqual(mock_build.call_count, 2)
for call in mock_build.call_args_list:
self.assertEqual(call.kwargs.get("model"), "meta-llama/llama-4-scout")
self.assertEqual(call.kwargs.get("override_provider"), "openrouter")
self.assertEqual(call.kwargs.get("override_base_url"), "https://openrouter.ai/api/v1")

View file

@ -0,0 +1,210 @@
"""Tests for probe_mcp_server_tools() in tools.mcp_tool."""
import asyncio
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@pytest.fixture(autouse=True)
def _reset_mcp_state():
"""Ensure clean MCP module state before/after each test."""
import tools.mcp_tool as mcp
old_loop = mcp._mcp_loop
old_thread = mcp._mcp_thread
old_servers = dict(mcp._servers)
yield
mcp._servers.clear()
mcp._servers.update(old_servers)
mcp._mcp_loop = old_loop
mcp._mcp_thread = old_thread
class TestProbeMcpServerTools:
"""Tests for the lightweight probe_mcp_server_tools function."""
def test_returns_empty_when_mcp_not_available(self):
with patch("tools.mcp_tool._MCP_AVAILABLE", False):
from tools.mcp_tool import probe_mcp_server_tools
result = probe_mcp_server_tools()
assert result == {}
def test_returns_empty_when_no_config(self):
with patch("tools.mcp_tool._load_mcp_config", return_value={}):
from tools.mcp_tool import probe_mcp_server_tools
result = probe_mcp_server_tools()
assert result == {}
def test_returns_empty_when_all_servers_disabled(self):
config = {
"github": {"command": "npx", "enabled": False},
"slack": {"command": "npx", "enabled": "off"},
}
with patch("tools.mcp_tool._load_mcp_config", return_value=config):
from tools.mcp_tool import probe_mcp_server_tools
result = probe_mcp_server_tools()
assert result == {}
def test_returns_tools_from_successful_server(self):
"""Successfully probed server returns its tools list."""
config = {
"github": {"command": "npx", "connect_timeout": 5},
}
mock_tool_1 = SimpleNamespace(name="create_issue", description="Create a new issue")
mock_tool_2 = SimpleNamespace(name="search_repos", description="Search repositories")
mock_server = MagicMock()
mock_server._tools = [mock_tool_1, mock_tool_2]
mock_server.shutdown = AsyncMock()
async def fake_connect(name, cfg):
return mock_server
with patch("tools.mcp_tool._load_mcp_config", return_value=config), \
patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
patch("tools.mcp_tool._ensure_mcp_loop"), \
patch("tools.mcp_tool._run_on_mcp_loop") as mock_run, \
patch("tools.mcp_tool._stop_mcp_loop"):
# Simulate running the async probe
def run_coro(coro, timeout=120):
loop = asyncio.new_event_loop()
try:
return loop.run_until_complete(coro)
finally:
loop.close()
mock_run.side_effect = run_coro
from tools.mcp_tool import probe_mcp_server_tools
result = probe_mcp_server_tools()
assert "github" in result
assert len(result["github"]) == 2
assert result["github"][0] == ("create_issue", "Create a new issue")
assert result["github"][1] == ("search_repos", "Search repositories")
mock_server.shutdown.assert_awaited_once()
def test_failed_server_omitted_from_results(self):
"""Servers that fail to connect are silently skipped."""
config = {
"github": {"command": "npx", "connect_timeout": 5},
"broken": {"command": "nonexistent", "connect_timeout": 5},
}
mock_tool = SimpleNamespace(name="create_issue", description="Create")
mock_server = MagicMock()
mock_server._tools = [mock_tool]
mock_server.shutdown = AsyncMock()
async def fake_connect(name, cfg):
if name == "broken":
raise ConnectionError("Server not found")
return mock_server
with patch("tools.mcp_tool._load_mcp_config", return_value=config), \
patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
patch("tools.mcp_tool._ensure_mcp_loop"), \
patch("tools.mcp_tool._run_on_mcp_loop") as mock_run, \
patch("tools.mcp_tool._stop_mcp_loop"):
def run_coro(coro, timeout=120):
loop = asyncio.new_event_loop()
try:
return loop.run_until_complete(coro)
finally:
loop.close()
mock_run.side_effect = run_coro
from tools.mcp_tool import probe_mcp_server_tools
result = probe_mcp_server_tools()
assert "github" in result
assert "broken" not in result
def test_handles_tool_without_description(self):
"""Tools without descriptions get empty string."""
config = {"github": {"command": "npx", "connect_timeout": 5}}
mock_tool = SimpleNamespace(name="my_tool") # no description attribute
mock_server = MagicMock()
mock_server._tools = [mock_tool]
mock_server.shutdown = AsyncMock()
async def fake_connect(name, cfg):
return mock_server
with patch("tools.mcp_tool._load_mcp_config", return_value=config), \
patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
patch("tools.mcp_tool._ensure_mcp_loop"), \
patch("tools.mcp_tool._run_on_mcp_loop") as mock_run, \
patch("tools.mcp_tool._stop_mcp_loop"):
def run_coro(coro, timeout=120):
loop = asyncio.new_event_loop()
try:
return loop.run_until_complete(coro)
finally:
loop.close()
mock_run.side_effect = run_coro
from tools.mcp_tool import probe_mcp_server_tools
result = probe_mcp_server_tools()
assert result["github"][0] == ("my_tool", "")
def test_cleanup_called_even_on_failure(self):
"""_stop_mcp_loop is called even when probe fails."""
config = {"github": {"command": "npx", "connect_timeout": 5}}
with patch("tools.mcp_tool._load_mcp_config", return_value=config), \
patch("tools.mcp_tool._ensure_mcp_loop"), \
patch("tools.mcp_tool._run_on_mcp_loop", side_effect=RuntimeError("boom")), \
patch("tools.mcp_tool._stop_mcp_loop") as mock_stop:
from tools.mcp_tool import probe_mcp_server_tools
result = probe_mcp_server_tools()
assert result == {}
mock_stop.assert_called_once()
def test_skips_disabled_servers(self):
"""Disabled servers are not probed."""
config = {
"github": {"command": "npx", "connect_timeout": 5},
"disabled_one": {"command": "npx", "enabled": False},
}
mock_tool = SimpleNamespace(name="create_issue", description="Create")
mock_server = MagicMock()
mock_server._tools = [mock_tool]
mock_server.shutdown = AsyncMock()
connect_calls = []
async def fake_connect(name, cfg):
connect_calls.append(name)
return mock_server
with patch("tools.mcp_tool._load_mcp_config", return_value=config), \
patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
patch("tools.mcp_tool._ensure_mcp_loop"), \
patch("tools.mcp_tool._run_on_mcp_loop") as mock_run, \
patch("tools.mcp_tool._stop_mcp_loop"):
def run_coro(coro, timeout=120):
loop = asyncio.new_event_loop()
try:
return loop.run_until_complete(coro)
finally:
loop.close()
mock_run.side_effect = run_coro
from tools.mcp_tool import probe_mcp_server_tools
result = probe_mcp_server_tools()
assert "github" in result
assert "disabled_one" not in result
assert "disabled_one" not in connect_calls

View file

@ -2596,17 +2596,19 @@ class TestMCPSelectiveToolLoading:
async def run():
with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
patch.dict("tools.mcp_tool._servers", {}, clear=True), \
patch("tools.registry.registry", mock_registry), \
patch("toolsets.create_custom_toolset"):
return await _discover_and_register_server(
registered = await _discover_and_register_server(
"ink_existing",
{"url": "https://mcp.example.com", "tools": {"include": ["create_service"]}},
)
return registered, _existing_tool_names()
try:
registered = asyncio.run(run())
registered, existing = asyncio.run(run())
assert registered == ["mcp_ink_existing_create_service"]
assert _existing_tool_names() == ["mcp_ink_existing_create_service"]
assert existing == ["mcp_ink_existing_create_service"]
finally:
_servers.pop("ink_existing", None)

View file

@ -294,6 +294,61 @@ class TestCheckpoint:
recovered = registry.recover_from_checkpoint()
assert recovered == 0
def test_write_checkpoint_includes_watcher_metadata(self, registry, tmp_path):
with patch("tools.process_registry.CHECKPOINT_PATH", tmp_path / "procs.json"):
s = _make_session()
s.watcher_platform = "telegram"
s.watcher_chat_id = "999"
s.watcher_thread_id = "42"
s.watcher_interval = 60
registry._running[s.id] = s
registry._write_checkpoint()
data = json.loads((tmp_path / "procs.json").read_text())
assert len(data) == 1
assert data[0]["watcher_platform"] == "telegram"
assert data[0]["watcher_chat_id"] == "999"
assert data[0]["watcher_thread_id"] == "42"
assert data[0]["watcher_interval"] == 60
def test_recover_enqueues_watchers(self, registry, tmp_path):
checkpoint = tmp_path / "procs.json"
checkpoint.write_text(json.dumps([{
"session_id": "proc_live",
"command": "sleep 999",
"pid": os.getpid(), # current process — guaranteed alive
"task_id": "t1",
"session_key": "sk1",
"watcher_platform": "telegram",
"watcher_chat_id": "123",
"watcher_thread_id": "42",
"watcher_interval": 60,
}]))
with patch("tools.process_registry.CHECKPOINT_PATH", checkpoint):
recovered = registry.recover_from_checkpoint()
assert recovered == 1
assert len(registry.pending_watchers) == 1
w = registry.pending_watchers[0]
assert w["session_id"] == "proc_live"
assert w["platform"] == "telegram"
assert w["chat_id"] == "123"
assert w["thread_id"] == "42"
assert w["check_interval"] == 60
def test_recover_skips_watcher_when_no_interval(self, registry, tmp_path):
checkpoint = tmp_path / "procs.json"
checkpoint.write_text(json.dumps([{
"session_id": "proc_live",
"command": "sleep 999",
"pid": os.getpid(),
"task_id": "t1",
"watcher_interval": 0,
}]))
with patch("tools.process_registry.CHECKPOINT_PATH", checkpoint):
recovered = registry.recover_from_checkpoint()
assert recovered == 1
assert len(registry.pending_watchers) == 0
# =========================================================================
# Kill process

View file

@ -25,7 +25,7 @@ def _make_config():
def _install_telegram_mock(monkeypatch, bot):
parse_mode = SimpleNamespace(MARKDOWN_V2="MarkdownV2")
parse_mode = SimpleNamespace(MARKDOWN_V2="MarkdownV2", HTML="HTML")
constants_mod = SimpleNamespace(ParseMode=parse_mode)
telegram_mod = SimpleNamespace(Bot=lambda token: bot, constants=constants_mod)
monkeypatch.setitem(sys.modules, "telegram", telegram_mod)
@ -391,3 +391,97 @@ class TestSendToPlatformChunking:
assert len(sent_calls) >= 3
assert all(call == [] for call in sent_calls[:-1])
assert sent_calls[-1] == media
# ---------------------------------------------------------------------------
# HTML auto-detection in Telegram send
# ---------------------------------------------------------------------------
class TestSendTelegramHtmlDetection:
"""Verify that messages containing HTML tags are sent with parse_mode=HTML
and that plain / markdown messages use MarkdownV2."""
def _make_bot(self):
bot = MagicMock()
bot.send_message = AsyncMock(return_value=SimpleNamespace(message_id=1))
bot.send_photo = AsyncMock()
bot.send_video = AsyncMock()
bot.send_voice = AsyncMock()
bot.send_audio = AsyncMock()
bot.send_document = AsyncMock()
return bot
def test_html_message_uses_html_parse_mode(self, monkeypatch):
bot = self._make_bot()
_install_telegram_mock(monkeypatch, bot)
asyncio.run(
_send_telegram("tok", "123", "<b>Hello</b> world")
)
bot.send_message.assert_awaited_once()
kwargs = bot.send_message.await_args.kwargs
assert kwargs["parse_mode"] == "HTML"
assert kwargs["text"] == "<b>Hello</b> world"
def test_plain_text_uses_markdown_v2(self, monkeypatch):
bot = self._make_bot()
_install_telegram_mock(monkeypatch, bot)
asyncio.run(
_send_telegram("tok", "123", "Just plain text, no tags")
)
bot.send_message.assert_awaited_once()
kwargs = bot.send_message.await_args.kwargs
assert kwargs["parse_mode"] == "MarkdownV2"
def test_html_with_code_and_pre_tags(self, monkeypatch):
bot = self._make_bot()
_install_telegram_mock(monkeypatch, bot)
html = "<pre>code block</pre> and <code>inline</code>"
asyncio.run(_send_telegram("tok", "123", html))
kwargs = bot.send_message.await_args.kwargs
assert kwargs["parse_mode"] == "HTML"
def test_closing_tag_detected(self, monkeypatch):
bot = self._make_bot()
_install_telegram_mock(monkeypatch, bot)
asyncio.run(_send_telegram("tok", "123", "text </div> more"))
kwargs = bot.send_message.await_args.kwargs
assert kwargs["parse_mode"] == "HTML"
def test_angle_brackets_in_math_not_detected(self, monkeypatch):
"""Expressions like 'x < 5' or '3 > 2' should not trigger HTML mode."""
bot = self._make_bot()
_install_telegram_mock(monkeypatch, bot)
asyncio.run(_send_telegram("tok", "123", "if x < 5 then y > 2"))
kwargs = bot.send_message.await_args.kwargs
assert kwargs["parse_mode"] == "MarkdownV2"
def test_html_parse_failure_falls_back_to_plain(self, monkeypatch):
"""If Telegram rejects the HTML, fall back to plain text."""
bot = self._make_bot()
bot.send_message = AsyncMock(
side_effect=[
Exception("Bad Request: can't parse entities: unsupported html tag"),
SimpleNamespace(message_id=2), # plain fallback succeeds
]
)
_install_telegram_mock(monkeypatch, bot)
result = asyncio.run(
_send_telegram("tok", "123", "<invalid>broken html</invalid>")
)
assert result["success"] is True
assert bot.send_message.await_count == 2
second_call = bot.send_message.await_args_list[1].kwargs
assert second_call["parse_mode"] is None

View file

@ -1,8 +1,11 @@
"""Tests for Firecrawl client configuration and singleton behavior.
"""Tests for web backend client configuration and singleton behavior.
Coverage:
_get_firecrawl_client() configuration matrix, singleton caching,
constructor failure recovery, return value verification, edge cases.
_get_backend() backend selection logic with env var combinations.
_get_parallel_client() Parallel client configuration, singleton caching.
check_web_api_key() unified availability check.
"""
import os
@ -117,3 +120,212 @@ class TestFirecrawlClientConfig:
from tools.web_tools import _get_firecrawl_client
with pytest.raises(ValueError):
_get_firecrawl_client()
class TestBackendSelection:
"""Test suite for _get_backend() backend selection logic.
The backend is configured via config.yaml (web.backend), set by
``hermes tools``. Falls back to key-based detection for legacy/manual
setups.
"""
_ENV_KEYS = ("PARALLEL_API_KEY", "FIRECRAWL_API_KEY", "FIRECRAWL_API_URL", "TAVILY_API_KEY")
def setup_method(self):
for key in self._ENV_KEYS:
os.environ.pop(key, None)
def teardown_method(self):
for key in self._ENV_KEYS:
os.environ.pop(key, None)
# ── Config-based selection (web.backend in config.yaml) ───────────
def test_config_parallel(self):
"""web.backend=parallel in config → 'parallel' regardless of keys."""
from tools.web_tools import _get_backend
with patch("tools.web_tools._load_web_config", return_value={"backend": "parallel"}):
assert _get_backend() == "parallel"
def test_config_firecrawl(self):
"""web.backend=firecrawl in config → 'firecrawl' even if Parallel key set."""
from tools.web_tools import _get_backend
with patch("tools.web_tools._load_web_config", return_value={"backend": "firecrawl"}), \
patch.dict(os.environ, {"PARALLEL_API_KEY": "test-key"}):
assert _get_backend() == "firecrawl"
def test_config_tavily(self):
"""web.backend=tavily in config → 'tavily' regardless of other keys."""
from tools.web_tools import _get_backend
with patch("tools.web_tools._load_web_config", return_value={"backend": "tavily"}):
assert _get_backend() == "tavily"
def test_config_tavily_overrides_env_keys(self):
"""web.backend=tavily in config → 'tavily' even if Firecrawl key set."""
from tools.web_tools import _get_backend
with patch("tools.web_tools._load_web_config", return_value={"backend": "tavily"}), \
patch.dict(os.environ, {"FIRECRAWL_API_KEY": "fc-test"}):
assert _get_backend() == "tavily"
def test_config_case_insensitive(self):
"""web.backend=Parallel (mixed case) → 'parallel'."""
from tools.web_tools import _get_backend
with patch("tools.web_tools._load_web_config", return_value={"backend": "Parallel"}):
assert _get_backend() == "parallel"
def test_config_tavily_case_insensitive(self):
"""web.backend=Tavily (mixed case) → 'tavily'."""
from tools.web_tools import _get_backend
with patch("tools.web_tools._load_web_config", return_value={"backend": "Tavily"}):
assert _get_backend() == "tavily"
# ── Fallback (no web.backend in config) ───────────────────────────
def test_fallback_parallel_only_key(self):
"""Only PARALLEL_API_KEY set → 'parallel'."""
from tools.web_tools import _get_backend
with patch("tools.web_tools._load_web_config", return_value={}), \
patch.dict(os.environ, {"PARALLEL_API_KEY": "test-key"}):
assert _get_backend() == "parallel"
def test_fallback_tavily_only_key(self):
"""Only TAVILY_API_KEY set → 'tavily'."""
from tools.web_tools import _get_backend
with patch("tools.web_tools._load_web_config", return_value={}), \
patch.dict(os.environ, {"TAVILY_API_KEY": "tvly-test"}):
assert _get_backend() == "tavily"
def test_fallback_tavily_with_firecrawl_prefers_firecrawl(self):
"""Tavily + Firecrawl keys, no config → 'firecrawl' (backward compat)."""
from tools.web_tools import _get_backend
with patch("tools.web_tools._load_web_config", return_value={}), \
patch.dict(os.environ, {"TAVILY_API_KEY": "tvly-test", "FIRECRAWL_API_KEY": "fc-test"}):
assert _get_backend() == "firecrawl"
def test_fallback_tavily_with_parallel_prefers_parallel(self):
"""Tavily + Parallel keys, no config → 'parallel' (Parallel takes priority over Tavily)."""
from tools.web_tools import _get_backend
with patch("tools.web_tools._load_web_config", return_value={}), \
patch.dict(os.environ, {"TAVILY_API_KEY": "tvly-test", "PARALLEL_API_KEY": "par-test"}):
# Parallel + no Firecrawl → parallel
assert _get_backend() == "parallel"
def test_fallback_both_keys_defaults_to_firecrawl(self):
"""Both keys set, no config → 'firecrawl' (backward compat)."""
from tools.web_tools import _get_backend
with patch("tools.web_tools._load_web_config", return_value={}), \
patch.dict(os.environ, {"PARALLEL_API_KEY": "test-key", "FIRECRAWL_API_KEY": "fc-test"}):
assert _get_backend() == "firecrawl"
def test_fallback_firecrawl_only_key(self):
"""Only FIRECRAWL_API_KEY set → 'firecrawl'."""
from tools.web_tools import _get_backend
with patch("tools.web_tools._load_web_config", return_value={}), \
patch.dict(os.environ, {"FIRECRAWL_API_KEY": "fc-test"}):
assert _get_backend() == "firecrawl"
def test_fallback_no_keys_defaults_to_firecrawl(self):
"""No keys, no config → 'firecrawl' (will fail at client init)."""
from tools.web_tools import _get_backend
with patch("tools.web_tools._load_web_config", return_value={}):
assert _get_backend() == "firecrawl"
def test_invalid_config_falls_through_to_fallback(self):
"""web.backend=invalid → ignored, uses key-based fallback."""
from tools.web_tools import _get_backend
with patch("tools.web_tools._load_web_config", return_value={"backend": "nonexistent"}), \
patch.dict(os.environ, {"PARALLEL_API_KEY": "test-key"}):
assert _get_backend() == "parallel"
class TestParallelClientConfig:
"""Test suite for Parallel client initialization."""
def setup_method(self):
import tools.web_tools
tools.web_tools._parallel_client = None
os.environ.pop("PARALLEL_API_KEY", None)
def teardown_method(self):
import tools.web_tools
tools.web_tools._parallel_client = None
os.environ.pop("PARALLEL_API_KEY", None)
def test_creates_client_with_key(self):
"""PARALLEL_API_KEY set → creates Parallel client."""
with patch.dict(os.environ, {"PARALLEL_API_KEY": "test-key"}):
from tools.web_tools import _get_parallel_client
from parallel import Parallel
client = _get_parallel_client()
assert client is not None
assert isinstance(client, Parallel)
def test_no_key_raises_with_helpful_message(self):
"""No PARALLEL_API_KEY → ValueError with guidance."""
from tools.web_tools import _get_parallel_client
with pytest.raises(ValueError, match="PARALLEL_API_KEY"):
_get_parallel_client()
def test_singleton_returns_same_instance(self):
"""Second call returns cached client."""
with patch.dict(os.environ, {"PARALLEL_API_KEY": "test-key"}):
from tools.web_tools import _get_parallel_client
client1 = _get_parallel_client()
client2 = _get_parallel_client()
assert client1 is client2
class TestCheckWebApiKey:
"""Test suite for check_web_api_key() unified availability check."""
_ENV_KEYS = ("PARALLEL_API_KEY", "FIRECRAWL_API_KEY", "FIRECRAWL_API_URL", "TAVILY_API_KEY")
def setup_method(self):
for key in self._ENV_KEYS:
os.environ.pop(key, None)
def teardown_method(self):
for key in self._ENV_KEYS:
os.environ.pop(key, None)
def test_parallel_key_only(self):
with patch.dict(os.environ, {"PARALLEL_API_KEY": "test-key"}):
from tools.web_tools import check_web_api_key
assert check_web_api_key() is True
def test_firecrawl_key_only(self):
with patch.dict(os.environ, {"FIRECRAWL_API_KEY": "fc-test"}):
from tools.web_tools import check_web_api_key
assert check_web_api_key() is True
def test_firecrawl_url_only(self):
with patch.dict(os.environ, {"FIRECRAWL_API_URL": "http://localhost:3002"}):
from tools.web_tools import check_web_api_key
assert check_web_api_key() is True
def test_tavily_key_only(self):
with patch.dict(os.environ, {"TAVILY_API_KEY": "tvly-test"}):
from tools.web_tools import check_web_api_key
assert check_web_api_key() is True
def test_no_keys_returns_false(self):
from tools.web_tools import check_web_api_key
assert check_web_api_key() is False
def test_both_keys_returns_true(self):
with patch.dict(os.environ, {
"PARALLEL_API_KEY": "test-key",
"FIRECRAWL_API_KEY": "fc-test",
}):
from tools.web_tools import check_web_api_key
assert check_web_api_key() is True
def test_all_three_keys_returns_true(self):
with patch.dict(os.environ, {
"PARALLEL_API_KEY": "test-key",
"FIRECRAWL_API_KEY": "fc-test",
"TAVILY_API_KEY": "tvly-test",
}):
from tools.web_tools import check_web_api_key
assert check_web_api_key() is True

View file

@ -0,0 +1,255 @@
"""Tests for Tavily web backend integration.
Coverage:
_tavily_request() API key handling, endpoint construction, error propagation.
_normalize_tavily_search_results() search response normalization.
_normalize_tavily_documents() extract/crawl response normalization, failed_results.
web_search_tool / web_extract_tool / web_crawl_tool Tavily dispatch paths.
"""
import json
import os
import asyncio
import pytest
from unittest.mock import patch, MagicMock
# ─── _tavily_request ─────────────────────────────────────────────────────────
class TestTavilyRequest:
"""Test suite for the _tavily_request helper."""
def test_raises_without_api_key(self):
"""No TAVILY_API_KEY → ValueError with guidance."""
with patch.dict(os.environ, {}, clear=False):
os.environ.pop("TAVILY_API_KEY", None)
from tools.web_tools import _tavily_request
with pytest.raises(ValueError, match="TAVILY_API_KEY"):
_tavily_request("search", {"query": "test"})
def test_posts_with_api_key_in_body(self):
"""api_key is injected into the JSON payload."""
mock_response = MagicMock()
mock_response.json.return_value = {"results": []}
mock_response.raise_for_status = MagicMock()
with patch.dict(os.environ, {"TAVILY_API_KEY": "tvly-test-key"}):
with patch("tools.web_tools.httpx.post", return_value=mock_response) as mock_post:
from tools.web_tools import _tavily_request
result = _tavily_request("search", {"query": "hello"})
mock_post.assert_called_once()
call_kwargs = mock_post.call_args
payload = call_kwargs.kwargs.get("json") or call_kwargs[1].get("json")
assert payload["api_key"] == "tvly-test-key"
assert payload["query"] == "hello"
assert "api.tavily.com/search" in call_kwargs.args[0]
def test_raises_on_http_error(self):
"""Non-2xx responses propagate as httpx.HTTPStatusError."""
import httpx as _httpx
mock_response = MagicMock()
mock_response.raise_for_status.side_effect = _httpx.HTTPStatusError(
"401 Unauthorized", request=MagicMock(), response=mock_response
)
with patch.dict(os.environ, {"TAVILY_API_KEY": "tvly-bad-key"}):
with patch("tools.web_tools.httpx.post", return_value=mock_response):
from tools.web_tools import _tavily_request
with pytest.raises(_httpx.HTTPStatusError):
_tavily_request("search", {"query": "test"})
# ─── _normalize_tavily_search_results ─────────────────────────────────────────
class TestNormalizeTavilySearchResults:
"""Test search result normalization."""
def test_basic_normalization(self):
from tools.web_tools import _normalize_tavily_search_results
raw = {
"results": [
{"title": "Python Docs", "url": "https://docs.python.org", "content": "Official docs", "score": 0.9},
{"title": "Tutorial", "url": "https://example.com", "content": "A tutorial", "score": 0.8},
]
}
result = _normalize_tavily_search_results(raw)
assert result["success"] is True
web = result["data"]["web"]
assert len(web) == 2
assert web[0]["title"] == "Python Docs"
assert web[0]["url"] == "https://docs.python.org"
assert web[0]["description"] == "Official docs"
assert web[0]["position"] == 1
assert web[1]["position"] == 2
def test_empty_results(self):
from tools.web_tools import _normalize_tavily_search_results
result = _normalize_tavily_search_results({"results": []})
assert result["success"] is True
assert result["data"]["web"] == []
def test_missing_fields(self):
from tools.web_tools import _normalize_tavily_search_results
result = _normalize_tavily_search_results({"results": [{}]})
web = result["data"]["web"]
assert web[0]["title"] == ""
assert web[0]["url"] == ""
assert web[0]["description"] == ""
# ─── _normalize_tavily_documents ──────────────────────────────────────────────
class TestNormalizeTavilyDocuments:
"""Test extract/crawl document normalization."""
def test_basic_document(self):
from tools.web_tools import _normalize_tavily_documents
raw = {
"results": [{
"url": "https://example.com",
"title": "Example",
"raw_content": "Full page content here",
}]
}
docs = _normalize_tavily_documents(raw)
assert len(docs) == 1
assert docs[0]["url"] == "https://example.com"
assert docs[0]["title"] == "Example"
assert docs[0]["content"] == "Full page content here"
assert docs[0]["raw_content"] == "Full page content here"
assert docs[0]["metadata"]["sourceURL"] == "https://example.com"
def test_falls_back_to_content_when_no_raw_content(self):
from tools.web_tools import _normalize_tavily_documents
raw = {"results": [{"url": "https://example.com", "content": "Snippet"}]}
docs = _normalize_tavily_documents(raw)
assert docs[0]["content"] == "Snippet"
def test_failed_results_included(self):
from tools.web_tools import _normalize_tavily_documents
raw = {
"results": [],
"failed_results": [
{"url": "https://fail.com", "error": "timeout"},
],
}
docs = _normalize_tavily_documents(raw)
assert len(docs) == 1
assert docs[0]["url"] == "https://fail.com"
assert docs[0]["error"] == "timeout"
assert docs[0]["content"] == ""
def test_failed_urls_included(self):
from tools.web_tools import _normalize_tavily_documents
raw = {
"results": [],
"failed_urls": ["https://bad.com"],
}
docs = _normalize_tavily_documents(raw)
assert len(docs) == 1
assert docs[0]["url"] == "https://bad.com"
assert docs[0]["error"] == "extraction failed"
def test_fallback_url(self):
from tools.web_tools import _normalize_tavily_documents
raw = {"results": [{"content": "data"}]}
docs = _normalize_tavily_documents(raw, fallback_url="https://fallback.com")
assert docs[0]["url"] == "https://fallback.com"
# ─── web_search_tool (Tavily dispatch) ────────────────────────────────────────
class TestWebSearchTavily:
"""Test web_search_tool dispatch to Tavily."""
def test_search_dispatches_to_tavily(self):
mock_response = MagicMock()
mock_response.json.return_value = {
"results": [{"title": "Result", "url": "https://r.com", "content": "desc", "score": 0.9}]
}
mock_response.raise_for_status = MagicMock()
with patch("tools.web_tools._get_backend", return_value="tavily"), \
patch.dict(os.environ, {"TAVILY_API_KEY": "tvly-test"}), \
patch("tools.web_tools.httpx.post", return_value=mock_response), \
patch("tools.interrupt.is_interrupted", return_value=False):
from tools.web_tools import web_search_tool
result = json.loads(web_search_tool("test query", limit=3))
assert result["success"] is True
assert len(result["data"]["web"]) == 1
assert result["data"]["web"][0]["title"] == "Result"
# ─── web_extract_tool (Tavily dispatch) ───────────────────────────────────────
class TestWebExtractTavily:
"""Test web_extract_tool dispatch to Tavily."""
def test_extract_dispatches_to_tavily(self):
mock_response = MagicMock()
mock_response.json.return_value = {
"results": [{"url": "https://example.com", "raw_content": "Extracted content", "title": "Page"}]
}
mock_response.raise_for_status = MagicMock()
with patch("tools.web_tools._get_backend", return_value="tavily"), \
patch.dict(os.environ, {"TAVILY_API_KEY": "tvly-test"}), \
patch("tools.web_tools.httpx.post", return_value=mock_response), \
patch("tools.web_tools.process_content_with_llm", return_value=None):
from tools.web_tools import web_extract_tool
result = json.loads(asyncio.get_event_loop().run_until_complete(
web_extract_tool(["https://example.com"], use_llm_processing=False)
))
assert "results" in result
assert len(result["results"]) == 1
assert result["results"][0]["url"] == "https://example.com"
# ─── web_crawl_tool (Tavily dispatch) ─────────────────────────────────────────
class TestWebCrawlTavily:
"""Test web_crawl_tool dispatch to Tavily."""
def test_crawl_dispatches_to_tavily(self):
mock_response = MagicMock()
mock_response.json.return_value = {
"results": [
{"url": "https://example.com/page1", "raw_content": "Page 1 content", "title": "Page 1"},
{"url": "https://example.com/page2", "raw_content": "Page 2 content", "title": "Page 2"},
]
}
mock_response.raise_for_status = MagicMock()
with patch("tools.web_tools._get_backend", return_value="tavily"), \
patch.dict(os.environ, {"TAVILY_API_KEY": "tvly-test"}), \
patch("tools.web_tools.httpx.post", return_value=mock_response), \
patch("tools.web_tools.check_website_access", return_value=None), \
patch("tools.interrupt.is_interrupted", return_value=False):
from tools.web_tools import web_crawl_tool
result = json.loads(asyncio.get_event_loop().run_until_complete(
web_crawl_tool("https://example.com", use_llm_processing=False)
))
assert "results" in result
assert len(result["results"]) == 2
assert result["results"][0]["title"] == "Page 1"
def test_crawl_sends_instructions(self):
"""Instructions are included in the Tavily crawl payload."""
mock_response = MagicMock()
mock_response.json.return_value = {"results": []}
mock_response.raise_for_status = MagicMock()
with patch("tools.web_tools._get_backend", return_value="tavily"), \
patch.dict(os.environ, {"TAVILY_API_KEY": "tvly-test"}), \
patch("tools.web_tools.httpx.post", return_value=mock_response) as mock_post, \
patch("tools.web_tools.check_website_access", return_value=None), \
patch("tools.interrupt.is_interrupted", return_value=False):
from tools.web_tools import web_crawl_tool
asyncio.get_event_loop().run_until_complete(
web_crawl_tool("https://example.com", instructions="Find docs", use_llm_processing=False)
)
call_kwargs = mock_post.call_args
payload = call_kwargs.kwargs.get("json") or call_kwargs[1].get("json")
assert payload["instructions"] == "Find docs"
assert payload["url"] == "https://example.com"

View file

@ -0,0 +1,495 @@
import json
from pathlib import Path
import pytest
import yaml
from tools.website_policy import WebsitePolicyError, check_website_access, load_website_blocklist
def test_load_website_blocklist_merges_config_and_shared_file(tmp_path):
shared = tmp_path / "community-blocklist.txt"
shared.write_text("# comment\nexample.org\nsub.bad.net\n", encoding="utf-8")
config_path = tmp_path / "config.yaml"
config_path.write_text(
yaml.safe_dump(
{
"security": {
"website_blocklist": {
"enabled": True,
"domains": ["example.com", "https://www.evil.test/path"],
"shared_files": [str(shared)],
}
}
},
sort_keys=False,
),
encoding="utf-8",
)
policy = load_website_blocklist(config_path)
assert policy["enabled"] is True
assert {rule["pattern"] for rule in policy["rules"]} == {
"example.com",
"evil.test",
"example.org",
"sub.bad.net",
}
def test_check_website_access_matches_parent_domain_subdomains(tmp_path):
config_path = tmp_path / "config.yaml"
config_path.write_text(
yaml.safe_dump(
{
"security": {
"website_blocklist": {
"enabled": True,
"domains": ["example.com"],
}
}
},
sort_keys=False,
),
encoding="utf-8",
)
blocked = check_website_access("https://docs.example.com/page", config_path=config_path)
assert blocked is not None
assert blocked["host"] == "docs.example.com"
assert blocked["rule"] == "example.com"
def test_check_website_access_supports_wildcard_subdomains_only(tmp_path):
config_path = tmp_path / "config.yaml"
config_path.write_text(
yaml.safe_dump(
{
"security": {
"website_blocklist": {
"enabled": True,
"domains": ["*.tracking.example"],
}
}
},
sort_keys=False,
),
encoding="utf-8",
)
assert check_website_access("https://a.tracking.example", config_path=config_path) is not None
assert check_website_access("https://www.tracking.example", config_path=config_path) is not None
assert check_website_access("https://tracking.example", config_path=config_path) is None
def test_default_config_exposes_website_blocklist_shape():
from hermes_cli.config import DEFAULT_CONFIG
website_blocklist = DEFAULT_CONFIG["security"]["website_blocklist"]
assert website_blocklist["enabled"] is False
assert website_blocklist["domains"] == []
assert website_blocklist["shared_files"] == []
def test_load_website_blocklist_uses_enabled_default_when_section_missing(tmp_path):
config_path = tmp_path / "config.yaml"
config_path.write_text(yaml.safe_dump({"display": {"tool_progress": "all"}}, sort_keys=False), encoding="utf-8")
policy = load_website_blocklist(config_path)
assert policy == {"enabled": False, "rules": []}
def test_load_website_blocklist_raises_clean_error_for_invalid_domains_type(tmp_path):
config_path = tmp_path / "config.yaml"
config_path.write_text(
yaml.safe_dump(
{
"security": {
"website_blocklist": {
"enabled": True,
"domains": "example.com",
}
}
},
sort_keys=False,
),
encoding="utf-8",
)
with pytest.raises(WebsitePolicyError, match="security.website_blocklist.domains must be a list"):
load_website_blocklist(config_path)
def test_load_website_blocklist_raises_clean_error_for_invalid_shared_files_type(tmp_path):
config_path = tmp_path / "config.yaml"
config_path.write_text(
yaml.safe_dump(
{
"security": {
"website_blocklist": {
"enabled": True,
"shared_files": "community-blocklist.txt",
}
}
},
sort_keys=False,
),
encoding="utf-8",
)
with pytest.raises(WebsitePolicyError, match="security.website_blocklist.shared_files must be a list"):
load_website_blocklist(config_path)
def test_load_website_blocklist_raises_clean_error_for_invalid_top_level_config_type(tmp_path):
config_path = tmp_path / "config.yaml"
config_path.write_text(yaml.safe_dump(["not", "a", "mapping"], sort_keys=False), encoding="utf-8")
with pytest.raises(WebsitePolicyError, match="config root must be a mapping"):
load_website_blocklist(config_path)
def test_load_website_blocklist_raises_clean_error_for_invalid_security_type(tmp_path):
config_path = tmp_path / "config.yaml"
config_path.write_text(yaml.safe_dump({"security": []}, sort_keys=False), encoding="utf-8")
with pytest.raises(WebsitePolicyError, match="security must be a mapping"):
load_website_blocklist(config_path)
def test_load_website_blocklist_raises_clean_error_for_invalid_website_blocklist_type(tmp_path):
config_path = tmp_path / "config.yaml"
config_path.write_text(
yaml.safe_dump(
{
"security": {
"website_blocklist": "block everything",
}
},
sort_keys=False,
),
encoding="utf-8",
)
with pytest.raises(WebsitePolicyError, match="security.website_blocklist must be a mapping"):
load_website_blocklist(config_path)
def test_load_website_blocklist_raises_clean_error_for_invalid_enabled_type(tmp_path):
config_path = tmp_path / "config.yaml"
config_path.write_text(
yaml.safe_dump(
{
"security": {
"website_blocklist": {
"enabled": "false",
}
}
},
sort_keys=False,
),
encoding="utf-8",
)
with pytest.raises(WebsitePolicyError, match="security.website_blocklist.enabled must be a boolean"):
load_website_blocklist(config_path)
def test_load_website_blocklist_raises_clean_error_for_malformed_yaml(tmp_path):
config_path = tmp_path / "config.yaml"
config_path.write_text("security: [oops\n", encoding="utf-8")
with pytest.raises(WebsitePolicyError, match="Invalid config YAML"):
load_website_blocklist(config_path)
def test_load_website_blocklist_wraps_shared_file_read_errors(tmp_path, monkeypatch):
shared = tmp_path / "community-blocklist.txt"
shared.write_text("example.org\n", encoding="utf-8")
config_path = tmp_path / "config.yaml"
config_path.write_text(
yaml.safe_dump(
{
"security": {
"website_blocklist": {
"enabled": True,
"shared_files": [str(shared)],
}
}
},
sort_keys=False,
),
encoding="utf-8",
)
def failing_read_text(self, *args, **kwargs):
raise PermissionError("no permission")
monkeypatch.setattr(Path, "read_text", failing_read_text)
# Unreadable shared files are now warned and skipped (not raised),
# so the blocklist loads successfully but without those rules.
result = load_website_blocklist(config_path)
assert result["enabled"] is True
assert result["rules"] == [] # shared file rules skipped
def test_check_website_access_uses_dynamic_hermes_home(monkeypatch, tmp_path):
hermes_home = tmp_path / "hermes-home"
hermes_home.mkdir()
(hermes_home / "config.yaml").write_text(
yaml.safe_dump(
{
"security": {
"website_blocklist": {
"enabled": True,
"domains": ["dynamic.example"],
}
}
},
sort_keys=False,
),
encoding="utf-8",
)
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
blocked = check_website_access("https://dynamic.example/path")
assert blocked is not None
assert blocked["rule"] == "dynamic.example"
def test_check_website_access_blocks_scheme_less_urls(tmp_path):
config_path = tmp_path / "config.yaml"
config_path.write_text(
yaml.safe_dump(
{
"security": {
"website_blocklist": {
"enabled": True,
"domains": ["blocked.test"],
}
}
},
sort_keys=False,
),
encoding="utf-8",
)
blocked = check_website_access("www.blocked.test/path", config_path=config_path)
assert blocked is not None
assert blocked["host"] == "www.blocked.test"
assert blocked["rule"] == "blocked.test"
def test_browser_navigate_returns_policy_block(monkeypatch):
from tools import browser_tool
monkeypatch.setattr(
browser_tool,
"check_website_access",
lambda url: {
"host": "blocked.test",
"rule": "blocked.test",
"source": "config",
"message": "Blocked by website policy",
},
)
monkeypatch.setattr(
browser_tool,
"_run_browser_command",
lambda *args, **kwargs: pytest.fail("browser command should not run for blocked URL"),
)
result = json.loads(browser_tool.browser_navigate("https://blocked.test"))
assert result["success"] is False
assert result["blocked_by_policy"]["rule"] == "blocked.test"
def test_browser_navigate_allows_when_shared_file_missing(monkeypatch, tmp_path):
"""Missing shared blocklist files are warned and skipped, not fatal."""
from tools import browser_tool
config_path = tmp_path / "config.yaml"
config_path.write_text(
yaml.safe_dump(
{
"security": {
"website_blocklist": {
"enabled": True,
"shared_files": ["missing-blocklist.txt"],
}
}
},
sort_keys=False,
),
encoding="utf-8",
)
# check_website_access should return None (allow) — missing file is skipped
result = check_website_access("https://allowed.test", config_path=config_path)
assert result is None
@pytest.mark.asyncio
async def test_web_extract_short_circuits_blocked_url(monkeypatch):
from tools import web_tools
monkeypatch.setattr(
web_tools,
"check_website_access",
lambda url: {
"host": "blocked.test",
"rule": "blocked.test",
"source": "config",
"message": "Blocked by website policy",
},
)
monkeypatch.setattr(
web_tools,
"_get_firecrawl_client",
lambda: pytest.fail("firecrawl should not run for blocked URL"),
)
monkeypatch.setattr("tools.interrupt.is_interrupted", lambda: False)
result = json.loads(await web_tools.web_extract_tool(["https://blocked.test"], use_llm_processing=False))
assert result["results"][0]["url"] == "https://blocked.test"
assert "Blocked by website policy" in result["results"][0]["error"]
def test_check_website_access_fails_open_on_malformed_config(tmp_path, monkeypatch):
"""Malformed config with default path should fail open (return None), not crash."""
config_path = tmp_path / "config.yaml"
config_path.write_text("security: [oops\n", encoding="utf-8")
# With explicit config_path (test mode), errors propagate
with pytest.raises(WebsitePolicyError):
check_website_access("https://example.com", config_path=config_path)
# Simulate default path by pointing HERMES_HOME to tmp_path
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from tools import website_policy
website_policy.invalidate_cache()
# With default path, errors are caught and fail open
result = check_website_access("https://example.com")
assert result is None # allowed, not crashed
@pytest.mark.asyncio
async def test_web_extract_blocks_redirected_final_url(monkeypatch):
from tools import web_tools
def fake_check(url):
if url == "https://allowed.test":
return None
if url == "https://blocked.test/final":
return {
"host": "blocked.test",
"rule": "blocked.test",
"source": "config",
"message": "Blocked by website policy",
}
pytest.fail(f"unexpected URL checked: {url}")
class FakeFirecrawlClient:
def scrape(self, url, formats):
return {
"markdown": "secret content",
"metadata": {
"title": "Redirected",
"sourceURL": "https://blocked.test/final",
},
}
monkeypatch.setattr(web_tools, "check_website_access", fake_check)
monkeypatch.setattr(web_tools, "_get_firecrawl_client", lambda: FakeFirecrawlClient())
monkeypatch.setattr("tools.interrupt.is_interrupted", lambda: False)
result = json.loads(await web_tools.web_extract_tool(["https://allowed.test"], use_llm_processing=False))
assert result["results"][0]["url"] == "https://blocked.test/final"
assert result["results"][0]["content"] == ""
assert result["results"][0]["blocked_by_policy"]["rule"] == "blocked.test"
@pytest.mark.asyncio
async def test_web_crawl_short_circuits_blocked_url(monkeypatch):
from tools import web_tools
# web_crawl_tool checks for Firecrawl env before website policy
monkeypatch.setenv("FIRECRAWL_API_KEY", "fake-key")
monkeypatch.setattr(
web_tools,
"check_website_access",
lambda url: {
"host": "blocked.test",
"rule": "blocked.test",
"source": "config",
"message": "Blocked by website policy",
},
)
monkeypatch.setattr(
web_tools,
"_get_firecrawl_client",
lambda: pytest.fail("firecrawl should not run for blocked crawl URL"),
)
monkeypatch.setattr("tools.interrupt.is_interrupted", lambda: False)
result = json.loads(await web_tools.web_crawl_tool("https://blocked.test", use_llm_processing=False))
assert result["results"][0]["url"] == "https://blocked.test"
assert result["results"][0]["blocked_by_policy"]["rule"] == "blocked.test"
@pytest.mark.asyncio
async def test_web_crawl_blocks_redirected_final_url(monkeypatch):
from tools import web_tools
# web_crawl_tool checks for Firecrawl env before website policy
monkeypatch.setenv("FIRECRAWL_API_KEY", "fake-key")
def fake_check(url):
if url == "https://allowed.test":
return None
if url == "https://blocked.test/final":
return {
"host": "blocked.test",
"rule": "blocked.test",
"source": "config",
"message": "Blocked by website policy",
}
pytest.fail(f"unexpected URL checked: {url}")
class FakeCrawlClient:
def crawl(self, url, **kwargs):
return {
"data": [
{
"markdown": "secret crawl content",
"metadata": {
"title": "Redirected crawl page",
"sourceURL": "https://blocked.test/final",
},
}
]
}
monkeypatch.setattr(web_tools, "check_website_access", fake_check)
monkeypatch.setattr(web_tools, "_get_firecrawl_client", lambda: FakeCrawlClient())
monkeypatch.setattr("tools.interrupt.is_interrupted", lambda: False)
result = json.loads(await web_tools.web_crawl_tool("https://allowed.test", use_llm_processing=False))
assert result["results"][0]["content"] == ""
assert result["results"][0]["error"] == "Blocked by website policy"
assert result["results"][0]["blocked_by_policy"]["rule"] == "blocked.test"