merge: resolve file_tools.py conflict with origin/main
Combine read/search loop detection with main's redact_sensitive_text and truncation hint features. Add tracker reset to TestSearchHints to prevent cross-test state leakage.
This commit is contained in:
commit
4684aaffdc
104 changed files with 13720 additions and 2489 deletions
|
|
@ -1,4 +1,4 @@
|
|||
"""Tests for agent.auxiliary_client resolution chain, especially the Codex fallback."""
|
||||
"""Tests for agent.auxiliary_client resolution chain, provider overrides, and model overrides."""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
|
@ -12,6 +12,9 @@ from agent.auxiliary_client import (
|
|||
get_vision_auxiliary_client,
|
||||
auxiliary_max_tokens_param,
|
||||
_read_codex_access_token,
|
||||
_get_auxiliary_provider,
|
||||
_resolve_forced_provider,
|
||||
_resolve_auto,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -21,6 +24,10 @@ def _clean_env(monkeypatch):
|
|||
for key in (
|
||||
"OPENROUTER_API_KEY", "OPENAI_BASE_URL", "OPENAI_API_KEY",
|
||||
"OPENAI_MODEL", "LLM_MODEL", "NOUS_INFERENCE_BASE_URL",
|
||||
# Per-task provider/model overrides
|
||||
"AUXILIARY_VISION_PROVIDER", "AUXILIARY_VISION_MODEL",
|
||||
"AUXILIARY_WEB_EXTRACT_PROVIDER", "AUXILIARY_WEB_EXTRACT_MODEL",
|
||||
"CONTEXT_COMPRESSION_PROVIDER", "CONTEXT_COMPRESSION_MODEL",
|
||||
):
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
|
||||
|
|
@ -151,15 +158,230 @@ class TestGetTextAuxiliaryClient:
|
|||
assert model is None
|
||||
|
||||
|
||||
class TestCodexNotInVisionClient:
|
||||
"""Codex fallback should NOT apply to vision tasks."""
|
||||
class TestVisionClientFallback:
|
||||
"""Vision client auto mode only tries OpenRouter + Nous (multimodal-capable)."""
|
||||
|
||||
def test_vision_returns_none_without_openrouter_nous(self):
|
||||
def test_vision_returns_none_without_any_credentials(self):
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None):
|
||||
client, model = get_vision_auxiliary_client()
|
||||
assert client is None
|
||||
assert model is None
|
||||
|
||||
def test_vision_auto_includes_codex(self, codex_auth_dir):
|
||||
"""Codex supports vision (gpt-5.3-codex), so auto mode should use it."""
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client.OpenAI"):
|
||||
client, model = get_vision_auxiliary_client()
|
||||
from agent.auxiliary_client import CodexAuxiliaryClient
|
||||
assert isinstance(client, CodexAuxiliaryClient)
|
||||
assert model == "gpt-5.3-codex"
|
||||
|
||||
def test_vision_auto_skips_custom_endpoint(self, monkeypatch):
|
||||
"""Custom endpoint is skipped in vision auto mode."""
|
||||
monkeypatch.setenv("OPENAI_BASE_URL", "http://localhost:1234/v1")
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "local-key")
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None):
|
||||
client, model = get_vision_auxiliary_client()
|
||||
assert client is None
|
||||
assert model is None
|
||||
|
||||
def test_vision_uses_openrouter_when_available(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
client, model = get_vision_auxiliary_client()
|
||||
assert model == "google/gemini-3-flash-preview"
|
||||
assert client is not None
|
||||
|
||||
def test_vision_uses_nous_when_available(self, monkeypatch):
|
||||
with patch("agent.auxiliary_client._read_nous_auth") as mock_nous, \
|
||||
patch("agent.auxiliary_client.OpenAI"):
|
||||
mock_nous.return_value = {"access_token": "nous-tok"}
|
||||
client, model = get_vision_auxiliary_client()
|
||||
assert model == "gemini-3-flash"
|
||||
assert client is not None
|
||||
|
||||
def test_vision_forced_main_uses_custom_endpoint(self, monkeypatch):
|
||||
"""When explicitly forced to 'main', vision CAN use custom endpoint."""
|
||||
monkeypatch.setenv("AUXILIARY_VISION_PROVIDER", "main")
|
||||
monkeypatch.setenv("OPENAI_BASE_URL", "http://localhost:1234/v1")
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "local-key")
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
client, model = get_vision_auxiliary_client()
|
||||
assert client is not None
|
||||
assert model == "gpt-4o-mini"
|
||||
|
||||
def test_vision_forced_main_returns_none_without_creds(self, monkeypatch):
|
||||
"""Forced main with no credentials still returns None."""
|
||||
monkeypatch.setenv("AUXILIARY_VISION_PROVIDER", "main")
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client._read_codex_access_token", return_value=None):
|
||||
client, model = get_vision_auxiliary_client()
|
||||
assert client is None
|
||||
assert model is None
|
||||
|
||||
def test_vision_forced_codex(self, monkeypatch, codex_auth_dir):
|
||||
"""When forced to 'codex', vision uses Codex OAuth."""
|
||||
monkeypatch.setenv("AUXILIARY_VISION_PROVIDER", "codex")
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client.OpenAI"):
|
||||
client, model = get_vision_auxiliary_client()
|
||||
from agent.auxiliary_client import CodexAuxiliaryClient
|
||||
assert isinstance(client, CodexAuxiliaryClient)
|
||||
assert model == "gpt-5.3-codex"
|
||||
|
||||
|
||||
class TestGetAuxiliaryProvider:
|
||||
"""Tests for _get_auxiliary_provider env var resolution."""
|
||||
|
||||
def test_no_task_returns_auto(self):
|
||||
assert _get_auxiliary_provider() == "auto"
|
||||
assert _get_auxiliary_provider("") == "auto"
|
||||
|
||||
def test_auxiliary_prefix_takes_priority(self, monkeypatch):
|
||||
monkeypatch.setenv("AUXILIARY_VISION_PROVIDER", "openrouter")
|
||||
assert _get_auxiliary_provider("vision") == "openrouter"
|
||||
|
||||
def test_context_prefix_fallback(self, monkeypatch):
|
||||
monkeypatch.setenv("CONTEXT_COMPRESSION_PROVIDER", "nous")
|
||||
assert _get_auxiliary_provider("compression") == "nous"
|
||||
|
||||
def test_auxiliary_prefix_over_context_prefix(self, monkeypatch):
|
||||
monkeypatch.setenv("AUXILIARY_COMPRESSION_PROVIDER", "openrouter")
|
||||
monkeypatch.setenv("CONTEXT_COMPRESSION_PROVIDER", "nous")
|
||||
assert _get_auxiliary_provider("compression") == "openrouter"
|
||||
|
||||
def test_auto_value_treated_as_auto(self, monkeypatch):
|
||||
monkeypatch.setenv("AUXILIARY_VISION_PROVIDER", "auto")
|
||||
assert _get_auxiliary_provider("vision") == "auto"
|
||||
|
||||
def test_whitespace_stripped(self, monkeypatch):
|
||||
monkeypatch.setenv("AUXILIARY_VISION_PROVIDER", " openrouter ")
|
||||
assert _get_auxiliary_provider("vision") == "openrouter"
|
||||
|
||||
def test_case_insensitive(self, monkeypatch):
|
||||
monkeypatch.setenv("AUXILIARY_VISION_PROVIDER", "OpenRouter")
|
||||
assert _get_auxiliary_provider("vision") == "openrouter"
|
||||
|
||||
def test_main_provider(self, monkeypatch):
|
||||
monkeypatch.setenv("AUXILIARY_WEB_EXTRACT_PROVIDER", "main")
|
||||
assert _get_auxiliary_provider("web_extract") == "main"
|
||||
|
||||
|
||||
class TestResolveForcedProvider:
|
||||
"""Tests for _resolve_forced_provider with explicit provider selection."""
|
||||
|
||||
def test_forced_openrouter(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
client, model = _resolve_forced_provider("openrouter")
|
||||
assert model == "google/gemini-3-flash-preview"
|
||||
assert client is not None
|
||||
|
||||
def test_forced_openrouter_no_key(self, monkeypatch):
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None):
|
||||
client, model = _resolve_forced_provider("openrouter")
|
||||
assert client is None
|
||||
assert model is None
|
||||
|
||||
def test_forced_nous(self, monkeypatch):
|
||||
with patch("agent.auxiliary_client._read_nous_auth") as mock_nous, \
|
||||
patch("agent.auxiliary_client.OpenAI"):
|
||||
mock_nous.return_value = {"access_token": "nous-tok"}
|
||||
client, model = _resolve_forced_provider("nous")
|
||||
assert model == "gemini-3-flash"
|
||||
assert client is not None
|
||||
|
||||
def test_forced_nous_not_configured(self, monkeypatch):
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None):
|
||||
client, model = _resolve_forced_provider("nous")
|
||||
assert client is None
|
||||
assert model is None
|
||||
|
||||
def test_forced_main_uses_custom(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_BASE_URL", "http://local:8080/v1")
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "local-key")
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
client, model = _resolve_forced_provider("main")
|
||||
assert model == "gpt-4o-mini"
|
||||
|
||||
def test_forced_main_skips_openrouter_nous(self, monkeypatch):
|
||||
"""Even if OpenRouter key is set, 'main' skips it."""
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
monkeypatch.setenv("OPENAI_BASE_URL", "http://local:8080/v1")
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "local-key")
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
client, model = _resolve_forced_provider("main")
|
||||
# Should use custom endpoint, not OpenRouter
|
||||
assert model == "gpt-4o-mini"
|
||||
|
||||
def test_forced_main_falls_to_codex(self, codex_auth_dir, monkeypatch):
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client.OpenAI"):
|
||||
client, model = _resolve_forced_provider("main")
|
||||
from agent.auxiliary_client import CodexAuxiliaryClient
|
||||
assert isinstance(client, CodexAuxiliaryClient)
|
||||
assert model == "gpt-5.3-codex"
|
||||
|
||||
def test_forced_codex(self, codex_auth_dir, monkeypatch):
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client.OpenAI"):
|
||||
client, model = _resolve_forced_provider("codex")
|
||||
from agent.auxiliary_client import CodexAuxiliaryClient
|
||||
assert isinstance(client, CodexAuxiliaryClient)
|
||||
assert model == "gpt-5.3-codex"
|
||||
|
||||
def test_forced_codex_no_token(self, monkeypatch):
|
||||
with patch("agent.auxiliary_client._read_codex_access_token", return_value=None):
|
||||
client, model = _resolve_forced_provider("codex")
|
||||
assert client is None
|
||||
assert model is None
|
||||
|
||||
def test_forced_unknown_returns_none(self, monkeypatch):
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client._read_codex_access_token", return_value=None):
|
||||
client, model = _resolve_forced_provider("invalid-provider")
|
||||
assert client is None
|
||||
assert model is None
|
||||
|
||||
|
||||
class TestTaskSpecificOverrides:
|
||||
"""Integration tests for per-task provider routing via get_text_auxiliary_client(task=...)."""
|
||||
|
||||
def test_text_with_vision_provider_override(self, monkeypatch):
|
||||
"""AUXILIARY_VISION_PROVIDER should not affect text tasks."""
|
||||
monkeypatch.setenv("AUXILIARY_VISION_PROVIDER", "nous")
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
with patch("agent.auxiliary_client.OpenAI"):
|
||||
client, model = get_text_auxiliary_client() # no task → auto
|
||||
assert model == "google/gemini-3-flash-preview" # OpenRouter, not Nous
|
||||
|
||||
def test_compression_task_reads_context_prefix(self, monkeypatch):
|
||||
"""Compression task should check CONTEXT_COMPRESSION_PROVIDER."""
|
||||
monkeypatch.setenv("CONTEXT_COMPRESSION_PROVIDER", "nous")
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key") # would win in auto
|
||||
with patch("agent.auxiliary_client._read_nous_auth") as mock_nous, \
|
||||
patch("agent.auxiliary_client.OpenAI"):
|
||||
mock_nous.return_value = {"access_token": "nous-tok"}
|
||||
client, model = get_text_auxiliary_client("compression")
|
||||
assert model == "gemini-3-flash" # forced to Nous, not OpenRouter
|
||||
|
||||
def test_web_extract_task_override(self, monkeypatch):
|
||||
monkeypatch.setenv("AUXILIARY_WEB_EXTRACT_PROVIDER", "openrouter")
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
with patch("agent.auxiliary_client.OpenAI"):
|
||||
client, model = get_text_auxiliary_client("web_extract")
|
||||
assert model == "google/gemini-3-flash-preview"
|
||||
|
||||
def test_task_without_override_uses_auto(self, monkeypatch):
|
||||
"""A task with no provider env var falls through to auto chain."""
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
with patch("agent.auxiliary_client.OpenAI"):
|
||||
client, model = get_text_auxiliary_client("compression")
|
||||
assert model == "google/gemini-3-flash-preview" # auto → OpenRouter
|
||||
|
||||
|
||||
class TestAuxiliaryMaxTokensParam:
|
||||
def test_codex_fallback_uses_max_tokens(self, monkeypatch):
|
||||
|
|
|
|||
|
|
@ -224,6 +224,60 @@ class TestCompressWithClient:
|
|||
for tc in msg["tool_calls"]:
|
||||
assert tc["id"] in answered_ids
|
||||
|
||||
def test_summary_role_avoids_consecutive_user_messages(self):
|
||||
"""Summary role should alternate with the last head message to avoid consecutive same-role messages."""
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "[CONTEXT SUMMARY]: stuff happened"
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=100000), \
|
||||
patch("agent.context_compressor.get_text_auxiliary_client", return_value=(mock_client, "test-model")):
|
||||
c = ContextCompressor(model="test", quiet_mode=True, protect_first_n=2, protect_last_n=2)
|
||||
|
||||
# Last head message (index 1) is "assistant" → summary should be "user"
|
||||
msgs = [
|
||||
{"role": "user", "content": "msg 0"},
|
||||
{"role": "assistant", "content": "msg 1"},
|
||||
{"role": "user", "content": "msg 2"},
|
||||
{"role": "assistant", "content": "msg 3"},
|
||||
{"role": "user", "content": "msg 4"},
|
||||
{"role": "assistant", "content": "msg 5"},
|
||||
]
|
||||
result = c.compress(msgs)
|
||||
summary_msg = [m for m in result if "CONTEXT SUMMARY" in (m.get("content") or "")]
|
||||
assert len(summary_msg) == 1
|
||||
assert summary_msg[0]["role"] == "user"
|
||||
|
||||
def test_summary_role_avoids_consecutive_user_when_head_ends_with_user(self):
|
||||
"""When last head message is 'user', summary must be 'assistant' to avoid two consecutive user messages."""
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "[CONTEXT SUMMARY]: stuff happened"
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=100000), \
|
||||
patch("agent.context_compressor.get_text_auxiliary_client", return_value=(mock_client, "test-model")):
|
||||
c = ContextCompressor(model="test", quiet_mode=True, protect_first_n=3, protect_last_n=2)
|
||||
|
||||
# Last head message (index 2) is "user" → summary should be "assistant"
|
||||
msgs = [
|
||||
{"role": "system", "content": "system prompt"},
|
||||
{"role": "user", "content": "msg 1"},
|
||||
{"role": "user", "content": "msg 2"}, # last head — user
|
||||
{"role": "assistant", "content": "msg 3"},
|
||||
{"role": "user", "content": "msg 4"},
|
||||
{"role": "assistant", "content": "msg 5"},
|
||||
{"role": "user", "content": "msg 6"},
|
||||
{"role": "assistant", "content": "msg 7"},
|
||||
]
|
||||
result = c.compress(msgs)
|
||||
summary_msg = [m for m in result if "CONTEXT SUMMARY" in (m.get("content") or "")]
|
||||
assert len(summary_msg) == 1
|
||||
assert summary_msg[0]["role"] == "assistant"
|
||||
|
||||
def test_summarization_does_not_start_tail_with_tool_outputs(self):
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
|
|
|
|||
200
tests/gateway/test_resume_command.py
Normal file
200
tests/gateway/test_resume_command.py
Normal file
|
|
@ -0,0 +1,200 @@
|
|||
"""Tests for /resume gateway slash command.
|
||||
|
||||
Tests the _handle_resume_command handler (switch to a previously-named session)
|
||||
across gateway messenger platforms.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform
|
||||
from gateway.platforms.base import MessageEvent
|
||||
from gateway.session import SessionSource, build_session_key
|
||||
|
||||
|
||||
def _make_event(text="/resume", platform=Platform.TELEGRAM,
|
||||
user_id="12345", chat_id="67890"):
|
||||
"""Build a MessageEvent for testing."""
|
||||
source = SessionSource(
|
||||
platform=platform,
|
||||
user_id=user_id,
|
||||
chat_id=chat_id,
|
||||
user_name="testuser",
|
||||
)
|
||||
return MessageEvent(text=text, source=source)
|
||||
|
||||
|
||||
def _session_key_for_event(event):
|
||||
"""Get the session key that build_session_key produces for an event."""
|
||||
return build_session_key(event.source)
|
||||
|
||||
|
||||
def _make_runner(session_db=None, current_session_id="current_session_001",
|
||||
event=None):
|
||||
"""Create a bare GatewayRunner with a mock session_store and optional session_db."""
|
||||
from gateway.run import GatewayRunner
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.adapters = {}
|
||||
runner._session_db = session_db
|
||||
runner._running_agents = {}
|
||||
|
||||
# Compute the real session key if an event is provided
|
||||
session_key = build_session_key(event.source) if event else "agent:main:telegram:dm"
|
||||
|
||||
# Mock session_store that returns a session entry with a known session_id
|
||||
mock_session_entry = MagicMock()
|
||||
mock_session_entry.session_id = current_session_id
|
||||
mock_session_entry.session_key = session_key
|
||||
mock_store = MagicMock()
|
||||
mock_store.get_or_create_session.return_value = mock_session_entry
|
||||
mock_store.load_transcript.return_value = []
|
||||
mock_store.switch_session.return_value = mock_session_entry
|
||||
runner.session_store = mock_store
|
||||
|
||||
# Stub out memory flushing
|
||||
runner._async_flush_memories = AsyncMock()
|
||||
|
||||
return runner
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _handle_resume_command
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHandleResumeCommand:
|
||||
"""Tests for GatewayRunner._handle_resume_command."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_session_db(self):
|
||||
"""Returns error when session database is unavailable."""
|
||||
runner = _make_runner(session_db=None)
|
||||
event = _make_event(text="/resume My Project")
|
||||
result = await runner._handle_resume_command(event)
|
||||
assert "not available" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_named_sessions_when_no_arg(self, tmp_path):
|
||||
"""With no argument, lists recently titled sessions."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("sess_001", "telegram")
|
||||
db.create_session("sess_002", "telegram")
|
||||
db.set_session_title("sess_001", "Research")
|
||||
db.set_session_title("sess_002", "Coding")
|
||||
|
||||
event = _make_event(text="/resume")
|
||||
runner = _make_runner(session_db=db, event=event)
|
||||
result = await runner._handle_resume_command(event)
|
||||
assert "Research" in result
|
||||
assert "Coding" in result
|
||||
assert "Named Sessions" in result
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_shows_usage_when_no_titled(self, tmp_path):
|
||||
"""With no arg and no titled sessions, shows instructions."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("sess_001", "telegram") # No title
|
||||
|
||||
event = _make_event(text="/resume")
|
||||
runner = _make_runner(session_db=db, event=event)
|
||||
result = await runner._handle_resume_command(event)
|
||||
assert "No named sessions" in result
|
||||
assert "/title" in result
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_by_name(self, tmp_path):
|
||||
"""Resolves a title and switches to that session."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("old_session_abc", "telegram")
|
||||
db.set_session_title("old_session_abc", "My Project")
|
||||
db.create_session("current_session_001", "telegram")
|
||||
|
||||
event = _make_event(text="/resume My Project")
|
||||
runner = _make_runner(session_db=db, current_session_id="current_session_001",
|
||||
event=event)
|
||||
result = await runner._handle_resume_command(event)
|
||||
|
||||
assert "Resumed" in result
|
||||
assert "My Project" in result
|
||||
# Verify switch_session was called with the old session ID
|
||||
runner.session_store.switch_session.assert_called_once()
|
||||
call_args = runner.session_store.switch_session.call_args
|
||||
assert call_args[0][1] == "old_session_abc"
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_nonexistent_name(self, tmp_path):
|
||||
"""Returns error for unknown session name."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("current_session_001", "telegram")
|
||||
|
||||
event = _make_event(text="/resume Nonexistent Session")
|
||||
runner = _make_runner(session_db=db, event=event)
|
||||
result = await runner._handle_resume_command(event)
|
||||
assert "No session found" in result
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_already_on_session(self, tmp_path):
|
||||
"""Returns friendly message when already on the requested session."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("current_session_001", "telegram")
|
||||
db.set_session_title("current_session_001", "Active Project")
|
||||
|
||||
event = _make_event(text="/resume Active Project")
|
||||
runner = _make_runner(session_db=db, current_session_id="current_session_001",
|
||||
event=event)
|
||||
result = await runner._handle_resume_command(event)
|
||||
assert "Already on session" in result
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_auto_lineage(self, tmp_path):
|
||||
"""Asking for 'My Project' when 'My Project #2' exists gets the latest."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("sess_v1", "telegram")
|
||||
db.set_session_title("sess_v1", "My Project")
|
||||
db.create_session("sess_v2", "telegram")
|
||||
db.set_session_title("sess_v2", "My Project #2")
|
||||
db.create_session("current_session_001", "telegram")
|
||||
|
||||
event = _make_event(text="/resume My Project")
|
||||
runner = _make_runner(session_db=db, current_session_id="current_session_001",
|
||||
event=event)
|
||||
result = await runner._handle_resume_command(event)
|
||||
|
||||
assert "Resumed" in result
|
||||
# Should resolve to #2 (latest in lineage)
|
||||
call_args = runner.session_store.switch_session.call_args
|
||||
assert call_args[0][1] == "sess_v2"
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_clears_running_agent(self, tmp_path):
|
||||
"""Switching sessions clears any cached running agent."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("old_session", "telegram")
|
||||
db.set_session_title("old_session", "Old Work")
|
||||
db.create_session("current_session_001", "telegram")
|
||||
|
||||
event = _make_event(text="/resume Old Work")
|
||||
runner = _make_runner(session_db=db, current_session_id="current_session_001",
|
||||
event=event)
|
||||
# Simulate a running agent using the real session key
|
||||
real_key = _session_key_for_event(event)
|
||||
runner._running_agents[real_key] = MagicMock()
|
||||
|
||||
await runner._handle_resume_command(event)
|
||||
|
||||
assert real_key not in runner._running_agents
|
||||
db.close()
|
||||
|
|
@ -2,6 +2,10 @@
|
|||
|
||||
Verifies that the gateway detects pathologically large transcripts and
|
||||
triggers auto-compression before running the agent. (#628)
|
||||
|
||||
The hygiene system uses the SAME compression config as the agent:
|
||||
compression.threshold × model context length
|
||||
so CLI and messaging platforms behave identically.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
|
@ -38,75 +42,113 @@ def _make_large_history_tokens(target_tokens: int) -> list:
|
|||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Detection threshold tests
|
||||
# Detection threshold tests (model-aware, unified with compression config)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSessionHygieneThresholds:
|
||||
"""Test that the threshold logic correctly identifies large sessions."""
|
||||
"""Test that the threshold logic correctly identifies large sessions.
|
||||
|
||||
Thresholds are derived from model context length × compression threshold,
|
||||
matching what the agent's ContextCompressor uses.
|
||||
"""
|
||||
|
||||
def test_small_session_below_thresholds(self):
|
||||
"""A 10-message session should not trigger compression."""
|
||||
history = _make_history(10)
|
||||
msg_count = len(history)
|
||||
approx_tokens = estimate_messages_tokens_rough(history)
|
||||
|
||||
compress_token_threshold = 100_000
|
||||
compress_msg_threshold = 200
|
||||
# For a 200k-context model at 85% threshold = 170k
|
||||
context_length = 200_000
|
||||
threshold_pct = 0.85
|
||||
compress_token_threshold = int(context_length * threshold_pct)
|
||||
|
||||
needs_compress = (
|
||||
approx_tokens >= compress_token_threshold
|
||||
or msg_count >= compress_msg_threshold
|
||||
)
|
||||
needs_compress = approx_tokens >= compress_token_threshold
|
||||
assert not needs_compress
|
||||
|
||||
def test_large_message_count_triggers(self):
|
||||
"""200+ messages should trigger compression even if tokens are low."""
|
||||
history = _make_history(250, content_size=10)
|
||||
msg_count = len(history)
|
||||
|
||||
compress_msg_threshold = 200
|
||||
needs_compress = msg_count >= compress_msg_threshold
|
||||
assert needs_compress
|
||||
|
||||
def test_large_token_count_triggers(self):
|
||||
"""High token count should trigger compression even if message count is low."""
|
||||
# 50 messages with huge content to exceed 100K tokens
|
||||
history = _make_history(50, content_size=10_000)
|
||||
"""High token count should trigger compression when exceeding model threshold."""
|
||||
# Build a history that exceeds 85% of a 200k model (170k tokens)
|
||||
history = _make_large_history_tokens(180_000)
|
||||
approx_tokens = estimate_messages_tokens_rough(history)
|
||||
|
||||
compress_token_threshold = 100_000
|
||||
context_length = 200_000
|
||||
threshold_pct = 0.85
|
||||
compress_token_threshold = int(context_length * threshold_pct)
|
||||
|
||||
needs_compress = approx_tokens >= compress_token_threshold
|
||||
assert needs_compress
|
||||
|
||||
def test_under_both_thresholds_no_trigger(self):
|
||||
"""Session under both thresholds should not trigger."""
|
||||
history = _make_history(100, content_size=100)
|
||||
msg_count = len(history)
|
||||
def test_under_threshold_no_trigger(self):
|
||||
"""Session under threshold should not trigger, even with many messages."""
|
||||
# 250 short messages — lots of messages but well under token threshold
|
||||
history = _make_history(250, content_size=10)
|
||||
approx_tokens = estimate_messages_tokens_rough(history)
|
||||
|
||||
compress_token_threshold = 100_000
|
||||
compress_msg_threshold = 200
|
||||
# 200k model at 85% = 170k token threshold
|
||||
context_length = 200_000
|
||||
threshold_pct = 0.85
|
||||
compress_token_threshold = int(context_length * threshold_pct)
|
||||
|
||||
needs_compress = (
|
||||
approx_tokens >= compress_token_threshold
|
||||
or msg_count >= compress_msg_threshold
|
||||
needs_compress = approx_tokens >= compress_token_threshold
|
||||
assert not needs_compress, (
|
||||
f"250 short messages (~{approx_tokens} tokens) should NOT trigger "
|
||||
f"compression at {compress_token_threshold} token threshold"
|
||||
)
|
||||
|
||||
def test_message_count_alone_does_not_trigger(self):
|
||||
"""Message count alone should NOT trigger — only token count matters.
|
||||
|
||||
The old system used an OR of token-count and message-count thresholds,
|
||||
which caused premature compression in tool-heavy sessions with 200+
|
||||
messages but low total tokens.
|
||||
"""
|
||||
# 300 very short messages — old system would compress, new should not
|
||||
history = _make_history(300, content_size=10)
|
||||
approx_tokens = estimate_messages_tokens_rough(history)
|
||||
|
||||
context_length = 200_000
|
||||
threshold_pct = 0.85
|
||||
compress_token_threshold = int(context_length * threshold_pct)
|
||||
|
||||
# Token-based check only
|
||||
needs_compress = approx_tokens >= compress_token_threshold
|
||||
assert not needs_compress
|
||||
|
||||
def test_custom_thresholds(self):
|
||||
"""Custom thresholds from config should be respected."""
|
||||
history = _make_history(60, content_size=100)
|
||||
msg_count = len(history)
|
||||
def test_threshold_scales_with_model(self):
|
||||
"""Different models should have different compression thresholds."""
|
||||
# 128k model at 85% = 108,800 tokens
|
||||
small_model_threshold = int(128_000 * 0.85)
|
||||
# 200k model at 85% = 170,000 tokens
|
||||
large_model_threshold = int(200_000 * 0.85)
|
||||
# 1M model at 85% = 850,000 tokens
|
||||
huge_model_threshold = int(1_000_000 * 0.85)
|
||||
|
||||
# Custom lower threshold
|
||||
compress_msg_threshold = 50
|
||||
needs_compress = msg_count >= compress_msg_threshold
|
||||
assert needs_compress
|
||||
# A session at ~120k tokens:
|
||||
history = _make_large_history_tokens(120_000)
|
||||
approx_tokens = estimate_messages_tokens_rough(history)
|
||||
|
||||
# Custom higher threshold
|
||||
compress_msg_threshold = 100
|
||||
needs_compress = msg_count >= compress_msg_threshold
|
||||
assert not needs_compress
|
||||
# Should trigger for 128k model
|
||||
assert approx_tokens >= small_model_threshold
|
||||
# Should NOT trigger for 200k model
|
||||
assert approx_tokens < large_model_threshold
|
||||
# Should NOT trigger for 1M model
|
||||
assert approx_tokens < huge_model_threshold
|
||||
|
||||
def test_custom_threshold_percentage(self):
|
||||
"""Custom threshold percentage from config should be respected."""
|
||||
context_length = 200_000
|
||||
|
||||
# At 50% threshold = 100k
|
||||
low_threshold = int(context_length * 0.50)
|
||||
# At 90% threshold = 180k
|
||||
high_threshold = int(context_length * 0.90)
|
||||
|
||||
history = _make_large_history_tokens(150_000)
|
||||
approx_tokens = estimate_messages_tokens_rough(history)
|
||||
|
||||
# Should trigger at 50% but not at 90%
|
||||
assert approx_tokens >= low_threshold
|
||||
assert approx_tokens < high_threshold
|
||||
|
||||
def test_minimum_message_guard(self):
|
||||
"""Sessions with fewer than 4 messages should never trigger."""
|
||||
|
|
@ -117,18 +159,19 @@ class TestSessionHygieneThresholds:
|
|||
|
||||
|
||||
class TestSessionHygieneWarnThreshold:
|
||||
"""Test the post-compression warning threshold."""
|
||||
"""Test the post-compression warning threshold (95% of context)."""
|
||||
|
||||
def test_warn_when_still_large(self):
|
||||
"""If compressed result is still above warn_tokens, should warn."""
|
||||
# Simulate post-compression tokens
|
||||
warn_threshold = 200_000
|
||||
post_compress_tokens = 250_000
|
||||
"""If compressed result is still above 95% of context, should warn."""
|
||||
context_length = 200_000
|
||||
warn_threshold = int(context_length * 0.95) # 190k
|
||||
post_compress_tokens = 195_000
|
||||
assert post_compress_tokens >= warn_threshold
|
||||
|
||||
def test_no_warn_when_under(self):
|
||||
"""If compressed result is under warn_tokens, no warning."""
|
||||
warn_threshold = 200_000
|
||||
"""If compressed result is under 95% of context, no warning."""
|
||||
context_length = 200_000
|
||||
warn_threshold = int(context_length * 0.95) # 190k
|
||||
post_compress_tokens = 150_000
|
||||
assert post_compress_tokens < warn_threshold
|
||||
|
||||
|
|
@ -150,10 +193,12 @@ class TestTokenEstimation:
|
|||
assert estimate_messages_tokens_rough(many) > estimate_messages_tokens_rough(few)
|
||||
|
||||
def test_pathological_session_detected(self):
|
||||
"""The reported pathological case: 648 messages, ~299K tokens."""
|
||||
# Simulate a 648-message session averaging ~460 tokens per message
|
||||
"""The reported pathological case: 648 messages, ~299K tokens.
|
||||
|
||||
With a 200k model at 85% threshold (170k), this should trigger.
|
||||
"""
|
||||
history = _make_history(648, content_size=1800)
|
||||
tokens = estimate_messages_tokens_rough(history)
|
||||
# Should be well above the 100K default threshold
|
||||
assert tokens > 100_000
|
||||
assert len(history) > 200
|
||||
# Should be well above the 170K threshold for a 200k model
|
||||
threshold = int(200_000 * 0.85)
|
||||
assert tokens > threshold
|
||||
|
|
|
|||
294
tests/gateway/test_signal.py
Normal file
294
tests/gateway/test_signal.py
Normal file
|
|
@ -0,0 +1,294 @@
|
|||
"""Tests for Signal messenger platform adapter."""
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Platform & Config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSignalPlatformEnum:
|
||||
def test_signal_enum_exists(self):
|
||||
assert Platform.SIGNAL.value == "signal"
|
||||
|
||||
def test_signal_in_platform_list(self):
|
||||
platforms = [p.value for p in Platform]
|
||||
assert "signal" in platforms
|
||||
|
||||
|
||||
class TestSignalConfigLoading:
|
||||
def test_apply_env_overrides_signal(self, monkeypatch):
|
||||
monkeypatch.setenv("SIGNAL_HTTP_URL", "http://localhost:9090")
|
||||
monkeypatch.setenv("SIGNAL_ACCOUNT", "+15551234567")
|
||||
|
||||
from gateway.config import GatewayConfig, _apply_env_overrides
|
||||
config = GatewayConfig()
|
||||
_apply_env_overrides(config)
|
||||
|
||||
assert Platform.SIGNAL in config.platforms
|
||||
sc = config.platforms[Platform.SIGNAL]
|
||||
assert sc.enabled is True
|
||||
assert sc.extra["http_url"] == "http://localhost:9090"
|
||||
assert sc.extra["account"] == "+15551234567"
|
||||
|
||||
def test_signal_not_loaded_without_both_vars(self, monkeypatch):
|
||||
monkeypatch.setenv("SIGNAL_HTTP_URL", "http://localhost:9090")
|
||||
# No SIGNAL_ACCOUNT
|
||||
|
||||
from gateway.config import GatewayConfig, _apply_env_overrides
|
||||
config = GatewayConfig()
|
||||
_apply_env_overrides(config)
|
||||
|
||||
assert Platform.SIGNAL not in config.platforms
|
||||
|
||||
def test_connected_platforms_includes_signal(self, monkeypatch):
|
||||
monkeypatch.setenv("SIGNAL_HTTP_URL", "http://localhost:8080")
|
||||
monkeypatch.setenv("SIGNAL_ACCOUNT", "+15551234567")
|
||||
|
||||
from gateway.config import GatewayConfig, _apply_env_overrides
|
||||
config = GatewayConfig()
|
||||
_apply_env_overrides(config)
|
||||
|
||||
connected = config.get_connected_platforms()
|
||||
assert Platform.SIGNAL in connected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Adapter Init & Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSignalAdapterInit:
|
||||
def _make_config(self, **extra):
|
||||
config = PlatformConfig()
|
||||
config.enabled = True
|
||||
config.extra = {
|
||||
"http_url": "http://localhost:8080",
|
||||
"account": "+15551234567",
|
||||
**extra,
|
||||
}
|
||||
return config
|
||||
|
||||
def test_init_parses_config(self, monkeypatch):
|
||||
monkeypatch.setenv("SIGNAL_GROUP_ALLOWED_USERS", "group123,group456")
|
||||
|
||||
from gateway.platforms.signal import SignalAdapter
|
||||
adapter = SignalAdapter(self._make_config())
|
||||
|
||||
assert adapter.http_url == "http://localhost:8080"
|
||||
assert adapter.account == "+15551234567"
|
||||
assert "group123" in adapter.group_allow_from
|
||||
|
||||
def test_init_empty_allowlist(self, monkeypatch):
|
||||
monkeypatch.setenv("SIGNAL_GROUP_ALLOWED_USERS", "")
|
||||
|
||||
from gateway.platforms.signal import SignalAdapter
|
||||
adapter = SignalAdapter(self._make_config())
|
||||
|
||||
assert len(adapter.group_allow_from) == 0
|
||||
|
||||
def test_init_strips_trailing_slash(self, monkeypatch):
|
||||
monkeypatch.setenv("SIGNAL_GROUP_ALLOWED_USERS", "")
|
||||
|
||||
from gateway.platforms.signal import SignalAdapter
|
||||
adapter = SignalAdapter(self._make_config(http_url="http://localhost:8080/"))
|
||||
|
||||
assert adapter.http_url == "http://localhost:8080"
|
||||
|
||||
def test_self_message_filtering(self, monkeypatch):
|
||||
monkeypatch.setenv("SIGNAL_GROUP_ALLOWED_USERS", "")
|
||||
|
||||
from gateway.platforms.signal import SignalAdapter
|
||||
adapter = SignalAdapter(self._make_config())
|
||||
|
||||
assert adapter._account_normalized == "+15551234567"
|
||||
|
||||
|
||||
class TestSignalHelpers:
|
||||
def test_redact_phone_long(self):
|
||||
from gateway.platforms.signal import _redact_phone
|
||||
assert _redact_phone("+15551234567") == "+155****4567"
|
||||
|
||||
def test_redact_phone_short(self):
|
||||
from gateway.platforms.signal import _redact_phone
|
||||
assert _redact_phone("+12345") == "+1****45"
|
||||
|
||||
def test_redact_phone_empty(self):
|
||||
from gateway.platforms.signal import _redact_phone
|
||||
assert _redact_phone("") == "<none>"
|
||||
|
||||
def test_parse_comma_list(self):
|
||||
from gateway.platforms.signal import _parse_comma_list
|
||||
assert _parse_comma_list("+1234, +5678 , +9012") == ["+1234", "+5678", "+9012"]
|
||||
assert _parse_comma_list("") == []
|
||||
assert _parse_comma_list(" , , ") == []
|
||||
|
||||
def test_guess_extension_png(self):
|
||||
from gateway.platforms.signal import _guess_extension
|
||||
assert _guess_extension(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100) == ".png"
|
||||
|
||||
def test_guess_extension_jpeg(self):
|
||||
from gateway.platforms.signal import _guess_extension
|
||||
assert _guess_extension(b"\xff\xd8\xff\xe0" + b"\x00" * 100) == ".jpg"
|
||||
|
||||
def test_guess_extension_pdf(self):
|
||||
from gateway.platforms.signal import _guess_extension
|
||||
assert _guess_extension(b"%PDF-1.4" + b"\x00" * 100) == ".pdf"
|
||||
|
||||
def test_guess_extension_zip(self):
|
||||
from gateway.platforms.signal import _guess_extension
|
||||
assert _guess_extension(b"PK\x03\x04" + b"\x00" * 100) == ".zip"
|
||||
|
||||
def test_guess_extension_mp4(self):
|
||||
from gateway.platforms.signal import _guess_extension
|
||||
assert _guess_extension(b"\x00\x00\x00\x18ftypisom" + b"\x00" * 100) == ".mp4"
|
||||
|
||||
def test_guess_extension_unknown(self):
|
||||
from gateway.platforms.signal import _guess_extension
|
||||
assert _guess_extension(b"\x00\x01\x02\x03" * 10) == ".bin"
|
||||
|
||||
def test_is_image_ext(self):
|
||||
from gateway.platforms.signal import _is_image_ext
|
||||
assert _is_image_ext(".png") is True
|
||||
assert _is_image_ext(".jpg") is True
|
||||
assert _is_image_ext(".gif") is True
|
||||
assert _is_image_ext(".pdf") is False
|
||||
|
||||
def test_is_audio_ext(self):
|
||||
from gateway.platforms.signal import _is_audio_ext
|
||||
assert _is_audio_ext(".mp3") is True
|
||||
assert _is_audio_ext(".ogg") is True
|
||||
assert _is_audio_ext(".png") is False
|
||||
|
||||
def test_check_requirements(self, monkeypatch):
|
||||
from gateway.platforms.signal import check_signal_requirements
|
||||
monkeypatch.setenv("SIGNAL_HTTP_URL", "http://localhost:8080")
|
||||
monkeypatch.setenv("SIGNAL_ACCOUNT", "+15551234567")
|
||||
assert check_signal_requirements() is True
|
||||
|
||||
def test_render_mentions(self):
|
||||
from gateway.platforms.signal import _render_mentions
|
||||
text = "Hello \uFFFC, how are you?"
|
||||
mentions = [{"start": 6, "length": 1, "number": "+15559999999"}]
|
||||
result = _render_mentions(text, mentions)
|
||||
assert "@+15559999999" in result
|
||||
assert "\uFFFC" not in result
|
||||
|
||||
def test_render_mentions_no_mentions(self):
|
||||
from gateway.platforms.signal import _render_mentions
|
||||
text = "Hello world"
|
||||
result = _render_mentions(text, [])
|
||||
assert result == "Hello world"
|
||||
|
||||
def test_check_requirements_missing(self, monkeypatch):
|
||||
from gateway.platforms.signal import check_signal_requirements
|
||||
monkeypatch.delenv("SIGNAL_HTTP_URL", raising=False)
|
||||
monkeypatch.delenv("SIGNAL_ACCOUNT", raising=False)
|
||||
assert check_signal_requirements() is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session Source
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSignalSessionSource:
|
||||
def test_session_source_alt_fields(self):
|
||||
from gateway.session import SessionSource
|
||||
source = SessionSource(
|
||||
platform=Platform.SIGNAL,
|
||||
chat_id="+15551234567",
|
||||
user_id="+15551234567",
|
||||
user_id_alt="uuid:abc-123",
|
||||
chat_id_alt=None,
|
||||
)
|
||||
d = source.to_dict()
|
||||
assert d["user_id_alt"] == "uuid:abc-123"
|
||||
assert "chat_id_alt" not in d # None fields excluded
|
||||
|
||||
def test_session_source_roundtrip(self):
|
||||
from gateway.session import SessionSource
|
||||
source = SessionSource(
|
||||
platform=Platform.SIGNAL,
|
||||
chat_id="group:xyz",
|
||||
chat_type="group",
|
||||
user_id="+15551234567",
|
||||
user_id_alt="uuid:abc",
|
||||
chat_id_alt="xyz",
|
||||
)
|
||||
d = source.to_dict()
|
||||
restored = SessionSource.from_dict(d)
|
||||
assert restored.user_id_alt == "uuid:abc"
|
||||
assert restored.chat_id_alt == "xyz"
|
||||
assert restored.platform == Platform.SIGNAL
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Phone Redaction in agent/redact.py
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSignalPhoneRedaction:
|
||||
def test_us_number(self):
|
||||
from agent.redact import redact_sensitive_text
|
||||
result = redact_sensitive_text("Call +15551234567 now")
|
||||
assert "+15551234567" not in result
|
||||
assert "+155" in result # Prefix preserved
|
||||
assert "4567" in result # Suffix preserved
|
||||
|
||||
def test_uk_number(self):
|
||||
from agent.redact import redact_sensitive_text
|
||||
result = redact_sensitive_text("UK: +442071838750")
|
||||
assert "+442071838750" not in result
|
||||
assert "****" in result
|
||||
|
||||
def test_multiple_numbers(self):
|
||||
from agent.redact import redact_sensitive_text
|
||||
text = "From +15551234567 to +442071838750"
|
||||
result = redact_sensitive_text(text)
|
||||
assert "+15551234567" not in result
|
||||
assert "+442071838750" not in result
|
||||
|
||||
def test_short_number_not_matched(self):
|
||||
from agent.redact import redact_sensitive_text
|
||||
result = redact_sensitive_text("Code: +12345")
|
||||
# 5 digits after + is below the 7-digit minimum
|
||||
assert "+12345" in result # Too short to redact
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Authorization in run.py
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSignalAuthorization:
|
||||
def test_signal_in_allowlist_maps(self):
|
||||
"""Signal should be in the platform auth maps."""
|
||||
from gateway.run import GatewayRunner
|
||||
from gateway.config import GatewayConfig
|
||||
|
||||
gw = GatewayRunner.__new__(GatewayRunner)
|
||||
gw.config = GatewayConfig()
|
||||
gw.pairing_store = MagicMock()
|
||||
gw.pairing_store.is_approved.return_value = False
|
||||
|
||||
source = MagicMock()
|
||||
source.platform = Platform.SIGNAL
|
||||
source.user_id = "+15559999999"
|
||||
|
||||
# No allowlists set — should check GATEWAY_ALLOW_ALL_USERS
|
||||
with patch.dict("os.environ", {}, clear=True):
|
||||
result = gw._is_user_authorized(source)
|
||||
assert result is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Send Message Tool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSignalSendMessage:
|
||||
def test_signal_in_platform_map(self):
|
||||
"""Signal should be in the send_message tool's platform map."""
|
||||
from tools.send_message_tool import send_message_tool
|
||||
# Just verify the import works and Signal is a valid platform
|
||||
from gateway.config import Platform
|
||||
assert Platform.SIGNAL.value == "signal"
|
||||
207
tests/gateway/test_title_command.py
Normal file
207
tests/gateway/test_title_command.py
Normal file
|
|
@ -0,0 +1,207 @@
|
|||
"""Tests for /title gateway slash command.
|
||||
|
||||
Tests the _handle_title_command handler (set/show session titles)
|
||||
across all gateway messenger platforms.
|
||||
"""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform
|
||||
from gateway.platforms.base import MessageEvent
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
def _make_event(text="/title", platform=Platform.TELEGRAM,
|
||||
user_id="12345", chat_id="67890"):
|
||||
"""Build a MessageEvent for testing."""
|
||||
source = SessionSource(
|
||||
platform=platform,
|
||||
user_id=user_id,
|
||||
chat_id=chat_id,
|
||||
user_name="testuser",
|
||||
)
|
||||
return MessageEvent(text=text, source=source)
|
||||
|
||||
|
||||
def _make_runner(session_db=None):
|
||||
"""Create a bare GatewayRunner with a mock session_store and optional session_db."""
|
||||
from gateway.run import GatewayRunner
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.adapters = {}
|
||||
runner._session_db = session_db
|
||||
|
||||
# Mock session_store that returns a session entry with a known session_id
|
||||
mock_session_entry = MagicMock()
|
||||
mock_session_entry.session_id = "test_session_123"
|
||||
mock_session_entry.session_key = "telegram:12345:67890"
|
||||
mock_store = MagicMock()
|
||||
mock_store.get_or_create_session.return_value = mock_session_entry
|
||||
runner.session_store = mock_store
|
||||
|
||||
return runner
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _handle_title_command
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHandleTitleCommand:
|
||||
"""Tests for GatewayRunner._handle_title_command."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_title(self, tmp_path):
|
||||
"""Setting a title returns confirmation."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("test_session_123", "telegram")
|
||||
|
||||
runner = _make_runner(session_db=db)
|
||||
event = _make_event(text="/title My Research Project")
|
||||
result = await runner._handle_title_command(event)
|
||||
assert "My Research Project" in result
|
||||
assert "✏️" in result
|
||||
|
||||
# Verify in DB
|
||||
assert db.get_session_title("test_session_123") == "My Research Project"
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_show_title_when_set(self, tmp_path):
|
||||
"""Showing title when one is set returns the title."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("test_session_123", "telegram")
|
||||
db.set_session_title("test_session_123", "Existing Title")
|
||||
|
||||
runner = _make_runner(session_db=db)
|
||||
event = _make_event(text="/title")
|
||||
result = await runner._handle_title_command(event)
|
||||
assert "Existing Title" in result
|
||||
assert "📌" in result
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_show_title_when_not_set(self, tmp_path):
|
||||
"""Showing title when none is set returns usage hint."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("test_session_123", "telegram")
|
||||
|
||||
runner = _make_runner(session_db=db)
|
||||
event = _make_event(text="/title")
|
||||
result = await runner._handle_title_command(event)
|
||||
assert "No title set" in result
|
||||
assert "/title" in result
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_title_conflict(self, tmp_path):
|
||||
"""Setting a title already used by another session returns error."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("other_session", "telegram")
|
||||
db.set_session_title("other_session", "Taken Title")
|
||||
db.create_session("test_session_123", "telegram")
|
||||
|
||||
runner = _make_runner(session_db=db)
|
||||
event = _make_event(text="/title Taken Title")
|
||||
result = await runner._handle_title_command(event)
|
||||
assert "already in use" in result
|
||||
assert "⚠️" in result
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_session_db(self):
|
||||
"""Returns error when session database is not available."""
|
||||
runner = _make_runner(session_db=None)
|
||||
event = _make_event(text="/title My Title")
|
||||
result = await runner._handle_title_command(event)
|
||||
assert "not available" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_title_too_long(self, tmp_path):
|
||||
"""Setting a title that exceeds max length returns error."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("test_session_123", "telegram")
|
||||
|
||||
runner = _make_runner(session_db=db)
|
||||
long_title = "A" * 150
|
||||
event = _make_event(text=f"/title {long_title}")
|
||||
result = await runner._handle_title_command(event)
|
||||
assert "too long" in result
|
||||
assert "⚠️" in result
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_title_control_chars_sanitized(self, tmp_path):
|
||||
"""Control characters are stripped and sanitized title is stored."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("test_session_123", "telegram")
|
||||
|
||||
runner = _make_runner(session_db=db)
|
||||
event = _make_event(text="/title hello\x00world")
|
||||
result = await runner._handle_title_command(event)
|
||||
assert "helloworld" in result
|
||||
assert db.get_session_title("test_session_123") == "helloworld"
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_title_only_control_chars(self, tmp_path):
|
||||
"""Title with only control chars returns empty error."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("test_session_123", "telegram")
|
||||
|
||||
runner = _make_runner(session_db=db)
|
||||
event = _make_event(text="/title \x00\x01\x02")
|
||||
result = await runner._handle_title_command(event)
|
||||
assert "empty after cleanup" in result
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_works_across_platforms(self, tmp_path):
|
||||
"""The /title command works for Discord, Slack, and WhatsApp too."""
|
||||
from hermes_state import SessionDB
|
||||
for platform in [Platform.DISCORD, Platform.TELEGRAM]:
|
||||
db = SessionDB(db_path=tmp_path / f"state_{platform.value}.db")
|
||||
db.create_session("test_session_123", platform.value)
|
||||
|
||||
runner = _make_runner(session_db=db)
|
||||
event = _make_event(text="/title Cross-Platform Test", platform=platform)
|
||||
result = await runner._handle_title_command(event)
|
||||
assert "Cross-Platform Test" in result
|
||||
assert db.get_session_title("test_session_123") == "Cross-Platform Test"
|
||||
db.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# /title in help and known_commands
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTitleInHelp:
|
||||
"""Verify /title appears in help text and known commands."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_title_in_help_output(self):
|
||||
"""The /help output includes /title."""
|
||||
runner = _make_runner()
|
||||
event = _make_event(text="/help")
|
||||
# Need hooks for help command
|
||||
from gateway.hooks import HookRegistry
|
||||
runner.hooks = HookRegistry()
|
||||
result = await runner._handle_help_command(event)
|
||||
assert "/title" in result
|
||||
|
||||
def test_title_is_known_command(self):
|
||||
"""The /title command is in the _known_commands set."""
|
||||
from gateway.run import GatewayRunner
|
||||
import inspect
|
||||
source = inspect.getsource(GatewayRunner._handle_message)
|
||||
assert '"title"' in source
|
||||
|
|
@ -11,7 +11,7 @@ EXPECTED_COMMANDS = {
|
|||
"/help", "/tools", "/toolsets", "/model", "/provider", "/prompt",
|
||||
"/personality", "/clear", "/history", "/new", "/reset", "/retry",
|
||||
"/undo", "/save", "/config", "/cron", "/skills", "/platforms",
|
||||
"/verbose", "/compress", "/usage", "/insights", "/paste",
|
||||
"/verbose", "/compress", "/title", "/usage", "/insights", "/paste",
|
||||
"/reload-mcp", "/quit",
|
||||
}
|
||||
|
||||
|
|
|
|||
542
tests/hermes_cli/test_session_browse.py
Normal file
542
tests/hermes_cli/test_session_browse.py
Normal file
|
|
@ -0,0 +1,542 @@
|
|||
"""Tests for the interactive session browser (`hermes sessions browse`).
|
||||
|
||||
Covers:
|
||||
- _session_browse_picker logic (curses mocked, fallback tested)
|
||||
- cmd_sessions 'browse' action integration
|
||||
- Argument parser registration
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch, call
|
||||
|
||||
import pytest
|
||||
|
||||
from hermes_cli.main import _session_browse_picker
|
||||
|
||||
|
||||
# ─── Sample session data ──────────────────────────────────────────────────────
|
||||
|
||||
def _make_sessions(n=5):
|
||||
"""Generate a list of fake rich-session dicts."""
|
||||
now = time.time()
|
||||
sessions = []
|
||||
for i in range(n):
|
||||
sessions.append({
|
||||
"id": f"20260308_{i:06d}_abcdef",
|
||||
"source": "cli" if i % 2 == 0 else "telegram",
|
||||
"model": "test/model",
|
||||
"title": f"Session {i}" if i % 3 != 0 else None,
|
||||
"preview": f"Hello from session {i}",
|
||||
"last_active": now - i * 3600,
|
||||
"started_at": now - i * 3600 - 60,
|
||||
"message_count": (i + 1) * 5,
|
||||
})
|
||||
return sessions
|
||||
|
||||
|
||||
SAMPLE_SESSIONS = _make_sessions(5)
|
||||
|
||||
|
||||
# ─── _session_browse_picker ──────────────────────────────────────────────────
|
||||
|
||||
class TestSessionBrowsePicker:
|
||||
"""Tests for the _session_browse_picker function."""
|
||||
|
||||
def test_empty_sessions_returns_none(self, capsys):
|
||||
result = _session_browse_picker([])
|
||||
assert result is None
|
||||
assert "No sessions found" in capsys.readouterr().out
|
||||
|
||||
def test_returns_none_when_no_sessions(self, capsys):
|
||||
result = _session_browse_picker([])
|
||||
assert result is None
|
||||
|
||||
def test_fallback_mode_valid_selection(self):
|
||||
"""When curses is unavailable, fallback numbered list should work."""
|
||||
sessions = _make_sessions(3)
|
||||
|
||||
# Mock curses import to fail, forcing fallback
|
||||
import builtins
|
||||
original_import = builtins.__import__
|
||||
|
||||
def mock_import(name, *args, **kwargs):
|
||||
if name == "curses":
|
||||
raise ImportError("no curses")
|
||||
return original_import(name, *args, **kwargs)
|
||||
|
||||
with patch.object(builtins, "__import__", side_effect=mock_import):
|
||||
with patch("builtins.input", return_value="2"):
|
||||
result = _session_browse_picker(sessions)
|
||||
|
||||
assert result == sessions[1]["id"]
|
||||
|
||||
def test_fallback_mode_cancel_q(self):
|
||||
"""Entering 'q' in fallback mode cancels."""
|
||||
sessions = _make_sessions(3)
|
||||
|
||||
import builtins
|
||||
original_import = builtins.__import__
|
||||
|
||||
def mock_import(name, *args, **kwargs):
|
||||
if name == "curses":
|
||||
raise ImportError("no curses")
|
||||
return original_import(name, *args, **kwargs)
|
||||
|
||||
with patch.object(builtins, "__import__", side_effect=mock_import):
|
||||
with patch("builtins.input", return_value="q"):
|
||||
result = _session_browse_picker(sessions)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_fallback_mode_cancel_empty(self):
|
||||
"""Entering empty string in fallback mode cancels."""
|
||||
sessions = _make_sessions(3)
|
||||
|
||||
import builtins
|
||||
original_import = builtins.__import__
|
||||
|
||||
def mock_import(name, *args, **kwargs):
|
||||
if name == "curses":
|
||||
raise ImportError("no curses")
|
||||
return original_import(name, *args, **kwargs)
|
||||
|
||||
with patch.object(builtins, "__import__", side_effect=mock_import):
|
||||
with patch("builtins.input", return_value=""):
|
||||
result = _session_browse_picker(sessions)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_fallback_mode_invalid_then_valid(self):
|
||||
"""Invalid selection followed by valid one works."""
|
||||
sessions = _make_sessions(3)
|
||||
|
||||
import builtins
|
||||
original_import = builtins.__import__
|
||||
|
||||
def mock_import(name, *args, **kwargs):
|
||||
if name == "curses":
|
||||
raise ImportError("no curses")
|
||||
return original_import(name, *args, **kwargs)
|
||||
|
||||
with patch.object(builtins, "__import__", side_effect=mock_import):
|
||||
with patch("builtins.input", side_effect=["99", "1"]):
|
||||
result = _session_browse_picker(sessions)
|
||||
|
||||
assert result == sessions[0]["id"]
|
||||
|
||||
def test_fallback_mode_keyboard_interrupt(self):
|
||||
"""KeyboardInterrupt in fallback mode returns None."""
|
||||
sessions = _make_sessions(3)
|
||||
|
||||
import builtins
|
||||
original_import = builtins.__import__
|
||||
|
||||
def mock_import(name, *args, **kwargs):
|
||||
if name == "curses":
|
||||
raise ImportError("no curses")
|
||||
return original_import(name, *args, **kwargs)
|
||||
|
||||
with patch.object(builtins, "__import__", side_effect=mock_import):
|
||||
with patch("builtins.input", side_effect=KeyboardInterrupt):
|
||||
result = _session_browse_picker(sessions)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_fallback_displays_all_sessions(self, capsys):
|
||||
"""Fallback mode should display all session entries."""
|
||||
sessions = _make_sessions(4)
|
||||
|
||||
import builtins
|
||||
original_import = builtins.__import__
|
||||
|
||||
def mock_import(name, *args, **kwargs):
|
||||
if name == "curses":
|
||||
raise ImportError("no curses")
|
||||
return original_import(name, *args, **kwargs)
|
||||
|
||||
with patch.object(builtins, "__import__", side_effect=mock_import):
|
||||
with patch("builtins.input", return_value="q"):
|
||||
_session_browse_picker(sessions)
|
||||
|
||||
output = capsys.readouterr().out
|
||||
# All 4 entries should be shown
|
||||
assert "1." in output
|
||||
assert "2." in output
|
||||
assert "3." in output
|
||||
assert "4." in output
|
||||
|
||||
def test_fallback_shows_title_over_preview(self, capsys):
|
||||
"""When a session has a title, show it instead of the preview."""
|
||||
sessions = [{
|
||||
"id": "test_001",
|
||||
"source": "cli",
|
||||
"title": "My Cool Project",
|
||||
"preview": "some preview text",
|
||||
"last_active": time.time(),
|
||||
}]
|
||||
|
||||
import builtins
|
||||
original_import = builtins.__import__
|
||||
|
||||
def mock_import(name, *args, **kwargs):
|
||||
if name == "curses":
|
||||
raise ImportError("no curses")
|
||||
return original_import(name, *args, **kwargs)
|
||||
|
||||
with patch.object(builtins, "__import__", side_effect=mock_import):
|
||||
with patch("builtins.input", return_value="q"):
|
||||
_session_browse_picker(sessions)
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "My Cool Project" in output
|
||||
|
||||
def test_fallback_shows_preview_when_no_title(self, capsys):
|
||||
"""When no title, show preview."""
|
||||
sessions = [{
|
||||
"id": "test_002",
|
||||
"source": "cli",
|
||||
"title": None,
|
||||
"preview": "Hello world test message",
|
||||
"last_active": time.time(),
|
||||
}]
|
||||
|
||||
import builtins
|
||||
original_import = builtins.__import__
|
||||
|
||||
def mock_import(name, *args, **kwargs):
|
||||
if name == "curses":
|
||||
raise ImportError("no curses")
|
||||
return original_import(name, *args, **kwargs)
|
||||
|
||||
with patch.object(builtins, "__import__", side_effect=mock_import):
|
||||
with patch("builtins.input", return_value="q"):
|
||||
_session_browse_picker(sessions)
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "Hello world test message" in output
|
||||
|
||||
def test_fallback_shows_id_when_no_title_or_preview(self, capsys):
|
||||
"""When neither title nor preview, show session ID."""
|
||||
sessions = [{
|
||||
"id": "test_003_fallback",
|
||||
"source": "cli",
|
||||
"title": None,
|
||||
"preview": "",
|
||||
"last_active": time.time(),
|
||||
}]
|
||||
|
||||
import builtins
|
||||
original_import = builtins.__import__
|
||||
|
||||
def mock_import(name, *args, **kwargs):
|
||||
if name == "curses":
|
||||
raise ImportError("no curses")
|
||||
return original_import(name, *args, **kwargs)
|
||||
|
||||
with patch.object(builtins, "__import__", side_effect=mock_import):
|
||||
with patch("builtins.input", return_value="q"):
|
||||
_session_browse_picker(sessions)
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "test_003_fallback" in output
|
||||
|
||||
|
||||
# ─── Curses-based picker (mocked curses) ────────────────────────────────────
|
||||
|
||||
class TestCursesBrowse:
|
||||
"""Tests for the curses-based interactive picker via simulated key sequences."""
|
||||
|
||||
def _run_with_keys(self, sessions, key_sequence):
|
||||
"""Simulate running the curses picker with a given key sequence."""
|
||||
import curses
|
||||
|
||||
# Build a mock stdscr that returns keys from the sequence
|
||||
mock_stdscr = MagicMock()
|
||||
mock_stdscr.getmaxyx.return_value = (30, 120)
|
||||
mock_stdscr.getch.side_effect = key_sequence
|
||||
|
||||
# Capture what curses.wrapper receives and call it with our mock
|
||||
with patch("curses.wrapper") as mock_wrapper:
|
||||
# When wrapper is called, invoke the function with our mock stdscr
|
||||
def run_inner(func):
|
||||
try:
|
||||
func(mock_stdscr)
|
||||
except StopIteration:
|
||||
pass # key sequence exhausted
|
||||
|
||||
mock_wrapper.side_effect = run_inner
|
||||
with patch("curses.curs_set"):
|
||||
with patch("curses.has_colors", return_value=False):
|
||||
return _session_browse_picker(sessions)
|
||||
|
||||
def test_enter_selects_first_session(self):
|
||||
sessions = _make_sessions(3)
|
||||
result = self._run_with_keys(sessions, [10]) # Enter key
|
||||
assert result == sessions[0]["id"]
|
||||
|
||||
def test_down_then_enter_selects_second(self):
|
||||
import curses
|
||||
sessions = _make_sessions(3)
|
||||
result = self._run_with_keys(sessions, [curses.KEY_DOWN, 10])
|
||||
assert result == sessions[1]["id"]
|
||||
|
||||
def test_down_down_enter_selects_third(self):
|
||||
import curses
|
||||
sessions = _make_sessions(5)
|
||||
result = self._run_with_keys(sessions, [curses.KEY_DOWN, curses.KEY_DOWN, 10])
|
||||
assert result == sessions[2]["id"]
|
||||
|
||||
def test_up_wraps_to_last(self):
|
||||
import curses
|
||||
sessions = _make_sessions(3)
|
||||
result = self._run_with_keys(sessions, [curses.KEY_UP, 10])
|
||||
assert result == sessions[2]["id"]
|
||||
|
||||
def test_escape_cancels(self):
|
||||
sessions = _make_sessions(3)
|
||||
result = self._run_with_keys(sessions, [27]) # Esc
|
||||
assert result is None
|
||||
|
||||
def test_q_cancels(self):
|
||||
sessions = _make_sessions(3)
|
||||
result = self._run_with_keys(sessions, [ord('q')])
|
||||
assert result is None
|
||||
|
||||
def test_type_to_filter_then_enter(self):
|
||||
"""Typing characters filters the list, Enter selects from filtered."""
|
||||
import curses
|
||||
sessions = [
|
||||
{"id": "s1", "source": "cli", "title": "Alpha project", "preview": "", "last_active": time.time()},
|
||||
{"id": "s2", "source": "cli", "title": "Beta project", "preview": "", "last_active": time.time()},
|
||||
{"id": "s3", "source": "cli", "title": "Gamma project", "preview": "", "last_active": time.time()},
|
||||
]
|
||||
# Type "Beta" then Enter — should select s2
|
||||
keys = [ord(c) for c in "Beta"] + [10]
|
||||
result = self._run_with_keys(sessions, keys)
|
||||
assert result == "s2"
|
||||
|
||||
def test_filter_no_match_enter_does_nothing(self):
|
||||
"""When filter produces no results, Enter shouldn't select."""
|
||||
sessions = _make_sessions(3)
|
||||
keys = [ord(c) for c in "zzzznonexistent"] + [10]
|
||||
result = self._run_with_keys(sessions, keys)
|
||||
assert result is None
|
||||
|
||||
def test_backspace_removes_filter_char(self):
|
||||
"""Backspace removes the last character from the filter."""
|
||||
import curses
|
||||
sessions = [
|
||||
{"id": "s1", "source": "cli", "title": "Alpha", "preview": "", "last_active": time.time()},
|
||||
{"id": "s2", "source": "cli", "title": "Beta", "preview": "", "last_active": time.time()},
|
||||
]
|
||||
# Type "Bet", backspace, backspace, backspace (clears filter), then Enter (selects first)
|
||||
keys = [ord('B'), ord('e'), ord('t'), 127, 127, 127, 10]
|
||||
result = self._run_with_keys(sessions, keys)
|
||||
assert result == "s1"
|
||||
|
||||
def test_escape_clears_filter_first(self):
|
||||
"""First Esc clears the search text, second Esc exits."""
|
||||
import curses
|
||||
sessions = _make_sessions(3)
|
||||
# Type "ab" then Esc (clears filter) then Enter (selects first)
|
||||
keys = [ord('a'), ord('b'), 27, 10]
|
||||
result = self._run_with_keys(sessions, keys)
|
||||
assert result == sessions[0]["id"]
|
||||
|
||||
def test_filter_matches_preview(self):
|
||||
"""Typing should match against session preview text."""
|
||||
sessions = [
|
||||
{"id": "s1", "source": "cli", "title": None, "preview": "Set up Minecraft server", "last_active": time.time()},
|
||||
{"id": "s2", "source": "cli", "title": None, "preview": "Review PR 438", "last_active": time.time()},
|
||||
]
|
||||
keys = [ord(c) for c in "Mine"] + [10]
|
||||
result = self._run_with_keys(sessions, keys)
|
||||
assert result == "s1"
|
||||
|
||||
def test_filter_matches_source(self):
|
||||
"""Typing a source name should filter by source."""
|
||||
sessions = [
|
||||
{"id": "s1", "source": "telegram", "title": "TG session", "preview": "", "last_active": time.time()},
|
||||
{"id": "s2", "source": "cli", "title": "CLI session", "preview": "", "last_active": time.time()},
|
||||
]
|
||||
keys = [ord(c) for c in "telegram"] + [10]
|
||||
result = self._run_with_keys(sessions, keys)
|
||||
assert result == "s1"
|
||||
|
||||
def test_q_quits_when_no_filter_active(self):
|
||||
"""When no search text is active, 'q' should quit (not filter)."""
|
||||
sessions = _make_sessions(3)
|
||||
result = self._run_with_keys(sessions, [ord('q')])
|
||||
assert result is None
|
||||
|
||||
def test_q_types_into_filter_when_filter_active(self):
|
||||
"""When search text is already active, 'q' should add to filter, not quit."""
|
||||
sessions = [
|
||||
{"id": "s1", "source": "cli", "title": "the sequel", "preview": "", "last_active": time.time()},
|
||||
{"id": "s2", "source": "cli", "title": "other thing", "preview": "", "last_active": time.time()},
|
||||
]
|
||||
# Type "se" first (activates filter, matches "the sequel")
|
||||
# Then type "q" — should add 'q' to filter (filter="seq"), NOT quit
|
||||
# "seq" still matches "the sequel" → Enter selects it
|
||||
keys = [ord('s'), ord('e'), ord('q'), 10]
|
||||
result = self._run_with_keys(sessions, keys)
|
||||
assert result == "s1" # "the sequel" matches "seq"
|
||||
|
||||
|
||||
# ─── Argument parser registration ──────────────────────────────────────────
|
||||
|
||||
class TestSessionBrowseArgparse:
|
||||
"""Verify the 'browse' subcommand is properly registered."""
|
||||
|
||||
def test_browse_subcommand_exists(self):
|
||||
"""hermes sessions browse should be parseable."""
|
||||
from hermes_cli.main import main as _main_entry
|
||||
|
||||
# We can't run main(), but we can import and test the parser setup
|
||||
# by checking that argparse doesn't error on "sessions browse"
|
||||
import argparse
|
||||
# Re-create the parser portion
|
||||
# Instead, let's just verify the import works and the function exists
|
||||
from hermes_cli.main import _session_browse_picker
|
||||
assert callable(_session_browse_picker)
|
||||
|
||||
def test_browse_default_limit_is_50(self):
|
||||
"""The default --limit for browse should be 50."""
|
||||
# This test verifies at the argparse level
|
||||
# We test by running the parse on "sessions browse" args
|
||||
# Since we can't easily extract the subparser, verify via the
|
||||
# _session_browse_picker accepting large lists
|
||||
sessions = _make_sessions(50)
|
||||
assert len(sessions) == 50
|
||||
|
||||
|
||||
# ─── Integration: cmd_sessions browse action ────────────────────────────────
|
||||
|
||||
class TestCmdSessionsBrowse:
|
||||
"""Integration tests for the 'browse' action in cmd_sessions."""
|
||||
|
||||
def test_browse_no_sessions_prints_message(self, capsys):
|
||||
"""When no sessions exist, _session_browse_picker returns None and prints message."""
|
||||
result = _session_browse_picker([])
|
||||
assert result is None
|
||||
output = capsys.readouterr().out
|
||||
assert "No sessions found" in output
|
||||
|
||||
def test_browse_with_source_filter(self):
|
||||
"""The --source flag should be passed to list_sessions_rich."""
|
||||
sessions = [
|
||||
{"id": "s1", "source": "cli", "title": "CLI only", "preview": "", "last_active": time.time()},
|
||||
]
|
||||
|
||||
import builtins
|
||||
original_import = builtins.__import__
|
||||
|
||||
def mock_import(name, *args, **kwargs):
|
||||
if name == "curses":
|
||||
raise ImportError("no curses")
|
||||
return original_import(name, *args, **kwargs)
|
||||
|
||||
with patch.object(builtins, "__import__", side_effect=mock_import):
|
||||
with patch("builtins.input", return_value="1"):
|
||||
result = _session_browse_picker(sessions)
|
||||
|
||||
assert result == "s1"
|
||||
|
||||
|
||||
# ─── Edge cases ──────────────────────────────────────────────────────────────
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Edge case handling for the session browser."""
|
||||
|
||||
def test_sessions_with_missing_fields(self):
|
||||
"""Sessions with missing optional fields should not crash."""
|
||||
sessions = [
|
||||
{"id": "minimal_001", "source": "cli"}, # No title, preview, last_active
|
||||
]
|
||||
|
||||
import builtins
|
||||
original_import = builtins.__import__
|
||||
|
||||
def mock_import(name, *args, **kwargs):
|
||||
if name == "curses":
|
||||
raise ImportError("no curses")
|
||||
return original_import(name, *args, **kwargs)
|
||||
|
||||
with patch.object(builtins, "__import__", side_effect=mock_import):
|
||||
with patch("builtins.input", return_value="1"):
|
||||
result = _session_browse_picker(sessions)
|
||||
|
||||
assert result == "minimal_001"
|
||||
|
||||
def test_single_session(self):
|
||||
"""A single session in the list should work fine."""
|
||||
sessions = [
|
||||
{"id": "only_one", "source": "cli", "title": "Solo", "preview": "", "last_active": time.time()},
|
||||
]
|
||||
|
||||
import builtins
|
||||
original_import = builtins.__import__
|
||||
|
||||
def mock_import(name, *args, **kwargs):
|
||||
if name == "curses":
|
||||
raise ImportError("no curses")
|
||||
return original_import(name, *args, **kwargs)
|
||||
|
||||
with patch.object(builtins, "__import__", side_effect=mock_import):
|
||||
with patch("builtins.input", return_value="1"):
|
||||
result = _session_browse_picker(sessions)
|
||||
|
||||
assert result == "only_one"
|
||||
|
||||
def test_long_title_truncated_in_fallback(self, capsys):
|
||||
"""Very long titles should be truncated in fallback mode."""
|
||||
sessions = [{
|
||||
"id": "long_title_001",
|
||||
"source": "cli",
|
||||
"title": "A" * 100,
|
||||
"preview": "",
|
||||
"last_active": time.time(),
|
||||
}]
|
||||
|
||||
import builtins
|
||||
original_import = builtins.__import__
|
||||
|
||||
def mock_import(name, *args, **kwargs):
|
||||
if name == "curses":
|
||||
raise ImportError("no curses")
|
||||
return original_import(name, *args, **kwargs)
|
||||
|
||||
with patch.object(builtins, "__import__", side_effect=mock_import):
|
||||
with patch("builtins.input", return_value="q"):
|
||||
_session_browse_picker(sessions)
|
||||
|
||||
output = capsys.readouterr().out
|
||||
# Title should be truncated to 50 chars with "..."
|
||||
assert "..." in output
|
||||
|
||||
def test_relative_time_formatting(self, capsys):
|
||||
"""Verify various time deltas format correctly."""
|
||||
now = time.time()
|
||||
sessions = [
|
||||
{"id": "recent", "source": "cli", "title": None, "preview": "just now test", "last_active": now},
|
||||
{"id": "hour_ago", "source": "cli", "title": None, "preview": "hour ago test", "last_active": now - 7200},
|
||||
{"id": "days_ago", "source": "cli", "title": None, "preview": "days ago test", "last_active": now - 259200},
|
||||
]
|
||||
|
||||
import builtins
|
||||
original_import = builtins.__import__
|
||||
|
||||
def mock_import(name, *args, **kwargs):
|
||||
if name == "curses":
|
||||
raise ImportError("no curses")
|
||||
return original_import(name, *args, **kwargs)
|
||||
|
||||
with patch.object(builtins, "__import__", side_effect=mock_import):
|
||||
with patch("builtins.input", return_value="q"):
|
||||
_session_browse_picker(sessions)
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "just now" in output
|
||||
assert "2h ago" in output
|
||||
assert "3d ago" in output
|
||||
|
|
@ -38,7 +38,6 @@ class TestExplicitAllowlist:
|
|||
"OPENROUTER_API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
"ANTHROPIC_API_KEY",
|
||||
"NOUS_API_KEY",
|
||||
"WANDB_API_KEY",
|
||||
"TINKER_API_KEY",
|
||||
"HONCHO_API_KEY",
|
||||
|
|
|
|||
31
tests/hermes_cli/test_skills_hub.py
Normal file
31
tests/hermes_cli/test_skills_hub.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
from io import StringIO
|
||||
|
||||
from rich.console import Console
|
||||
|
||||
from hermes_cli.skills_hub import do_list
|
||||
|
||||
|
||||
def test_do_list_initializes_hub_dir(monkeypatch, tmp_path):
|
||||
import tools.skills_hub as hub
|
||||
import tools.skills_tool as skills_tool
|
||||
|
||||
hub_dir = tmp_path / "skills" / ".hub"
|
||||
monkeypatch.setattr(hub, "SKILLS_DIR", tmp_path / "skills")
|
||||
monkeypatch.setattr(hub, "HUB_DIR", hub_dir)
|
||||
monkeypatch.setattr(hub, "LOCK_FILE", hub_dir / "lock.json")
|
||||
monkeypatch.setattr(hub, "QUARANTINE_DIR", hub_dir / "quarantine")
|
||||
monkeypatch.setattr(hub, "AUDIT_LOG", hub_dir / "audit.log")
|
||||
monkeypatch.setattr(hub, "TAPS_FILE", hub_dir / "taps.json")
|
||||
monkeypatch.setattr(hub, "INDEX_CACHE_DIR", hub_dir / "index-cache")
|
||||
monkeypatch.setattr(skills_tool, "_find_all_skills", lambda: [])
|
||||
|
||||
console = Console(file=StringIO(), force_terminal=False, color_system=None)
|
||||
|
||||
assert not hub_dir.exists()
|
||||
|
||||
do_list(console=console)
|
||||
|
||||
assert hub_dir.exists()
|
||||
assert (hub_dir / "lock.json").exists()
|
||||
assert (hub_dir / "quarantine").is_dir()
|
||||
assert (hub_dir / "index-cache").is_dir()
|
||||
|
|
@ -12,7 +12,7 @@ Usage:
|
|||
|
||||
Requirements:
|
||||
- FIRECRAWL_API_KEY environment variable must be set
|
||||
- NOUS_API_KEY environment variable (optional, for LLM tests)
|
||||
- An auxiliary LLM provider (OPENROUTER_API_KEY or Nous Portal auth) (optional, for LLM tests)
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
|
@ -128,12 +128,12 @@ class WebToolsTester:
|
|||
else:
|
||||
self.log_result("Firecrawl API Key", "passed", "Found")
|
||||
|
||||
# Check Nous API key (optional)
|
||||
# Check auxiliary LLM provider (optional)
|
||||
if not check_auxiliary_model():
|
||||
self.log_result("Nous API Key", "skipped", "NOUS_API_KEY not set (LLM tests will be skipped)")
|
||||
self.log_result("Auxiliary LLM", "skipped", "No auxiliary LLM provider available (LLM tests will be skipped)")
|
||||
self.test_llm = False
|
||||
else:
|
||||
self.log_result("Nous API Key", "passed", "Found")
|
||||
self.log_result("Auxiliary LLM", "passed", "Found")
|
||||
|
||||
# Check debug mode
|
||||
debug_info = get_debug_session_info()
|
||||
|
|
|
|||
292
tests/test_auxiliary_config_bridge.py
Normal file
292
tests/test_auxiliary_config_bridge.py
Normal file
|
|
@ -0,0 +1,292 @@
|
|||
"""Tests for auxiliary model config bridging — verifies that config.yaml values
|
||||
are properly mapped to environment variables by both CLI and gateway loaders.
|
||||
|
||||
Also tests the vision_tools and browser_tool model override env vars.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
|
||||
def _run_auxiliary_bridge(config_dict, monkeypatch):
|
||||
"""Simulate the auxiliary config → env var bridging logic shared by CLI and gateway.
|
||||
|
||||
This mirrors the code in cli.py load_cli_config() and gateway/run.py.
|
||||
Both use the same pattern; we test it once here.
|
||||
"""
|
||||
# Clear env vars
|
||||
for key in (
|
||||
"AUXILIARY_VISION_PROVIDER", "AUXILIARY_VISION_MODEL",
|
||||
"AUXILIARY_WEB_EXTRACT_PROVIDER", "AUXILIARY_WEB_EXTRACT_MODEL",
|
||||
"CONTEXT_COMPRESSION_PROVIDER", "CONTEXT_COMPRESSION_MODEL",
|
||||
):
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
|
||||
# Compression bridge
|
||||
compression_cfg = config_dict.get("compression", {})
|
||||
if compression_cfg and isinstance(compression_cfg, dict):
|
||||
compression_env_map = {
|
||||
"enabled": "CONTEXT_COMPRESSION_ENABLED",
|
||||
"threshold": "CONTEXT_COMPRESSION_THRESHOLD",
|
||||
"summary_model": "CONTEXT_COMPRESSION_MODEL",
|
||||
"summary_provider": "CONTEXT_COMPRESSION_PROVIDER",
|
||||
}
|
||||
for cfg_key, env_var in compression_env_map.items():
|
||||
if cfg_key in compression_cfg:
|
||||
os.environ[env_var] = str(compression_cfg[cfg_key])
|
||||
|
||||
# Auxiliary bridge
|
||||
auxiliary_cfg = config_dict.get("auxiliary", {})
|
||||
if auxiliary_cfg and isinstance(auxiliary_cfg, dict):
|
||||
aux_task_env = {
|
||||
"vision": ("AUXILIARY_VISION_PROVIDER", "AUXILIARY_VISION_MODEL"),
|
||||
"web_extract": ("AUXILIARY_WEB_EXTRACT_PROVIDER", "AUXILIARY_WEB_EXTRACT_MODEL"),
|
||||
}
|
||||
for task_key, (prov_env, model_env) in aux_task_env.items():
|
||||
task_cfg = auxiliary_cfg.get(task_key, {})
|
||||
if not isinstance(task_cfg, dict):
|
||||
continue
|
||||
prov = str(task_cfg.get("provider", "")).strip()
|
||||
model = str(task_cfg.get("model", "")).strip()
|
||||
if prov and prov != "auto":
|
||||
os.environ[prov_env] = prov
|
||||
if model:
|
||||
os.environ[model_env] = model
|
||||
|
||||
|
||||
# ── Config bridging tests ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestAuxiliaryConfigBridge:
|
||||
"""Verify the config.yaml → env var bridging logic used by CLI and gateway."""
|
||||
|
||||
def test_vision_provider_bridged(self, monkeypatch):
|
||||
config = {
|
||||
"auxiliary": {
|
||||
"vision": {"provider": "openrouter", "model": ""},
|
||||
"web_extract": {"provider": "auto", "model": ""},
|
||||
}
|
||||
}
|
||||
_run_auxiliary_bridge(config, monkeypatch)
|
||||
assert os.environ.get("AUXILIARY_VISION_PROVIDER") == "openrouter"
|
||||
# auto should not be set
|
||||
assert os.environ.get("AUXILIARY_WEB_EXTRACT_PROVIDER") is None
|
||||
|
||||
def test_vision_model_bridged(self, monkeypatch):
|
||||
config = {
|
||||
"auxiliary": {
|
||||
"vision": {"provider": "auto", "model": "openai/gpt-4o"},
|
||||
}
|
||||
}
|
||||
_run_auxiliary_bridge(config, monkeypatch)
|
||||
assert os.environ.get("AUXILIARY_VISION_MODEL") == "openai/gpt-4o"
|
||||
# auto provider should not be set
|
||||
assert os.environ.get("AUXILIARY_VISION_PROVIDER") is None
|
||||
|
||||
def test_web_extract_bridged(self, monkeypatch):
|
||||
config = {
|
||||
"auxiliary": {
|
||||
"web_extract": {"provider": "nous", "model": "gemini-2.5-flash"},
|
||||
}
|
||||
}
|
||||
_run_auxiliary_bridge(config, monkeypatch)
|
||||
assert os.environ.get("AUXILIARY_WEB_EXTRACT_PROVIDER") == "nous"
|
||||
assert os.environ.get("AUXILIARY_WEB_EXTRACT_MODEL") == "gemini-2.5-flash"
|
||||
|
||||
def test_compression_provider_bridged(self, monkeypatch):
|
||||
config = {
|
||||
"compression": {
|
||||
"summary_provider": "nous",
|
||||
"summary_model": "gemini-3-flash",
|
||||
}
|
||||
}
|
||||
_run_auxiliary_bridge(config, monkeypatch)
|
||||
assert os.environ.get("CONTEXT_COMPRESSION_PROVIDER") == "nous"
|
||||
assert os.environ.get("CONTEXT_COMPRESSION_MODEL") == "gemini-3-flash"
|
||||
|
||||
def test_empty_values_not_bridged(self, monkeypatch):
|
||||
config = {
|
||||
"auxiliary": {
|
||||
"vision": {"provider": "auto", "model": ""},
|
||||
}
|
||||
}
|
||||
_run_auxiliary_bridge(config, monkeypatch)
|
||||
assert os.environ.get("AUXILIARY_VISION_PROVIDER") is None
|
||||
assert os.environ.get("AUXILIARY_VISION_MODEL") is None
|
||||
|
||||
def test_missing_auxiliary_section_safe(self, monkeypatch):
|
||||
"""Config without auxiliary section should not crash."""
|
||||
config = {"model": {"default": "test-model"}}
|
||||
_run_auxiliary_bridge(config, monkeypatch)
|
||||
assert os.environ.get("AUXILIARY_VISION_PROVIDER") is None
|
||||
|
||||
def test_non_dict_task_config_ignored(self, monkeypatch):
|
||||
"""Malformed task config (e.g. string instead of dict) is safely ignored."""
|
||||
config = {
|
||||
"auxiliary": {
|
||||
"vision": "openrouter", # should be a dict
|
||||
}
|
||||
}
|
||||
_run_auxiliary_bridge(config, monkeypatch)
|
||||
assert os.environ.get("AUXILIARY_VISION_PROVIDER") is None
|
||||
|
||||
def test_mixed_tasks(self, monkeypatch):
|
||||
config = {
|
||||
"auxiliary": {
|
||||
"vision": {"provider": "openrouter", "model": ""},
|
||||
"web_extract": {"provider": "auto", "model": "custom-llm"},
|
||||
}
|
||||
}
|
||||
_run_auxiliary_bridge(config, monkeypatch)
|
||||
assert os.environ.get("AUXILIARY_VISION_PROVIDER") == "openrouter"
|
||||
assert os.environ.get("AUXILIARY_VISION_MODEL") is None
|
||||
assert os.environ.get("AUXILIARY_WEB_EXTRACT_PROVIDER") is None
|
||||
assert os.environ.get("AUXILIARY_WEB_EXTRACT_MODEL") == "custom-llm"
|
||||
|
||||
def test_all_tasks_with_overrides(self, monkeypatch):
|
||||
config = {
|
||||
"compression": {
|
||||
"summary_provider": "main",
|
||||
"summary_model": "local-model",
|
||||
},
|
||||
"auxiliary": {
|
||||
"vision": {"provider": "openrouter", "model": "google/gemini-2.5-flash"},
|
||||
"web_extract": {"provider": "nous", "model": "gemini-3-flash"},
|
||||
}
|
||||
}
|
||||
_run_auxiliary_bridge(config, monkeypatch)
|
||||
assert os.environ.get("CONTEXT_COMPRESSION_PROVIDER") == "main"
|
||||
assert os.environ.get("CONTEXT_COMPRESSION_MODEL") == "local-model"
|
||||
assert os.environ.get("AUXILIARY_VISION_PROVIDER") == "openrouter"
|
||||
assert os.environ.get("AUXILIARY_VISION_MODEL") == "google/gemini-2.5-flash"
|
||||
assert os.environ.get("AUXILIARY_WEB_EXTRACT_PROVIDER") == "nous"
|
||||
assert os.environ.get("AUXILIARY_WEB_EXTRACT_MODEL") == "gemini-3-flash"
|
||||
|
||||
def test_whitespace_in_values_stripped(self, monkeypatch):
|
||||
config = {
|
||||
"auxiliary": {
|
||||
"vision": {"provider": " openrouter ", "model": " my-model "},
|
||||
}
|
||||
}
|
||||
_run_auxiliary_bridge(config, monkeypatch)
|
||||
assert os.environ.get("AUXILIARY_VISION_PROVIDER") == "openrouter"
|
||||
assert os.environ.get("AUXILIARY_VISION_MODEL") == "my-model"
|
||||
|
||||
def test_empty_auxiliary_dict_safe(self, monkeypatch):
|
||||
config = {"auxiliary": {}}
|
||||
_run_auxiliary_bridge(config, monkeypatch)
|
||||
assert os.environ.get("AUXILIARY_VISION_PROVIDER") is None
|
||||
assert os.environ.get("AUXILIARY_WEB_EXTRACT_PROVIDER") is None
|
||||
|
||||
|
||||
# ── Gateway bridge parity test ───────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGatewayBridgeCodeParity:
|
||||
"""Verify the gateway/run.py config bridge contains the auxiliary section."""
|
||||
|
||||
def test_gateway_has_auxiliary_bridge(self):
|
||||
"""The gateway config bridge must include auxiliary.* bridging."""
|
||||
gateway_path = Path(__file__).parent.parent / "gateway" / "run.py"
|
||||
content = gateway_path.read_text()
|
||||
# Check for key patterns that indicate the bridge is present
|
||||
assert "AUXILIARY_VISION_PROVIDER" in content
|
||||
assert "AUXILIARY_VISION_MODEL" in content
|
||||
assert "AUXILIARY_WEB_EXTRACT_PROVIDER" in content
|
||||
assert "AUXILIARY_WEB_EXTRACT_MODEL" in content
|
||||
|
||||
def test_gateway_has_compression_provider(self):
|
||||
"""Gateway must bridge compression.summary_provider."""
|
||||
gateway_path = Path(__file__).parent.parent / "gateway" / "run.py"
|
||||
content = gateway_path.read_text()
|
||||
assert "summary_provider" in content
|
||||
assert "CONTEXT_COMPRESSION_PROVIDER" in content
|
||||
|
||||
|
||||
# ── Vision model override tests ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestVisionModelOverride:
|
||||
"""Test that AUXILIARY_VISION_MODEL env var overrides the default model in the handler."""
|
||||
|
||||
def test_env_var_overrides_default(self, monkeypatch):
|
||||
monkeypatch.setenv("AUXILIARY_VISION_MODEL", "openai/gpt-4o")
|
||||
from tools.vision_tools import _handle_vision_analyze
|
||||
with patch("tools.vision_tools.vision_analyze_tool", new_callable=MagicMock) as mock_tool:
|
||||
mock_tool.return_value = '{"success": true}'
|
||||
_handle_vision_analyze({"image_url": "http://test.jpg", "question": "test"})
|
||||
call_args = mock_tool.call_args
|
||||
# 3rd positional arg = model
|
||||
assert call_args[0][2] == "openai/gpt-4o"
|
||||
|
||||
def test_default_model_when_no_override(self, monkeypatch):
|
||||
monkeypatch.delenv("AUXILIARY_VISION_MODEL", raising=False)
|
||||
from tools.vision_tools import _handle_vision_analyze, DEFAULT_VISION_MODEL
|
||||
with patch("tools.vision_tools.vision_analyze_tool", new_callable=MagicMock) as mock_tool:
|
||||
mock_tool.return_value = '{"success": true}'
|
||||
_handle_vision_analyze({"image_url": "http://test.jpg", "question": "test"})
|
||||
call_args = mock_tool.call_args
|
||||
expected = DEFAULT_VISION_MODEL or "google/gemini-3-flash-preview"
|
||||
assert call_args[0][2] == expected
|
||||
|
||||
|
||||
# ── DEFAULT_CONFIG shape tests ───────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestDefaultConfigShape:
|
||||
"""Verify the DEFAULT_CONFIG in hermes_cli/config.py has correct auxiliary structure."""
|
||||
|
||||
def test_auxiliary_section_exists(self):
|
||||
from hermes_cli.config import DEFAULT_CONFIG
|
||||
assert "auxiliary" in DEFAULT_CONFIG
|
||||
|
||||
def test_vision_task_structure(self):
|
||||
from hermes_cli.config import DEFAULT_CONFIG
|
||||
vision = DEFAULT_CONFIG["auxiliary"]["vision"]
|
||||
assert "provider" in vision
|
||||
assert "model" in vision
|
||||
assert vision["provider"] == "auto"
|
||||
assert vision["model"] == ""
|
||||
|
||||
def test_web_extract_task_structure(self):
|
||||
from hermes_cli.config import DEFAULT_CONFIG
|
||||
web = DEFAULT_CONFIG["auxiliary"]["web_extract"]
|
||||
assert "provider" in web
|
||||
assert "model" in web
|
||||
assert web["provider"] == "auto"
|
||||
assert web["model"] == ""
|
||||
|
||||
def test_compression_provider_default(self):
|
||||
from hermes_cli.config import DEFAULT_CONFIG
|
||||
compression = DEFAULT_CONFIG["compression"]
|
||||
assert "summary_provider" in compression
|
||||
assert compression["summary_provider"] == "auto"
|
||||
|
||||
|
||||
# ── CLI defaults parity ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCLIDefaultsHaveAuxiliaryKeys:
|
||||
"""Verify cli.py load_cli_config() defaults dict does NOT include auxiliary
|
||||
(it comes from config.yaml deep merge, not hardcoded defaults)."""
|
||||
|
||||
def test_cli_defaults_can_merge_auxiliary(self):
|
||||
"""The load_cli_config deep merge logic handles keys not in defaults.
|
||||
Verify auxiliary would be picked up from config.yaml."""
|
||||
# This is a structural assertion: cli.py's second-pass loop
|
||||
# carries over keys from file_config that aren't in defaults.
|
||||
# So auxiliary config from config.yaml gets merged even though
|
||||
# cli.py's defaults dict doesn't define it.
|
||||
import cli as _cli_mod
|
||||
source = Path(_cli_mod.__file__).read_text()
|
||||
assert "auxiliary_config = defaults.get(\"auxiliary\"" in source
|
||||
assert "AUXILIARY_VISION_PROVIDER" in source
|
||||
assert "AUXILIARY_VISION_MODEL" in source
|
||||
|
|
@ -162,6 +162,124 @@ def test_runtime_resolution_rebuilds_agent_on_routing_change(monkeypatch):
|
|||
assert shell.api_mode == "codex_responses"
|
||||
|
||||
|
||||
def test_codex_provider_replaces_incompatible_default_model(monkeypatch):
|
||||
"""When provider resolves to openai-codex and no model was explicitly
|
||||
chosen, the global config default (e.g. anthropic/claude-opus-4.6) must
|
||||
be replaced with a Codex-compatible model. Fixes #651."""
|
||||
cli = _import_cli()
|
||||
|
||||
monkeypatch.delenv("LLM_MODEL", raising=False)
|
||||
monkeypatch.delenv("OPENAI_MODEL", raising=False)
|
||||
|
||||
def _runtime_resolve(**kwargs):
|
||||
return {
|
||||
"provider": "openai-codex",
|
||||
"api_mode": "codex_responses",
|
||||
"base_url": "https://chatgpt.com/backend-api/codex",
|
||||
"api_key": "test-key",
|
||||
"source": "env/config",
|
||||
}
|
||||
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.resolve_runtime_provider", _runtime_resolve)
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.format_runtime_provider_error", lambda exc: str(exc))
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.codex_models.get_codex_model_ids",
|
||||
lambda access_token=None: ["gpt-5.2-codex", "gpt-5.1-codex-mini"],
|
||||
)
|
||||
|
||||
shell = cli.HermesCLI(compact=True, max_turns=1)
|
||||
|
||||
assert shell._model_is_default is True
|
||||
assert shell._ensure_runtime_credentials() is True
|
||||
assert shell.provider == "openai-codex"
|
||||
assert "anthropic" not in shell.model
|
||||
assert "claude" not in shell.model
|
||||
assert shell.model == "gpt-5.2-codex"
|
||||
|
||||
|
||||
def test_codex_provider_trusts_explicit_envvar_model(monkeypatch):
|
||||
"""When the user explicitly sets LLM_MODEL, we trust their choice and
|
||||
let the API be the judge — even if it's a non-OpenAI model. Only
|
||||
provider prefixes are stripped; the bare model passes through."""
|
||||
cli = _import_cli()
|
||||
|
||||
monkeypatch.setenv("LLM_MODEL", "claude-opus-4-6")
|
||||
monkeypatch.delenv("OPENAI_MODEL", raising=False)
|
||||
|
||||
def _runtime_resolve(**kwargs):
|
||||
return {
|
||||
"provider": "openai-codex",
|
||||
"api_mode": "codex_responses",
|
||||
"base_url": "https://chatgpt.com/backend-api/codex",
|
||||
"api_key": "test-key",
|
||||
"source": "env/config",
|
||||
}
|
||||
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.resolve_runtime_provider", _runtime_resolve)
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.format_runtime_provider_error", lambda exc: str(exc))
|
||||
|
||||
shell = cli.HermesCLI(compact=True, max_turns=1)
|
||||
|
||||
assert shell._model_is_default is False
|
||||
assert shell._ensure_runtime_credentials() is True
|
||||
assert shell.provider == "openai-codex"
|
||||
# User explicitly chose this model — it passes through untouched
|
||||
assert shell.model == "claude-opus-4-6"
|
||||
|
||||
|
||||
def test_codex_provider_preserves_explicit_codex_model(monkeypatch):
|
||||
"""If the user explicitly passes a Codex-compatible model, it must be
|
||||
preserved even when the provider resolves to openai-codex."""
|
||||
cli = _import_cli()
|
||||
|
||||
monkeypatch.delenv("LLM_MODEL", raising=False)
|
||||
monkeypatch.delenv("OPENAI_MODEL", raising=False)
|
||||
|
||||
def _runtime_resolve(**kwargs):
|
||||
return {
|
||||
"provider": "openai-codex",
|
||||
"api_mode": "codex_responses",
|
||||
"base_url": "https://chatgpt.com/backend-api/codex",
|
||||
"api_key": "test-key",
|
||||
"source": "env/config",
|
||||
}
|
||||
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.resolve_runtime_provider", _runtime_resolve)
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.format_runtime_provider_error", lambda exc: str(exc))
|
||||
|
||||
shell = cli.HermesCLI(model="gpt-5.1-codex-mini", compact=True, max_turns=1)
|
||||
|
||||
assert shell._model_is_default is False
|
||||
assert shell._ensure_runtime_credentials() is True
|
||||
assert shell.model == "gpt-5.1-codex-mini"
|
||||
|
||||
|
||||
def test_codex_provider_strips_provider_prefix_from_model(monkeypatch):
|
||||
"""openai/gpt-5.3-codex should become gpt-5.3-codex — the Codex
|
||||
Responses API does not accept provider-prefixed model slugs."""
|
||||
cli = _import_cli()
|
||||
|
||||
monkeypatch.delenv("LLM_MODEL", raising=False)
|
||||
monkeypatch.delenv("OPENAI_MODEL", raising=False)
|
||||
|
||||
def _runtime_resolve(**kwargs):
|
||||
return {
|
||||
"provider": "openai-codex",
|
||||
"api_mode": "codex_responses",
|
||||
"base_url": "https://chatgpt.com/backend-api/codex",
|
||||
"api_key": "test-key",
|
||||
"source": "env/config",
|
||||
}
|
||||
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.resolve_runtime_provider", _runtime_resolve)
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.format_runtime_provider_error", lambda exc: str(exc))
|
||||
|
||||
shell = cli.HermesCLI(model="openai/gpt-5.3-codex", compact=True, max_turns=1)
|
||||
|
||||
assert shell._ensure_runtime_credentials() is True
|
||||
assert shell.model == "gpt-5.3-codex"
|
||||
|
||||
|
||||
def test_cmd_model_falls_back_to_auto_on_invalid_provider(monkeypatch, capsys):
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.config.load_config",
|
||||
|
|
|
|||
|
|
@ -149,6 +149,7 @@ def test_gateway_run_agent_codex_path_handles_internal_401_refresh(monkeypatch):
|
|||
runner._prefill_messages = []
|
||||
runner._reasoning_config = None
|
||||
runner._provider_routing = {}
|
||||
runner._fallback_model = None
|
||||
runner._running_agents = {}
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
runner.hooks = MagicMock()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,9 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import patch
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
from hermes_cli.codex_models import DEFAULT_CODEX_MODELS, get_codex_model_ids
|
||||
|
||||
|
|
@ -13,7 +18,7 @@ def test_get_codex_model_ids_prioritizes_default_and_cache(tmp_path, monkeypatch
|
|||
"models": [
|
||||
{"slug": "gpt-5.3-codex", "priority": 20, "supported_in_api": True},
|
||||
{"slug": "gpt-5.1-codex", "priority": 5, "supported_in_api": True},
|
||||
{"slug": "gpt-4o", "priority": 1, "supported_in_api": True},
|
||||
{"slug": "gpt-5.4", "priority": 1, "supported_in_api": True},
|
||||
{"slug": "gpt-5-hidden-codex", "priority": 2, "visibility": "hidden"},
|
||||
]
|
||||
}
|
||||
|
|
@ -26,10 +31,19 @@ def test_get_codex_model_ids_prioritizes_default_and_cache(tmp_path, monkeypatch
|
|||
assert models[0] == "gpt-5.2-codex"
|
||||
assert "gpt-5.1-codex" in models
|
||||
assert "gpt-5.3-codex" in models
|
||||
assert "gpt-4o" not in models
|
||||
# Non-codex-suffixed models are included when the cache says they're available
|
||||
assert "gpt-5.4" in models
|
||||
assert "gpt-5-hidden-codex" not in models
|
||||
|
||||
|
||||
def test_setup_wizard_codex_import_resolves():
|
||||
"""Regression test for #712: setup.py must import the correct function name."""
|
||||
# This mirrors the exact import used in hermes_cli/setup.py line 873.
|
||||
# A prior bug had 'get_codex_models' (wrong) instead of 'get_codex_model_ids'.
|
||||
from hermes_cli.codex_models import get_codex_model_ids as setup_import
|
||||
assert callable(setup_import)
|
||||
|
||||
|
||||
def test_get_codex_model_ids_falls_back_to_curated_defaults(tmp_path, monkeypatch):
|
||||
codex_home = tmp_path / "codex-home"
|
||||
codex_home.mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -38,3 +52,144 @@ def test_get_codex_model_ids_falls_back_to_curated_defaults(tmp_path, monkeypatc
|
|||
models = get_codex_model_ids()
|
||||
|
||||
assert models[: len(DEFAULT_CODEX_MODELS)] == DEFAULT_CODEX_MODELS
|
||||
|
||||
|
||||
# ── Tests for _normalize_model_for_provider ──────────────────────────
|
||||
|
||||
|
||||
def _make_cli(model="anthropic/claude-opus-4.6", **kwargs):
|
||||
"""Create a HermesCLI with minimal mocking."""
|
||||
import cli as _cli_mod
|
||||
from cli import HermesCLI
|
||||
|
||||
_clean_config = {
|
||||
"model": {
|
||||
"default": "anthropic/claude-opus-4.6",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
"provider": "auto",
|
||||
},
|
||||
"display": {"compact": False, "tool_progress": "all", "resume_display": "full"},
|
||||
"agent": {},
|
||||
"terminal": {"env_type": "local"},
|
||||
}
|
||||
clean_env = {"LLM_MODEL": "", "HERMES_MAX_ITERATIONS": ""}
|
||||
with (
|
||||
patch("cli.get_tool_definitions", return_value=[]),
|
||||
patch.dict("os.environ", clean_env, clear=False),
|
||||
patch.dict(_cli_mod.__dict__, {"CLI_CONFIG": _clean_config}),
|
||||
):
|
||||
cli = HermesCLI(model=model, **kwargs)
|
||||
return cli
|
||||
|
||||
|
||||
class TestNormalizeModelForProvider:
|
||||
"""_normalize_model_for_provider() trusts user-selected models.
|
||||
|
||||
Only two things happen:
|
||||
1. Provider prefixes are stripped (API needs bare slugs)
|
||||
2. The *untouched default* model is swapped for a Codex model
|
||||
Everything else passes through — the API is the judge.
|
||||
"""
|
||||
|
||||
def test_non_codex_provider_is_noop(self):
|
||||
cli = _make_cli(model="gpt-5.4")
|
||||
changed = cli._normalize_model_for_provider("openrouter")
|
||||
assert changed is False
|
||||
assert cli.model == "gpt-5.4"
|
||||
|
||||
def test_bare_codex_model_passes_through(self):
|
||||
cli = _make_cli(model="gpt-5.3-codex")
|
||||
changed = cli._normalize_model_for_provider("openai-codex")
|
||||
assert changed is False
|
||||
assert cli.model == "gpt-5.3-codex"
|
||||
|
||||
def test_bare_non_codex_model_passes_through(self):
|
||||
"""gpt-5.4 (no 'codex' suffix) passes through — user chose it."""
|
||||
cli = _make_cli(model="gpt-5.4")
|
||||
changed = cli._normalize_model_for_provider("openai-codex")
|
||||
assert changed is False
|
||||
assert cli.model == "gpt-5.4"
|
||||
|
||||
def test_any_bare_model_trusted(self):
|
||||
"""Even a non-OpenAI bare model passes through — user explicitly set it."""
|
||||
cli = _make_cli(model="claude-opus-4-6")
|
||||
changed = cli._normalize_model_for_provider("openai-codex")
|
||||
# User explicitly chose this model — we trust them, API will error if wrong
|
||||
assert changed is False
|
||||
assert cli.model == "claude-opus-4-6"
|
||||
|
||||
def test_provider_prefix_stripped(self):
|
||||
"""openai/gpt-5.4 → gpt-5.4 (strip prefix, keep model)."""
|
||||
cli = _make_cli(model="openai/gpt-5.4")
|
||||
changed = cli._normalize_model_for_provider("openai-codex")
|
||||
assert changed is True
|
||||
assert cli.model == "gpt-5.4"
|
||||
|
||||
def test_any_provider_prefix_stripped(self):
|
||||
"""anthropic/claude-opus-4.6 → claude-opus-4.6 (strip prefix only).
|
||||
User explicitly chose this — let the API decide if it works."""
|
||||
cli = _make_cli(model="anthropic/claude-opus-4.6")
|
||||
changed = cli._normalize_model_for_provider("openai-codex")
|
||||
assert changed is True
|
||||
assert cli.model == "claude-opus-4.6"
|
||||
|
||||
def test_default_model_replaced(self):
|
||||
"""The untouched default (anthropic/claude-opus-4.6) gets swapped."""
|
||||
import cli as _cli_mod
|
||||
_clean_config = {
|
||||
"model": {
|
||||
"default": "anthropic/claude-opus-4.6",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
"provider": "auto",
|
||||
},
|
||||
"display": {"compact": False, "tool_progress": "all", "resume_display": "full"},
|
||||
"agent": {},
|
||||
"terminal": {"env_type": "local"},
|
||||
}
|
||||
# Don't pass model= so _model_is_default is True
|
||||
with (
|
||||
patch("cli.get_tool_definitions", return_value=[]),
|
||||
patch.dict("os.environ", {"LLM_MODEL": "", "HERMES_MAX_ITERATIONS": ""}, clear=False),
|
||||
patch.dict(_cli_mod.__dict__, {"CLI_CONFIG": _clean_config}),
|
||||
):
|
||||
from cli import HermesCLI
|
||||
cli = HermesCLI()
|
||||
|
||||
assert cli._model_is_default is True
|
||||
with patch(
|
||||
"hermes_cli.codex_models.get_codex_model_ids",
|
||||
return_value=["gpt-5.3-codex", "gpt-5.4"],
|
||||
):
|
||||
changed = cli._normalize_model_for_provider("openai-codex")
|
||||
assert changed is True
|
||||
# Uses first from available list
|
||||
assert cli.model == "gpt-5.3-codex"
|
||||
|
||||
def test_default_fallback_when_api_fails(self):
|
||||
"""Default model falls back to gpt-5.3-codex when API unreachable."""
|
||||
import cli as _cli_mod
|
||||
_clean_config = {
|
||||
"model": {
|
||||
"default": "anthropic/claude-opus-4.6",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
"provider": "auto",
|
||||
},
|
||||
"display": {"compact": False, "tool_progress": "all", "resume_display": "full"},
|
||||
"agent": {},
|
||||
"terminal": {"env_type": "local"},
|
||||
}
|
||||
with (
|
||||
patch("cli.get_tool_definitions", return_value=[]),
|
||||
patch.dict("os.environ", {"LLM_MODEL": "", "HERMES_MAX_ITERATIONS": ""}, clear=False),
|
||||
patch.dict(_cli_mod.__dict__, {"CLI_CONFIG": _clean_config}),
|
||||
):
|
||||
from cli import HermesCLI
|
||||
cli = HermesCLI()
|
||||
|
||||
with patch(
|
||||
"hermes_cli.codex_models.get_codex_model_ids",
|
||||
side_effect=Exception("offline"),
|
||||
):
|
||||
changed = cli._normalize_model_for_provider("openai-codex")
|
||||
assert changed is True
|
||||
assert cli.model == "gpt-5.3-codex"
|
||||
|
|
|
|||
339
tests/test_fallback_model.py
Normal file
339
tests/test_fallback_model.py
Normal file
|
|
@ -0,0 +1,339 @@
|
|||
"""Tests for the provider fallback model feature.
|
||||
|
||||
Verifies that AIAgent can switch to a configured fallback model/provider
|
||||
when the primary fails after retries.
|
||||
"""
|
||||
|
||||
import os
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from run_agent import AIAgent
|
||||
|
||||
|
||||
def _make_tool_defs(*names: str) -> list:
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": n,
|
||||
"description": f"{n} tool",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
}
|
||||
for n in names
|
||||
]
|
||||
|
||||
|
||||
def _make_agent(fallback_model=None):
|
||||
"""Create a minimal AIAgent with optional fallback config."""
|
||||
with (
|
||||
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI"),
|
||||
):
|
||||
agent = AIAgent(
|
||||
api_key="test-key-primary",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
fallback_model=fallback_model,
|
||||
)
|
||||
agent.client = MagicMock()
|
||||
return agent
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# _try_activate_fallback()
|
||||
# =============================================================================
|
||||
|
||||
class TestTryActivateFallback:
|
||||
def test_returns_false_when_not_configured(self):
|
||||
agent = _make_agent(fallback_model=None)
|
||||
assert agent._try_activate_fallback() is False
|
||||
assert agent._fallback_activated is False
|
||||
|
||||
def test_returns_false_for_empty_config(self):
|
||||
agent = _make_agent(fallback_model={"provider": "", "model": ""})
|
||||
assert agent._try_activate_fallback() is False
|
||||
|
||||
def test_returns_false_for_missing_provider(self):
|
||||
agent = _make_agent(fallback_model={"model": "gpt-4.1"})
|
||||
assert agent._try_activate_fallback() is False
|
||||
|
||||
def test_returns_false_for_missing_model(self):
|
||||
agent = _make_agent(fallback_model={"provider": "openrouter"})
|
||||
assert agent._try_activate_fallback() is False
|
||||
|
||||
def test_activates_openrouter_fallback(self):
|
||||
agent = _make_agent(
|
||||
fallback_model={"provider": "openrouter", "model": "anthropic/claude-sonnet-4"},
|
||||
)
|
||||
with (
|
||||
patch.dict("os.environ", {"OPENROUTER_API_KEY": "sk-or-fallback-key"}),
|
||||
patch("run_agent.OpenAI") as mock_openai,
|
||||
):
|
||||
result = agent._try_activate_fallback()
|
||||
assert result is True
|
||||
assert agent._fallback_activated is True
|
||||
assert agent.model == "anthropic/claude-sonnet-4"
|
||||
assert agent.provider == "openrouter"
|
||||
assert agent.api_mode == "chat_completions"
|
||||
mock_openai.assert_called_once()
|
||||
call_kwargs = mock_openai.call_args[1]
|
||||
assert call_kwargs["api_key"] == "sk-or-fallback-key"
|
||||
assert "openrouter" in call_kwargs["base_url"].lower()
|
||||
# OpenRouter should get attribution headers
|
||||
assert "default_headers" in call_kwargs
|
||||
|
||||
def test_activates_zai_fallback(self):
|
||||
agent = _make_agent(
|
||||
fallback_model={"provider": "zai", "model": "glm-5"},
|
||||
)
|
||||
with (
|
||||
patch.dict("os.environ", {"ZAI_API_KEY": "sk-zai-key"}),
|
||||
patch("run_agent.OpenAI") as mock_openai,
|
||||
):
|
||||
result = agent._try_activate_fallback()
|
||||
assert result is True
|
||||
assert agent.model == "glm-5"
|
||||
assert agent.provider == "zai"
|
||||
call_kwargs = mock_openai.call_args[1]
|
||||
assert call_kwargs["api_key"] == "sk-zai-key"
|
||||
assert "z.ai" in call_kwargs["base_url"].lower()
|
||||
|
||||
def test_activates_kimi_fallback(self):
|
||||
agent = _make_agent(
|
||||
fallback_model={"provider": "kimi-coding", "model": "kimi-k2.5"},
|
||||
)
|
||||
with (
|
||||
patch.dict("os.environ", {"KIMI_API_KEY": "sk-kimi-key"}),
|
||||
patch("run_agent.OpenAI"),
|
||||
):
|
||||
assert agent._try_activate_fallback() is True
|
||||
assert agent.model == "kimi-k2.5"
|
||||
assert agent.provider == "kimi-coding"
|
||||
|
||||
def test_activates_minimax_fallback(self):
|
||||
agent = _make_agent(
|
||||
fallback_model={"provider": "minimax", "model": "MiniMax-M2.5"},
|
||||
)
|
||||
with (
|
||||
patch.dict("os.environ", {"MINIMAX_API_KEY": "sk-mm-key"}),
|
||||
patch("run_agent.OpenAI") as mock_openai,
|
||||
):
|
||||
assert agent._try_activate_fallback() is True
|
||||
assert agent.model == "MiniMax-M2.5"
|
||||
assert agent.provider == "minimax"
|
||||
call_kwargs = mock_openai.call_args[1]
|
||||
assert "minimax.io" in call_kwargs["base_url"]
|
||||
|
||||
def test_only_fires_once(self):
|
||||
agent = _make_agent(
|
||||
fallback_model={"provider": "openrouter", "model": "anthropic/claude-sonnet-4"},
|
||||
)
|
||||
with (
|
||||
patch.dict("os.environ", {"OPENROUTER_API_KEY": "sk-or-key"}),
|
||||
patch("run_agent.OpenAI"),
|
||||
):
|
||||
assert agent._try_activate_fallback() is True
|
||||
# Second attempt should return False
|
||||
assert agent._try_activate_fallback() is False
|
||||
|
||||
def test_returns_false_when_no_api_key(self):
|
||||
"""Fallback should fail gracefully when the API key env var is unset."""
|
||||
agent = _make_agent(
|
||||
fallback_model={"provider": "minimax", "model": "MiniMax-M2.5"},
|
||||
)
|
||||
# Ensure MINIMAX_API_KEY is not in the environment
|
||||
env = {k: v for k, v in os.environ.items() if k != "MINIMAX_API_KEY"}
|
||||
with patch.dict("os.environ", env, clear=True):
|
||||
assert agent._try_activate_fallback() is False
|
||||
assert agent._fallback_activated is False
|
||||
|
||||
def test_custom_base_url(self):
|
||||
"""Custom base_url in config should override the provider default."""
|
||||
agent = _make_agent(
|
||||
fallback_model={
|
||||
"provider": "custom",
|
||||
"model": "my-model",
|
||||
"base_url": "http://localhost:8080/v1",
|
||||
"api_key_env": "MY_CUSTOM_KEY",
|
||||
},
|
||||
)
|
||||
with (
|
||||
patch.dict("os.environ", {"MY_CUSTOM_KEY": "custom-secret"}),
|
||||
patch("run_agent.OpenAI") as mock_openai,
|
||||
):
|
||||
assert agent._try_activate_fallback() is True
|
||||
call_kwargs = mock_openai.call_args[1]
|
||||
assert call_kwargs["base_url"] == "http://localhost:8080/v1"
|
||||
assert call_kwargs["api_key"] == "custom-secret"
|
||||
|
||||
def test_prompt_caching_enabled_for_claude_on_openrouter(self):
|
||||
agent = _make_agent(
|
||||
fallback_model={"provider": "openrouter", "model": "anthropic/claude-sonnet-4"},
|
||||
)
|
||||
with (
|
||||
patch.dict("os.environ", {"OPENROUTER_API_KEY": "sk-or-key"}),
|
||||
patch("run_agent.OpenAI"),
|
||||
):
|
||||
agent._try_activate_fallback()
|
||||
assert agent._use_prompt_caching is True
|
||||
|
||||
def test_prompt_caching_disabled_for_non_claude(self):
|
||||
agent = _make_agent(
|
||||
fallback_model={"provider": "openrouter", "model": "google/gemini-2.5-flash"},
|
||||
)
|
||||
with (
|
||||
patch.dict("os.environ", {"OPENROUTER_API_KEY": "sk-or-key"}),
|
||||
patch("run_agent.OpenAI"),
|
||||
):
|
||||
agent._try_activate_fallback()
|
||||
assert agent._use_prompt_caching is False
|
||||
|
||||
def test_prompt_caching_disabled_for_non_openrouter(self):
|
||||
agent = _make_agent(
|
||||
fallback_model={"provider": "zai", "model": "glm-5"},
|
||||
)
|
||||
with (
|
||||
patch.dict("os.environ", {"ZAI_API_KEY": "sk-zai-key"}),
|
||||
patch("run_agent.OpenAI"),
|
||||
):
|
||||
agent._try_activate_fallback()
|
||||
assert agent._use_prompt_caching is False
|
||||
|
||||
def test_zai_alt_env_var(self):
|
||||
"""Z.AI should also check Z_AI_API_KEY as fallback env var."""
|
||||
agent = _make_agent(
|
||||
fallback_model={"provider": "zai", "model": "glm-5"},
|
||||
)
|
||||
with (
|
||||
patch.dict("os.environ", {"Z_AI_API_KEY": "sk-alt-key"}),
|
||||
patch("run_agent.OpenAI") as mock_openai,
|
||||
):
|
||||
assert agent._try_activate_fallback() is True
|
||||
call_kwargs = mock_openai.call_args[1]
|
||||
assert call_kwargs["api_key"] == "sk-alt-key"
|
||||
|
||||
def test_activates_codex_fallback(self):
|
||||
"""OpenAI Codex fallback should use OAuth credentials and codex_responses mode."""
|
||||
agent = _make_agent(
|
||||
fallback_model={"provider": "openai-codex", "model": "gpt-5.3-codex"},
|
||||
)
|
||||
mock_creds = {
|
||||
"api_key": "codex-oauth-token",
|
||||
"base_url": "https://chatgpt.com/backend-api/codex",
|
||||
}
|
||||
with (
|
||||
patch("hermes_cli.auth.resolve_codex_runtime_credentials", return_value=mock_creds),
|
||||
patch("run_agent.OpenAI") as mock_openai,
|
||||
):
|
||||
result = agent._try_activate_fallback()
|
||||
assert result is True
|
||||
assert agent.model == "gpt-5.3-codex"
|
||||
assert agent.provider == "openai-codex"
|
||||
assert agent.api_mode == "codex_responses"
|
||||
call_kwargs = mock_openai.call_args[1]
|
||||
assert call_kwargs["api_key"] == "codex-oauth-token"
|
||||
assert "chatgpt.com" in call_kwargs["base_url"]
|
||||
|
||||
def test_codex_fallback_fails_gracefully_without_credentials(self):
|
||||
"""Codex fallback should return False if no OAuth credentials available."""
|
||||
agent = _make_agent(
|
||||
fallback_model={"provider": "openai-codex", "model": "gpt-5.3-codex"},
|
||||
)
|
||||
with patch(
|
||||
"hermes_cli.auth.resolve_codex_runtime_credentials",
|
||||
side_effect=Exception("No Codex credentials"),
|
||||
):
|
||||
assert agent._try_activate_fallback() is False
|
||||
assert agent._fallback_activated is False
|
||||
|
||||
def test_activates_nous_fallback(self):
|
||||
"""Nous Portal fallback should use OAuth credentials and chat_completions mode."""
|
||||
agent = _make_agent(
|
||||
fallback_model={"provider": "nous", "model": "nous-hermes-3"},
|
||||
)
|
||||
mock_creds = {
|
||||
"api_key": "nous-agent-key-abc",
|
||||
"base_url": "https://inference-api.nousresearch.com/v1",
|
||||
}
|
||||
with (
|
||||
patch("hermes_cli.auth.resolve_nous_runtime_credentials", return_value=mock_creds),
|
||||
patch("run_agent.OpenAI") as mock_openai,
|
||||
):
|
||||
result = agent._try_activate_fallback()
|
||||
assert result is True
|
||||
assert agent.model == "nous-hermes-3"
|
||||
assert agent.provider == "nous"
|
||||
assert agent.api_mode == "chat_completions"
|
||||
call_kwargs = mock_openai.call_args[1]
|
||||
assert call_kwargs["api_key"] == "nous-agent-key-abc"
|
||||
assert "nousresearch.com" in call_kwargs["base_url"]
|
||||
|
||||
def test_nous_fallback_fails_gracefully_without_login(self):
|
||||
"""Nous fallback should return False if not logged in."""
|
||||
agent = _make_agent(
|
||||
fallback_model={"provider": "nous", "model": "nous-hermes-3"},
|
||||
)
|
||||
with patch(
|
||||
"hermes_cli.auth.resolve_nous_runtime_credentials",
|
||||
side_effect=Exception("Not logged in to Nous Portal"),
|
||||
):
|
||||
assert agent._try_activate_fallback() is False
|
||||
assert agent._fallback_activated is False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Fallback config init
|
||||
# =============================================================================
|
||||
|
||||
class TestFallbackInit:
|
||||
def test_fallback_stored_when_configured(self):
|
||||
agent = _make_agent(
|
||||
fallback_model={"provider": "openrouter", "model": "anthropic/claude-sonnet-4"},
|
||||
)
|
||||
assert agent._fallback_model is not None
|
||||
assert agent._fallback_model["provider"] == "openrouter"
|
||||
assert agent._fallback_activated is False
|
||||
|
||||
def test_fallback_none_when_not_configured(self):
|
||||
agent = _make_agent(fallback_model=None)
|
||||
assert agent._fallback_model is None
|
||||
assert agent._fallback_activated is False
|
||||
|
||||
def test_fallback_none_for_non_dict(self):
|
||||
agent = _make_agent(fallback_model="not-a-dict")
|
||||
assert agent._fallback_model is None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Provider credential resolution
|
||||
# =============================================================================
|
||||
|
||||
class TestProviderCredentials:
|
||||
"""Verify that each supported provider resolves its API key correctly."""
|
||||
|
||||
@pytest.mark.parametrize("provider,env_var,base_url_fragment", [
|
||||
("openrouter", "OPENROUTER_API_KEY", "openrouter"),
|
||||
("zai", "ZAI_API_KEY", "z.ai"),
|
||||
("kimi-coding", "KIMI_API_KEY", "moonshot.ai"),
|
||||
("minimax", "MINIMAX_API_KEY", "minimax.io"),
|
||||
("minimax-cn", "MINIMAX_CN_API_KEY", "minimaxi.com"),
|
||||
])
|
||||
def test_provider_resolves(self, provider, env_var, base_url_fragment):
|
||||
agent = _make_agent(
|
||||
fallback_model={"provider": provider, "model": "test-model"},
|
||||
)
|
||||
with (
|
||||
patch.dict("os.environ", {env_var: "test-key-123"}),
|
||||
patch("run_agent.OpenAI") as mock_openai,
|
||||
):
|
||||
result = agent._try_activate_fallback()
|
||||
assert result is True, f"Failed to activate fallback for {provider}"
|
||||
call_kwargs = mock_openai.call_args[1]
|
||||
assert call_kwargs["api_key"] == "test-key-123"
|
||||
assert base_url_fragment in call_kwargs["base_url"].lower()
|
||||
|
|
@ -351,6 +351,173 @@ class TestPruneSessions:
|
|||
# Schema and WAL mode
|
||||
# =========================================================================
|
||||
|
||||
# =========================================================================
|
||||
# Session title
|
||||
# =========================================================================
|
||||
|
||||
class TestSessionTitle:
|
||||
def test_set_and_get_title(self, db):
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
assert db.set_session_title("s1", "My Session") is True
|
||||
|
||||
session = db.get_session("s1")
|
||||
assert session["title"] == "My Session"
|
||||
|
||||
def test_set_title_nonexistent_session(self, db):
|
||||
assert db.set_session_title("nonexistent", "Title") is False
|
||||
|
||||
def test_title_initially_none(self, db):
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
session = db.get_session("s1")
|
||||
assert session["title"] is None
|
||||
|
||||
def test_update_title(self, db):
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.set_session_title("s1", "First Title")
|
||||
db.set_session_title("s1", "Updated Title")
|
||||
|
||||
session = db.get_session("s1")
|
||||
assert session["title"] == "Updated Title"
|
||||
|
||||
def test_title_in_search_sessions(self, db):
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.set_session_title("s1", "Debugging Auth")
|
||||
db.create_session(session_id="s2", source="cli")
|
||||
|
||||
sessions = db.search_sessions()
|
||||
titled = [s for s in sessions if s.get("title") == "Debugging Auth"]
|
||||
assert len(titled) == 1
|
||||
assert titled[0]["id"] == "s1"
|
||||
|
||||
def test_title_in_export(self, db):
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.set_session_title("s1", "Export Test")
|
||||
db.append_message("s1", role="user", content="Hello")
|
||||
|
||||
export = db.export_session("s1")
|
||||
assert export["title"] == "Export Test"
|
||||
|
||||
def test_title_with_special_characters(self, db):
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
title = "PR #438 — fixing the 'auth' middleware"
|
||||
db.set_session_title("s1", title)
|
||||
|
||||
session = db.get_session("s1")
|
||||
assert session["title"] == title
|
||||
|
||||
def test_title_empty_string_normalized_to_none(self, db):
|
||||
"""Empty strings are normalized to None (clearing the title)."""
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.set_session_title("s1", "My Title")
|
||||
# Setting to empty string should clear the title (normalize to None)
|
||||
db.set_session_title("s1", "")
|
||||
|
||||
session = db.get_session("s1")
|
||||
assert session["title"] is None
|
||||
|
||||
def test_multiple_empty_titles_no_conflict(self, db):
|
||||
"""Multiple sessions can have empty-string (normalized to NULL) titles."""
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.create_session(session_id="s2", source="cli")
|
||||
db.set_session_title("s1", "")
|
||||
db.set_session_title("s2", "")
|
||||
# Both should be None, no uniqueness conflict
|
||||
assert db.get_session("s1")["title"] is None
|
||||
assert db.get_session("s2")["title"] is None
|
||||
|
||||
def test_title_survives_end_session(self, db):
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.set_session_title("s1", "Before End")
|
||||
db.end_session("s1", end_reason="user_exit")
|
||||
|
||||
session = db.get_session("s1")
|
||||
assert session["title"] == "Before End"
|
||||
assert session["ended_at"] is not None
|
||||
|
||||
|
||||
class TestSanitizeTitle:
|
||||
"""Tests for SessionDB.sanitize_title() validation and cleaning."""
|
||||
|
||||
def test_normal_title_unchanged(self):
|
||||
assert SessionDB.sanitize_title("My Project") == "My Project"
|
||||
|
||||
def test_strips_whitespace(self):
|
||||
assert SessionDB.sanitize_title(" hello world ") == "hello world"
|
||||
|
||||
def test_collapses_internal_whitespace(self):
|
||||
assert SessionDB.sanitize_title("hello world") == "hello world"
|
||||
|
||||
def test_tabs_and_newlines_collapsed(self):
|
||||
assert SessionDB.sanitize_title("hello\t\nworld") == "hello world"
|
||||
|
||||
def test_none_returns_none(self):
|
||||
assert SessionDB.sanitize_title(None) is None
|
||||
|
||||
def test_empty_string_returns_none(self):
|
||||
assert SessionDB.sanitize_title("") is None
|
||||
|
||||
def test_whitespace_only_returns_none(self):
|
||||
assert SessionDB.sanitize_title(" \t\n ") is None
|
||||
|
||||
def test_control_chars_stripped(self):
|
||||
# Null byte, bell, backspace, etc.
|
||||
assert SessionDB.sanitize_title("hello\x00world") == "helloworld"
|
||||
assert SessionDB.sanitize_title("\x07\x08test\x1b") == "test"
|
||||
|
||||
def test_del_char_stripped(self):
|
||||
assert SessionDB.sanitize_title("hello\x7fworld") == "helloworld"
|
||||
|
||||
def test_zero_width_chars_stripped(self):
|
||||
# Zero-width space (U+200B), zero-width joiner (U+200D)
|
||||
assert SessionDB.sanitize_title("hello\u200bworld") == "helloworld"
|
||||
assert SessionDB.sanitize_title("hello\u200dworld") == "helloworld"
|
||||
|
||||
def test_rtl_override_stripped(self):
|
||||
# Right-to-left override (U+202E) — used in filename spoofing attacks
|
||||
assert SessionDB.sanitize_title("hello\u202eworld") == "helloworld"
|
||||
|
||||
def test_bom_stripped(self):
|
||||
# Byte order mark (U+FEFF)
|
||||
assert SessionDB.sanitize_title("\ufeffhello") == "hello"
|
||||
|
||||
def test_only_control_chars_returns_none(self):
|
||||
assert SessionDB.sanitize_title("\x00\x01\x02\u200b\ufeff") is None
|
||||
|
||||
def test_max_length_allowed(self):
|
||||
title = "A" * 100
|
||||
assert SessionDB.sanitize_title(title) == title
|
||||
|
||||
def test_exceeds_max_length_raises(self):
|
||||
title = "A" * 101
|
||||
with pytest.raises(ValueError, match="too long"):
|
||||
SessionDB.sanitize_title(title)
|
||||
|
||||
def test_unicode_emoji_allowed(self):
|
||||
assert SessionDB.sanitize_title("🚀 My Project 🎉") == "🚀 My Project 🎉"
|
||||
|
||||
def test_cjk_characters_allowed(self):
|
||||
assert SessionDB.sanitize_title("我的项目") == "我的项目"
|
||||
|
||||
def test_accented_characters_allowed(self):
|
||||
assert SessionDB.sanitize_title("Résumé éditing") == "Résumé éditing"
|
||||
|
||||
def test_special_punctuation_allowed(self):
|
||||
title = "PR #438 — fixing the 'auth' middleware"
|
||||
assert SessionDB.sanitize_title(title) == title
|
||||
|
||||
def test_sanitize_applied_in_set_session_title(self, db):
|
||||
"""set_session_title applies sanitize_title internally."""
|
||||
db.create_session("s1", "cli")
|
||||
db.set_session_title("s1", " hello\x00 world ")
|
||||
assert db.get_session("s1")["title"] == "hello world"
|
||||
|
||||
def test_too_long_title_rejected_by_set(self, db):
|
||||
"""set_session_title raises ValueError for overly long titles."""
|
||||
db.create_session("s1", "cli")
|
||||
with pytest.raises(ValueError, match="too long"):
|
||||
db.set_session_title("s1", "X" * 150)
|
||||
|
||||
|
||||
class TestSchemaInit:
|
||||
def test_wal_mode(self, db):
|
||||
cursor = db._conn.execute("PRAGMA journal_mode")
|
||||
|
|
@ -373,4 +540,297 @@ class TestSchemaInit:
|
|||
def test_schema_version(self, db):
|
||||
cursor = db._conn.execute("SELECT version FROM schema_version")
|
||||
version = cursor.fetchone()[0]
|
||||
assert version == 2
|
||||
assert version == 4
|
||||
|
||||
def test_title_column_exists(self, db):
|
||||
"""Verify the title column was created in the sessions table."""
|
||||
cursor = db._conn.execute("PRAGMA table_info(sessions)")
|
||||
columns = {row[1] for row in cursor.fetchall()}
|
||||
assert "title" in columns
|
||||
|
||||
def test_migration_from_v2(self, tmp_path):
|
||||
"""Simulate a v2 database and verify migration adds title column."""
|
||||
import sqlite3
|
||||
|
||||
db_path = tmp_path / "migrate_test.db"
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
# Create v2 schema (without title column)
|
||||
conn.executescript("""
|
||||
CREATE TABLE schema_version (version INTEGER NOT NULL);
|
||||
INSERT INTO schema_version (version) VALUES (2);
|
||||
|
||||
CREATE TABLE sessions (
|
||||
id TEXT PRIMARY KEY,
|
||||
source TEXT NOT NULL,
|
||||
user_id TEXT,
|
||||
model TEXT,
|
||||
model_config TEXT,
|
||||
system_prompt TEXT,
|
||||
parent_session_id TEXT,
|
||||
started_at REAL NOT NULL,
|
||||
ended_at REAL,
|
||||
end_reason TEXT,
|
||||
message_count INTEGER DEFAULT 0,
|
||||
tool_call_count INTEGER DEFAULT 0,
|
||||
input_tokens INTEGER DEFAULT 0,
|
||||
output_tokens INTEGER DEFAULT 0
|
||||
);
|
||||
|
||||
CREATE TABLE messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
session_id TEXT NOT NULL,
|
||||
role TEXT NOT NULL,
|
||||
content TEXT,
|
||||
tool_call_id TEXT,
|
||||
tool_calls TEXT,
|
||||
tool_name TEXT,
|
||||
timestamp REAL NOT NULL,
|
||||
token_count INTEGER,
|
||||
finish_reason TEXT
|
||||
);
|
||||
""")
|
||||
conn.execute(
|
||||
"INSERT INTO sessions (id, source, started_at) VALUES (?, ?, ?)",
|
||||
("existing", "cli", 1000.0),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
# Open with SessionDB — should migrate to v4
|
||||
migrated_db = SessionDB(db_path=db_path)
|
||||
|
||||
# Verify migration
|
||||
cursor = migrated_db._conn.execute("SELECT version FROM schema_version")
|
||||
assert cursor.fetchone()[0] == 4
|
||||
|
||||
# Verify title column exists and is NULL for existing sessions
|
||||
session = migrated_db.get_session("existing")
|
||||
assert session is not None
|
||||
assert session["title"] is None
|
||||
|
||||
# Verify we can set title on migrated session
|
||||
assert migrated_db.set_session_title("existing", "Migrated Title") is True
|
||||
session = migrated_db.get_session("existing")
|
||||
assert session["title"] == "Migrated Title"
|
||||
|
||||
migrated_db.close()
|
||||
|
||||
|
||||
class TestTitleUniqueness:
|
||||
"""Tests for unique title enforcement and title-based lookups."""
|
||||
|
||||
def test_duplicate_title_raises(self, db):
|
||||
"""Setting a title already used by another session raises ValueError."""
|
||||
db.create_session("s1", "cli")
|
||||
db.create_session("s2", "cli")
|
||||
db.set_session_title("s1", "my project")
|
||||
with pytest.raises(ValueError, match="already in use"):
|
||||
db.set_session_title("s2", "my project")
|
||||
|
||||
def test_same_session_can_keep_title(self, db):
|
||||
"""A session can re-set its own title without error."""
|
||||
db.create_session("s1", "cli")
|
||||
db.set_session_title("s1", "my project")
|
||||
# Should not raise — it's the same session
|
||||
assert db.set_session_title("s1", "my project") is True
|
||||
|
||||
def test_null_titles_not_unique(self, db):
|
||||
"""Multiple sessions can have NULL titles (no constraint violation)."""
|
||||
db.create_session("s1", "cli")
|
||||
db.create_session("s2", "cli")
|
||||
# Both have NULL titles — no error
|
||||
assert db.get_session("s1")["title"] is None
|
||||
assert db.get_session("s2")["title"] is None
|
||||
|
||||
def test_get_session_by_title(self, db):
|
||||
db.create_session("s1", "cli")
|
||||
db.set_session_title("s1", "refactoring auth")
|
||||
result = db.get_session_by_title("refactoring auth")
|
||||
assert result is not None
|
||||
assert result["id"] == "s1"
|
||||
|
||||
def test_get_session_by_title_not_found(self, db):
|
||||
assert db.get_session_by_title("nonexistent") is None
|
||||
|
||||
def test_get_session_title(self, db):
|
||||
db.create_session("s1", "cli")
|
||||
assert db.get_session_title("s1") is None
|
||||
db.set_session_title("s1", "my title")
|
||||
assert db.get_session_title("s1") == "my title"
|
||||
|
||||
def test_get_session_title_nonexistent(self, db):
|
||||
assert db.get_session_title("nonexistent") is None
|
||||
|
||||
|
||||
class TestTitleLineage:
|
||||
"""Tests for title lineage resolution and auto-numbering."""
|
||||
|
||||
def test_resolve_exact_title(self, db):
|
||||
db.create_session("s1", "cli")
|
||||
db.set_session_title("s1", "my project")
|
||||
assert db.resolve_session_by_title("my project") == "s1"
|
||||
|
||||
def test_resolve_returns_latest_numbered(self, db):
|
||||
"""When numbered variants exist, return the most recent one."""
|
||||
import time
|
||||
db.create_session("s1", "cli")
|
||||
db.set_session_title("s1", "my project")
|
||||
time.sleep(0.01)
|
||||
db.create_session("s2", "cli")
|
||||
db.set_session_title("s2", "my project #2")
|
||||
time.sleep(0.01)
|
||||
db.create_session("s3", "cli")
|
||||
db.set_session_title("s3", "my project #3")
|
||||
# Resolving "my project" should return s3 (latest numbered variant)
|
||||
assert db.resolve_session_by_title("my project") == "s3"
|
||||
|
||||
def test_resolve_exact_numbered(self, db):
|
||||
"""Resolving an exact numbered title returns that specific session."""
|
||||
db.create_session("s1", "cli")
|
||||
db.set_session_title("s1", "my project")
|
||||
db.create_session("s2", "cli")
|
||||
db.set_session_title("s2", "my project #2")
|
||||
# Resolving "my project #2" exactly should return s2
|
||||
assert db.resolve_session_by_title("my project #2") == "s2"
|
||||
|
||||
def test_resolve_nonexistent_title(self, db):
|
||||
assert db.resolve_session_by_title("nonexistent") is None
|
||||
|
||||
def test_next_title_no_existing(self, db):
|
||||
"""With no existing sessions, base title is returned as-is."""
|
||||
assert db.get_next_title_in_lineage("my project") == "my project"
|
||||
|
||||
def test_next_title_first_continuation(self, db):
|
||||
"""First continuation after the original gets #2."""
|
||||
db.create_session("s1", "cli")
|
||||
db.set_session_title("s1", "my project")
|
||||
assert db.get_next_title_in_lineage("my project") == "my project #2"
|
||||
|
||||
def test_next_title_increments(self, db):
|
||||
"""Each continuation increments the number."""
|
||||
db.create_session("s1", "cli")
|
||||
db.set_session_title("s1", "my project")
|
||||
db.create_session("s2", "cli")
|
||||
db.set_session_title("s2", "my project #2")
|
||||
db.create_session("s3", "cli")
|
||||
db.set_session_title("s3", "my project #3")
|
||||
assert db.get_next_title_in_lineage("my project") == "my project #4"
|
||||
|
||||
def test_next_title_strips_existing_number(self, db):
|
||||
"""Passing a numbered title strips the number and finds the base."""
|
||||
db.create_session("s1", "cli")
|
||||
db.set_session_title("s1", "my project")
|
||||
db.create_session("s2", "cli")
|
||||
db.set_session_title("s2", "my project #2")
|
||||
# Even when called with "my project #2", it should return #3
|
||||
assert db.get_next_title_in_lineage("my project #2") == "my project #3"
|
||||
|
||||
|
||||
class TestTitleSqlWildcards:
|
||||
"""Titles containing SQL LIKE wildcards (%, _) must not cause false matches."""
|
||||
|
||||
def test_resolve_title_with_underscore(self, db):
|
||||
"""A title like 'test_project' should not match 'testXproject #2'."""
|
||||
db.create_session("s1", "cli")
|
||||
db.set_session_title("s1", "test_project")
|
||||
db.create_session("s2", "cli")
|
||||
db.set_session_title("s2", "testXproject #2")
|
||||
# Resolving "test_project" should return s1 (exact), not s2
|
||||
assert db.resolve_session_by_title("test_project") == "s1"
|
||||
|
||||
def test_resolve_title_with_percent(self, db):
|
||||
"""A title with '%' should not wildcard-match unrelated sessions."""
|
||||
db.create_session("s1", "cli")
|
||||
db.set_session_title("s1", "100% done")
|
||||
db.create_session("s2", "cli")
|
||||
db.set_session_title("s2", "100X done #2")
|
||||
# Should resolve to s1 (exact), not s2
|
||||
assert db.resolve_session_by_title("100% done") == "s1"
|
||||
|
||||
def test_next_lineage_with_underscore(self, db):
|
||||
"""get_next_title_in_lineage with underscores doesn't match wrong sessions."""
|
||||
db.create_session("s1", "cli")
|
||||
db.set_session_title("s1", "test_project")
|
||||
db.create_session("s2", "cli")
|
||||
db.set_session_title("s2", "testXproject #2")
|
||||
# Only "test_project" exists, so next should be "test_project #2"
|
||||
assert db.get_next_title_in_lineage("test_project") == "test_project #2"
|
||||
|
||||
|
||||
class TestListSessionsRich:
|
||||
"""Tests for enhanced session listing with preview and last_active."""
|
||||
|
||||
def test_preview_from_first_user_message(self, db):
|
||||
db.create_session("s1", "cli")
|
||||
db.append_message("s1", "system", "You are a helpful assistant.")
|
||||
db.append_message("s1", "user", "Help me refactor the auth module please")
|
||||
db.append_message("s1", "assistant", "Sure, let me look at it.")
|
||||
sessions = db.list_sessions_rich()
|
||||
assert len(sessions) == 1
|
||||
assert "Help me refactor the auth module" in sessions[0]["preview"]
|
||||
|
||||
def test_preview_truncated_at_60(self, db):
|
||||
db.create_session("s1", "cli")
|
||||
long_msg = "A" * 100
|
||||
db.append_message("s1", "user", long_msg)
|
||||
sessions = db.list_sessions_rich()
|
||||
assert len(sessions[0]["preview"]) == 63 # 60 chars + "..."
|
||||
assert sessions[0]["preview"].endswith("...")
|
||||
|
||||
def test_preview_empty_when_no_user_messages(self, db):
|
||||
db.create_session("s1", "cli")
|
||||
db.append_message("s1", "system", "System prompt")
|
||||
sessions = db.list_sessions_rich()
|
||||
assert sessions[0]["preview"] == ""
|
||||
|
||||
def test_last_active_from_latest_message(self, db):
|
||||
import time
|
||||
db.create_session("s1", "cli")
|
||||
db.append_message("s1", "user", "Hello")
|
||||
time.sleep(0.01)
|
||||
db.append_message("s1", "assistant", "Hi there!")
|
||||
sessions = db.list_sessions_rich()
|
||||
# last_active should be close to now (the assistant message)
|
||||
assert sessions[0]["last_active"] > sessions[0]["started_at"]
|
||||
|
||||
def test_last_active_fallback_to_started_at(self, db):
|
||||
db.create_session("s1", "cli")
|
||||
sessions = db.list_sessions_rich()
|
||||
# No messages, so last_active falls back to started_at
|
||||
assert sessions[0]["last_active"] == sessions[0]["started_at"]
|
||||
|
||||
def test_rich_list_includes_title(self, db):
|
||||
db.create_session("s1", "cli")
|
||||
db.set_session_title("s1", "refactoring auth")
|
||||
sessions = db.list_sessions_rich()
|
||||
assert sessions[0]["title"] == "refactoring auth"
|
||||
|
||||
def test_rich_list_source_filter(self, db):
|
||||
db.create_session("s1", "cli")
|
||||
db.create_session("s2", "telegram")
|
||||
sessions = db.list_sessions_rich(source="cli")
|
||||
assert len(sessions) == 1
|
||||
assert sessions[0]["id"] == "s1"
|
||||
|
||||
def test_preview_newlines_collapsed(self, db):
|
||||
db.create_session("s1", "cli")
|
||||
db.append_message("s1", "user", "Line one\nLine two\nLine three")
|
||||
sessions = db.list_sessions_rich()
|
||||
assert "\n" not in sessions[0]["preview"]
|
||||
assert "Line one Line two" in sessions[0]["preview"]
|
||||
|
||||
|
||||
class TestResolveSessionByNameOrId:
|
||||
"""Tests for the main.py helper that resolves names or IDs."""
|
||||
|
||||
def test_resolve_by_id(self, db):
|
||||
db.create_session("test-id-123", "cli")
|
||||
session = db.get_session("test-id-123")
|
||||
assert session is not None
|
||||
assert session["id"] == "test-id-123"
|
||||
|
||||
def test_resolve_by_title_falls_back(self, db):
|
||||
db.create_session("s1", "cli")
|
||||
db.set_session_title("s1", "my project")
|
||||
result = db.resolve_session_by_title("my project")
|
||||
assert result == "s1"
|
||||
|
|
|
|||
488
tests/test_resume_display.py
Normal file
488
tests/test_resume_display.py
Normal file
|
|
@ -0,0 +1,488 @@
|
|||
"""Tests for session resume history display — _display_resumed_history() and
|
||||
_preload_resumed_session().
|
||||
|
||||
Verifies that resuming a session shows a compact recap of the previous
|
||||
conversation with correct formatting, truncation, and config behavior.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from io import StringIO
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
|
||||
def _make_cli(config_overrides=None, env_overrides=None, **kwargs):
|
||||
"""Create a HermesCLI instance with minimal mocking."""
|
||||
import cli as _cli_mod
|
||||
from cli import HermesCLI
|
||||
|
||||
_clean_config = {
|
||||
"model": {
|
||||
"default": "anthropic/claude-opus-4.6",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
"provider": "auto",
|
||||
},
|
||||
"display": {"compact": False, "tool_progress": "all", "resume_display": "full"},
|
||||
"agent": {},
|
||||
"terminal": {"env_type": "local"},
|
||||
}
|
||||
if config_overrides:
|
||||
for k, v in config_overrides.items():
|
||||
if isinstance(v, dict) and k in _clean_config and isinstance(_clean_config[k], dict):
|
||||
_clean_config[k].update(v)
|
||||
else:
|
||||
_clean_config[k] = v
|
||||
|
||||
clean_env = {"LLM_MODEL": "", "HERMES_MAX_ITERATIONS": ""}
|
||||
if env_overrides:
|
||||
clean_env.update(env_overrides)
|
||||
with (
|
||||
patch("cli.get_tool_definitions", return_value=[]),
|
||||
patch.dict("os.environ", clean_env, clear=False),
|
||||
patch.dict(_cli_mod.__dict__, {"CLI_CONFIG": _clean_config}),
|
||||
):
|
||||
return HermesCLI(**kwargs)
|
||||
|
||||
|
||||
# ── Sample conversation histories for tests ──────────────────────────
|
||||
|
||||
|
||||
def _simple_history():
|
||||
"""Two-turn conversation: user → assistant → user → assistant."""
|
||||
return [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What is Python?"},
|
||||
{"role": "assistant", "content": "Python is a high-level programming language."},
|
||||
{"role": "user", "content": "How do I install it?"},
|
||||
{"role": "assistant", "content": "You can install Python from python.org."},
|
||||
]
|
||||
|
||||
|
||||
def _tool_call_history():
|
||||
"""Conversation with tool calls and tool results."""
|
||||
return [
|
||||
{"role": "system", "content": "system prompt"},
|
||||
{"role": "user", "content": "Search for Python tutorials"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {"name": "web_search", "arguments": '{"query":"python tutorials"}'},
|
||||
},
|
||||
{
|
||||
"id": "call_2",
|
||||
"type": "function",
|
||||
"function": {"name": "web_extract", "arguments": '{"urls":["https://example.com"]}'},
|
||||
},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_1", "content": "Found 5 results..."},
|
||||
{"role": "tool", "tool_call_id": "call_2", "content": "Page content..."},
|
||||
{"role": "assistant", "content": "Here are some great Python tutorials I found."},
|
||||
]
|
||||
|
||||
|
||||
def _large_history(n_exchanges=15):
|
||||
"""Build a history with many exchanges to test truncation."""
|
||||
msgs = [{"role": "system", "content": "system prompt"}]
|
||||
for i in range(n_exchanges):
|
||||
msgs.append({"role": "user", "content": f"Question #{i + 1}: What is item {i + 1}?"})
|
||||
msgs.append({"role": "assistant", "content": f"Answer #{i + 1}: Item {i + 1} is great."})
|
||||
return msgs
|
||||
|
||||
|
||||
def _multimodal_history():
|
||||
"""Conversation with multimodal (image) content."""
|
||||
return [
|
||||
{"role": "system", "content": "system prompt"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{"type": "image_url", "image_url": {"url": "https://example.com/cat.jpg"}},
|
||||
],
|
||||
},
|
||||
{"role": "assistant", "content": "I see a cat in the image."},
|
||||
]
|
||||
|
||||
|
||||
# ── Tests for _display_resumed_history ───────────────────────────────
|
||||
|
||||
|
||||
class TestDisplayResumedHistory:
|
||||
"""_display_resumed_history() renders a Rich panel with conversation recap."""
|
||||
|
||||
def _capture_display(self, cli_obj):
|
||||
"""Run _display_resumed_history and capture the Rich console output."""
|
||||
buf = StringIO()
|
||||
cli_obj.console.file = buf
|
||||
cli_obj._display_resumed_history()
|
||||
return buf.getvalue()
|
||||
|
||||
def test_simple_history_shows_user_and_assistant(self):
|
||||
cli = _make_cli()
|
||||
cli.conversation_history = _simple_history()
|
||||
output = self._capture_display(cli)
|
||||
|
||||
assert "You:" in output
|
||||
assert "Hermes:" in output
|
||||
assert "What is Python?" in output
|
||||
assert "Python is a high-level programming language." in output
|
||||
assert "How do I install it?" in output
|
||||
|
||||
def test_system_messages_hidden(self):
|
||||
cli = _make_cli()
|
||||
cli.conversation_history = _simple_history()
|
||||
output = self._capture_display(cli)
|
||||
|
||||
assert "You are a helpful assistant" not in output
|
||||
|
||||
def test_tool_messages_hidden(self):
|
||||
cli = _make_cli()
|
||||
cli.conversation_history = _tool_call_history()
|
||||
output = self._capture_display(cli)
|
||||
|
||||
# Tool result content should NOT appear
|
||||
assert "Found 5 results" not in output
|
||||
assert "Page content" not in output
|
||||
|
||||
def test_tool_calls_shown_as_summary(self):
|
||||
cli = _make_cli()
|
||||
cli.conversation_history = _tool_call_history()
|
||||
output = self._capture_display(cli)
|
||||
|
||||
assert "2 tool calls" in output
|
||||
assert "web_search" in output
|
||||
assert "web_extract" in output
|
||||
|
||||
def test_long_user_message_truncated(self):
|
||||
cli = _make_cli()
|
||||
long_text = "A" * 500
|
||||
cli.conversation_history = [
|
||||
{"role": "user", "content": long_text},
|
||||
{"role": "assistant", "content": "OK."},
|
||||
]
|
||||
output = self._capture_display(cli)
|
||||
|
||||
# Should have truncation indicator and NOT contain the full 500 chars
|
||||
assert "..." in output
|
||||
assert "A" * 500 not in output
|
||||
# The 300-char truncated text is present but may be line-wrapped by
|
||||
# Rich's panel renderer, so check the total A count in the output
|
||||
a_count = output.count("A")
|
||||
assert 200 <= a_count <= 310 # roughly 300 chars (±panel padding)
|
||||
|
||||
def test_long_assistant_message_truncated(self):
|
||||
cli = _make_cli()
|
||||
long_text = "B" * 400
|
||||
cli.conversation_history = [
|
||||
{"role": "user", "content": "Tell me a lot."},
|
||||
{"role": "assistant", "content": long_text},
|
||||
]
|
||||
output = self._capture_display(cli)
|
||||
|
||||
assert "..." in output
|
||||
assert "B" * 400 not in output
|
||||
|
||||
def test_multiline_assistant_truncated(self):
|
||||
cli = _make_cli()
|
||||
multi = "\n".join([f"Line {i}" for i in range(20)])
|
||||
cli.conversation_history = [
|
||||
{"role": "user", "content": "Show me lines."},
|
||||
{"role": "assistant", "content": multi},
|
||||
]
|
||||
output = self._capture_display(cli)
|
||||
|
||||
# First 3 lines should be there
|
||||
assert "Line 0" in output
|
||||
assert "Line 1" in output
|
||||
assert "Line 2" in output
|
||||
# Line 19 should NOT be there (truncated after 3 lines)
|
||||
assert "Line 19" not in output
|
||||
|
||||
def test_large_history_shows_truncation_indicator(self):
|
||||
cli = _make_cli()
|
||||
cli.conversation_history = _large_history(n_exchanges=15)
|
||||
output = self._capture_display(cli)
|
||||
|
||||
# Should show "earlier messages" indicator
|
||||
assert "earlier messages" in output
|
||||
# Last question should still be visible
|
||||
assert "Question #15" in output
|
||||
|
||||
def test_multimodal_content_handled(self):
|
||||
cli = _make_cli()
|
||||
cli.conversation_history = _multimodal_history()
|
||||
output = self._capture_display(cli)
|
||||
|
||||
assert "What's in this image?" in output
|
||||
assert "[image]" in output
|
||||
|
||||
def test_empty_history_no_output(self):
|
||||
cli = _make_cli()
|
||||
cli.conversation_history = []
|
||||
output = self._capture_display(cli)
|
||||
|
||||
assert output.strip() == ""
|
||||
|
||||
def test_minimal_config_suppresses_display(self):
|
||||
cli = _make_cli(config_overrides={"display": {"resume_display": "minimal"}})
|
||||
# resume_display is captured as an instance variable during __init__
|
||||
assert cli.resume_display == "minimal"
|
||||
cli.conversation_history = _simple_history()
|
||||
output = self._capture_display(cli)
|
||||
|
||||
assert output.strip() == ""
|
||||
|
||||
def test_panel_has_title(self):
|
||||
cli = _make_cli()
|
||||
cli.conversation_history = _simple_history()
|
||||
output = self._capture_display(cli)
|
||||
|
||||
assert "Previous Conversation" in output
|
||||
|
||||
def test_assistant_with_no_content_no_tools_skipped(self):
|
||||
"""Assistant messages with no visible output (e.g. pure reasoning)
|
||||
are skipped in the recap."""
|
||||
cli = _make_cli()
|
||||
cli.conversation_history = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": None},
|
||||
]
|
||||
output = self._capture_display(cli)
|
||||
|
||||
# The assistant entry should be skipped, only the user message shown
|
||||
assert "You:" in output
|
||||
assert "Hermes:" not in output
|
||||
|
||||
def test_only_system_messages_no_output(self):
|
||||
cli = _make_cli()
|
||||
cli.conversation_history = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
]
|
||||
output = self._capture_display(cli)
|
||||
|
||||
assert output.strip() == ""
|
||||
|
||||
def test_reasoning_scratchpad_stripped(self):
|
||||
"""<REASONING_SCRATCHPAD> blocks should be stripped from display."""
|
||||
cli = _make_cli()
|
||||
cli.conversation_history = [
|
||||
{"role": "user", "content": "Think about this"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": (
|
||||
"<REASONING_SCRATCHPAD>\nLet me think step by step.\n"
|
||||
"</REASONING_SCRATCHPAD>\n\nThe answer is 42."
|
||||
),
|
||||
},
|
||||
]
|
||||
output = self._capture_display(cli)
|
||||
|
||||
assert "REASONING_SCRATCHPAD" not in output
|
||||
assert "Let me think step by step" not in output
|
||||
assert "The answer is 42" in output
|
||||
|
||||
def test_pure_reasoning_message_skipped(self):
|
||||
"""Assistant messages that are only reasoning should be skipped."""
|
||||
cli = _make_cli()
|
||||
cli.conversation_history = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "<REASONING_SCRATCHPAD>\nJust thinking...\n</REASONING_SCRATCHPAD>",
|
||||
},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
output = self._capture_display(cli)
|
||||
|
||||
assert "Just thinking" not in output
|
||||
assert "Hi there!" in output
|
||||
|
||||
def test_assistant_with_text_and_tool_calls(self):
|
||||
"""When an assistant message has both text content AND tool_calls."""
|
||||
cli = _make_cli()
|
||||
cli.conversation_history = [
|
||||
{"role": "user", "content": "Do something complex"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Let me search for that.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {"name": "terminal", "arguments": '{"command":"ls"}'},
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
output = self._capture_display(cli)
|
||||
|
||||
assert "Let me search for that." in output
|
||||
assert "1 tool call" in output
|
||||
assert "terminal" in output
|
||||
|
||||
|
||||
# ── Tests for _preload_resumed_session ──────────────────────────────
|
||||
|
||||
|
||||
class TestPreloadResumedSession:
|
||||
"""_preload_resumed_session() loads session from DB early."""
|
||||
|
||||
def test_returns_false_when_not_resumed(self):
|
||||
cli = _make_cli()
|
||||
assert cli._preload_resumed_session() is False
|
||||
|
||||
def test_returns_false_when_no_session_db(self):
|
||||
cli = _make_cli(resume="test_session_id")
|
||||
cli._session_db = None
|
||||
assert cli._preload_resumed_session() is False
|
||||
|
||||
def test_returns_false_when_session_not_found(self):
|
||||
cli = _make_cli(resume="nonexistent_session")
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_session.return_value = None
|
||||
cli._session_db = mock_db
|
||||
|
||||
buf = StringIO()
|
||||
cli.console.file = buf
|
||||
result = cli._preload_resumed_session()
|
||||
|
||||
assert result is False
|
||||
output = buf.getvalue()
|
||||
assert "Session not found" in output
|
||||
|
||||
def test_returns_false_when_session_has_no_messages(self):
|
||||
cli = _make_cli(resume="empty_session")
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_session.return_value = {"id": "empty_session", "title": None}
|
||||
mock_db.get_messages_as_conversation.return_value = []
|
||||
cli._session_db = mock_db
|
||||
|
||||
buf = StringIO()
|
||||
cli.console.file = buf
|
||||
result = cli._preload_resumed_session()
|
||||
|
||||
assert result is False
|
||||
output = buf.getvalue()
|
||||
assert "no messages" in output
|
||||
|
||||
def test_loads_session_successfully(self):
|
||||
cli = _make_cli(resume="good_session")
|
||||
messages = _simple_history()
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_session.return_value = {"id": "good_session", "title": "Test Session"}
|
||||
mock_db.get_messages_as_conversation.return_value = messages
|
||||
cli._session_db = mock_db
|
||||
|
||||
buf = StringIO()
|
||||
cli.console.file = buf
|
||||
result = cli._preload_resumed_session()
|
||||
|
||||
assert result is True
|
||||
assert cli.conversation_history == messages
|
||||
output = buf.getvalue()
|
||||
assert "Resumed session" in output
|
||||
assert "good_session" in output
|
||||
assert "Test Session" in output
|
||||
assert "2 user messages" in output
|
||||
|
||||
def test_reopens_session_in_db(self):
|
||||
cli = _make_cli(resume="reopen_session")
|
||||
messages = [{"role": "user", "content": "hi"}]
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_session.return_value = {"id": "reopen_session", "title": None}
|
||||
mock_db.get_messages_as_conversation.return_value = messages
|
||||
mock_conn = MagicMock()
|
||||
mock_db._conn = mock_conn
|
||||
cli._session_db = mock_db
|
||||
|
||||
buf = StringIO()
|
||||
cli.console.file = buf
|
||||
cli._preload_resumed_session()
|
||||
|
||||
# Should have executed UPDATE to clear ended_at
|
||||
mock_conn.execute.assert_called_once()
|
||||
call_args = mock_conn.execute.call_args
|
||||
assert "ended_at = NULL" in call_args[0][0]
|
||||
mock_conn.commit.assert_called_once()
|
||||
|
||||
def test_singular_user_message_grammar(self):
|
||||
"""1 user message should say 'message' not 'messages'."""
|
||||
cli = _make_cli(resume="one_msg_session")
|
||||
messages = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "hi"},
|
||||
]
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_session.return_value = {"id": "one_msg_session", "title": None}
|
||||
mock_db.get_messages_as_conversation.return_value = messages
|
||||
mock_db._conn = MagicMock()
|
||||
cli._session_db = mock_db
|
||||
|
||||
buf = StringIO()
|
||||
cli.console.file = buf
|
||||
cli._preload_resumed_session()
|
||||
|
||||
output = buf.getvalue()
|
||||
assert "1 user message," in output
|
||||
assert "1 user messages" not in output
|
||||
|
||||
|
||||
# ── Integration: _init_agent skips when preloaded ────────────────────
|
||||
|
||||
|
||||
class TestInitAgentSkipsPreloaded:
|
||||
"""_init_agent() should skip DB load when history is already populated."""
|
||||
|
||||
def test_init_agent_skips_db_when_preloaded(self):
|
||||
"""If conversation_history is already set, _init_agent should not
|
||||
reload from the DB."""
|
||||
cli = _make_cli(resume="preloaded_session")
|
||||
cli.conversation_history = _simple_history()
|
||||
|
||||
mock_db = MagicMock()
|
||||
cli._session_db = mock_db
|
||||
|
||||
# _init_agent will fail at credential resolution (no real API key),
|
||||
# but the session-loading block should be skipped entirely
|
||||
with patch.object(cli, "_ensure_runtime_credentials", return_value=False):
|
||||
cli._init_agent()
|
||||
|
||||
# get_messages_as_conversation should NOT have been called
|
||||
mock_db.get_messages_as_conversation.assert_not_called()
|
||||
|
||||
|
||||
# ── Config default tests ─────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestResumeDisplayConfig:
|
||||
"""resume_display config option defaults and behavior."""
|
||||
|
||||
def test_default_config_has_resume_display(self):
|
||||
"""DEFAULT_CONFIG in hermes_cli/config.py includes resume_display."""
|
||||
from hermes_cli.config import DEFAULT_CONFIG
|
||||
display = DEFAULT_CONFIG.get("display", {})
|
||||
assert "resume_display" in display
|
||||
assert display["resume_display"] == "full"
|
||||
|
||||
def test_cli_defaults_have_resume_display(self):
|
||||
"""cli.py load_cli_config defaults include resume_display."""
|
||||
import cli as _cli_mod
|
||||
from cli import load_cli_config
|
||||
|
||||
with (
|
||||
patch("pathlib.Path.exists", return_value=False),
|
||||
patch.dict("os.environ", {"LLM_MODEL": ""}, clear=False),
|
||||
):
|
||||
config = load_cli_config()
|
||||
|
||||
display = config.get("display", {})
|
||||
assert display.get("resume_display") == "full"
|
||||
|
|
@ -1040,3 +1040,136 @@ class TestMaxTokensParam:
|
|||
agent.base_url = "https://openrouter.ai/api/v1/api.openai.com"
|
||||
result = agent._max_tokens_param(4096)
|
||||
assert result == {"max_tokens": 4096}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# System prompt stability for prompt caching
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSystemPromptStability:
|
||||
"""Verify that the system prompt stays stable across turns for cache hits."""
|
||||
|
||||
def test_stored_prompt_reused_for_continuing_session(self, agent):
|
||||
"""When conversation_history is non-empty and session DB has a stored
|
||||
prompt, it should be reused instead of rebuilding from disk."""
|
||||
stored = "You are helpful. [stored from turn 1]"
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_session.return_value = {"system_prompt": stored}
|
||||
agent._session_db = mock_db
|
||||
|
||||
# Simulate a continuing session with history
|
||||
history = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "hi"},
|
||||
]
|
||||
|
||||
# First call — _cached_system_prompt is None, history is non-empty
|
||||
agent._cached_system_prompt = None
|
||||
|
||||
# Patch run_conversation internals to just test the system prompt logic.
|
||||
# We'll call the prompt caching block directly by simulating what
|
||||
# run_conversation does.
|
||||
conversation_history = history
|
||||
|
||||
# The block under test (from run_conversation):
|
||||
if agent._cached_system_prompt is None:
|
||||
stored_prompt = None
|
||||
if conversation_history and agent._session_db:
|
||||
try:
|
||||
session_row = agent._session_db.get_session(agent.session_id)
|
||||
if session_row:
|
||||
stored_prompt = session_row.get("system_prompt") or None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if stored_prompt:
|
||||
agent._cached_system_prompt = stored_prompt
|
||||
|
||||
assert agent._cached_system_prompt == stored
|
||||
mock_db.get_session.assert_called_once_with(agent.session_id)
|
||||
|
||||
def test_fresh_build_when_no_history(self, agent):
|
||||
"""On the first turn (no history), system prompt should be built fresh."""
|
||||
mock_db = MagicMock()
|
||||
agent._session_db = mock_db
|
||||
|
||||
agent._cached_system_prompt = None
|
||||
conversation_history = []
|
||||
|
||||
# The block under test:
|
||||
if agent._cached_system_prompt is None:
|
||||
stored_prompt = None
|
||||
if conversation_history and agent._session_db:
|
||||
session_row = agent._session_db.get_session(agent.session_id)
|
||||
if session_row:
|
||||
stored_prompt = session_row.get("system_prompt") or None
|
||||
|
||||
if stored_prompt:
|
||||
agent._cached_system_prompt = stored_prompt
|
||||
else:
|
||||
agent._cached_system_prompt = agent._build_system_prompt()
|
||||
|
||||
# Should have built fresh, not queried the DB
|
||||
mock_db.get_session.assert_not_called()
|
||||
assert agent._cached_system_prompt is not None
|
||||
assert "Hermes Agent" in agent._cached_system_prompt
|
||||
|
||||
def test_fresh_build_when_db_has_no_prompt(self, agent):
|
||||
"""If the session DB has no stored prompt, build fresh even with history."""
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_session.return_value = {"system_prompt": ""}
|
||||
agent._session_db = mock_db
|
||||
|
||||
agent._cached_system_prompt = None
|
||||
conversation_history = [{"role": "user", "content": "hi"}]
|
||||
|
||||
if agent._cached_system_prompt is None:
|
||||
stored_prompt = None
|
||||
if conversation_history and agent._session_db:
|
||||
try:
|
||||
session_row = agent._session_db.get_session(agent.session_id)
|
||||
if session_row:
|
||||
stored_prompt = session_row.get("system_prompt") or None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if stored_prompt:
|
||||
agent._cached_system_prompt = stored_prompt
|
||||
else:
|
||||
agent._cached_system_prompt = agent._build_system_prompt()
|
||||
|
||||
# Empty string is falsy, so should fall through to fresh build
|
||||
assert "Hermes Agent" in agent._cached_system_prompt
|
||||
|
||||
def test_honcho_context_baked_into_prompt_on_first_turn(self, agent):
|
||||
"""Honcho context should be baked into _cached_system_prompt on
|
||||
the first turn, not injected separately per API call."""
|
||||
agent._honcho_context = "User prefers Python over JavaScript."
|
||||
agent._cached_system_prompt = None
|
||||
|
||||
# Simulate first turn: build fresh and bake in Honcho
|
||||
agent._cached_system_prompt = agent._build_system_prompt()
|
||||
if agent._honcho_context:
|
||||
agent._cached_system_prompt = (
|
||||
agent._cached_system_prompt + "\n\n" + agent._honcho_context
|
||||
).strip()
|
||||
|
||||
assert "User prefers Python over JavaScript" in agent._cached_system_prompt
|
||||
|
||||
def test_honcho_prefetch_skipped_on_continuing_session(self):
|
||||
"""Honcho prefetch should not be called when conversation_history
|
||||
is non-empty (continuing session)."""
|
||||
conversation_history = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "hi there"},
|
||||
]
|
||||
|
||||
# The guard: `not conversation_history` is False when history exists
|
||||
should_prefetch = not conversation_history
|
||||
assert should_prefetch is False
|
||||
|
||||
def test_honcho_prefetch_runs_on_first_turn(self):
|
||||
"""Honcho prefetch should run when conversation_history is empty."""
|
||||
conversation_history = []
|
||||
should_prefetch = not conversation_history
|
||||
assert should_prefetch is True
|
||||
|
|
|
|||
276
tests/tools/test_browser_console.py
Normal file
276
tests/tools/test_browser_console.py
Normal file
|
|
@ -0,0 +1,276 @@
|
|||
"""Tests for browser_console tool and browser_vision annotate param."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
|
||||
|
||||
# ── browser_console ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestBrowserConsole:
|
||||
"""browser_console() returns console messages + JS errors in one call."""
|
||||
|
||||
def test_returns_console_messages_and_errors(self):
|
||||
from tools.browser_tool import browser_console
|
||||
|
||||
console_response = {
|
||||
"success": True,
|
||||
"data": {
|
||||
"messages": [
|
||||
{"text": "hello", "type": "log", "timestamp": 1},
|
||||
{"text": "oops", "type": "error", "timestamp": 2},
|
||||
]
|
||||
},
|
||||
}
|
||||
errors_response = {
|
||||
"success": True,
|
||||
"data": {
|
||||
"errors": [
|
||||
{"message": "Uncaught TypeError", "timestamp": 3},
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
with patch("tools.browser_tool._run_browser_command") as mock_cmd:
|
||||
mock_cmd.side_effect = [console_response, errors_response]
|
||||
result = json.loads(browser_console(task_id="test"))
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["total_messages"] == 2
|
||||
assert result["total_errors"] == 1
|
||||
assert result["console_messages"][0]["text"] == "hello"
|
||||
assert result["console_messages"][1]["text"] == "oops"
|
||||
assert result["js_errors"][0]["message"] == "Uncaught TypeError"
|
||||
|
||||
def test_passes_clear_flag(self):
|
||||
from tools.browser_tool import browser_console
|
||||
|
||||
empty = {"success": True, "data": {"messages": [], "errors": []}}
|
||||
with patch("tools.browser_tool._run_browser_command", return_value=empty) as mock_cmd:
|
||||
browser_console(clear=True, task_id="test")
|
||||
|
||||
calls = mock_cmd.call_args_list
|
||||
# Both console and errors should get --clear
|
||||
assert calls[0][0] == ("test", "console", ["--clear"])
|
||||
assert calls[1][0] == ("test", "errors", ["--clear"])
|
||||
|
||||
def test_no_clear_by_default(self):
|
||||
from tools.browser_tool import browser_console
|
||||
|
||||
empty = {"success": True, "data": {"messages": [], "errors": []}}
|
||||
with patch("tools.browser_tool._run_browser_command", return_value=empty) as mock_cmd:
|
||||
browser_console(task_id="test")
|
||||
|
||||
calls = mock_cmd.call_args_list
|
||||
assert calls[0][0] == ("test", "console", [])
|
||||
assert calls[1][0] == ("test", "errors", [])
|
||||
|
||||
def test_empty_console_and_errors(self):
|
||||
from tools.browser_tool import browser_console
|
||||
|
||||
empty = {"success": True, "data": {"messages": [], "errors": []}}
|
||||
with patch("tools.browser_tool._run_browser_command", return_value=empty):
|
||||
result = json.loads(browser_console(task_id="test"))
|
||||
|
||||
assert result["total_messages"] == 0
|
||||
assert result["total_errors"] == 0
|
||||
assert result["console_messages"] == []
|
||||
assert result["js_errors"] == []
|
||||
|
||||
def test_handles_failed_commands(self):
|
||||
from tools.browser_tool import browser_console
|
||||
|
||||
failed = {"success": False, "error": "No session"}
|
||||
with patch("tools.browser_tool._run_browser_command", return_value=failed):
|
||||
result = json.loads(browser_console(task_id="test"))
|
||||
|
||||
# Should still return success with empty data
|
||||
assert result["success"] is True
|
||||
assert result["total_messages"] == 0
|
||||
assert result["total_errors"] == 0
|
||||
|
||||
|
||||
# ── browser_console schema ───────────────────────────────────────────
|
||||
|
||||
|
||||
class TestBrowserConsoleSchema:
|
||||
"""browser_console is properly registered in the tool registry."""
|
||||
|
||||
def test_schema_in_browser_schemas(self):
|
||||
from tools.browser_tool import BROWSER_TOOL_SCHEMAS
|
||||
|
||||
names = [s["name"] for s in BROWSER_TOOL_SCHEMAS]
|
||||
assert "browser_console" in names
|
||||
|
||||
def test_schema_has_clear_param(self):
|
||||
from tools.browser_tool import BROWSER_TOOL_SCHEMAS
|
||||
|
||||
schema = next(s for s in BROWSER_TOOL_SCHEMAS if s["name"] == "browser_console")
|
||||
props = schema["parameters"]["properties"]
|
||||
assert "clear" in props
|
||||
assert props["clear"]["type"] == "boolean"
|
||||
|
||||
|
||||
# ── browser_vision annotate ──────────────────────────────────────────
|
||||
|
||||
|
||||
class TestBrowserVisionAnnotate:
|
||||
"""browser_vision supports annotate parameter."""
|
||||
|
||||
def test_schema_has_annotate_param(self):
|
||||
from tools.browser_tool import BROWSER_TOOL_SCHEMAS
|
||||
|
||||
schema = next(s for s in BROWSER_TOOL_SCHEMAS if s["name"] == "browser_vision")
|
||||
props = schema["parameters"]["properties"]
|
||||
assert "annotate" in props
|
||||
assert props["annotate"]["type"] == "boolean"
|
||||
|
||||
def test_annotate_false_no_flag(self):
|
||||
"""Without annotate, screenshot command has no --annotate flag."""
|
||||
from tools.browser_tool import browser_vision
|
||||
|
||||
with (
|
||||
patch("tools.browser_tool._run_browser_command") as mock_cmd,
|
||||
patch("tools.browser_tool._aux_vision_client") as mock_client,
|
||||
patch("tools.browser_tool._DEFAULT_VISION_MODEL", "test-model"),
|
||||
patch("tools.browser_tool._get_vision_model", return_value="test-model"),
|
||||
):
|
||||
mock_cmd.return_value = {"success": True, "data": {}}
|
||||
# Will fail at screenshot file read, but we can check the command
|
||||
try:
|
||||
browser_vision("test", annotate=False, task_id="test")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if mock_cmd.called:
|
||||
args = mock_cmd.call_args[0]
|
||||
cmd_args = args[2] if len(args) > 2 else []
|
||||
assert "--annotate" not in cmd_args
|
||||
|
||||
def test_annotate_true_adds_flag(self):
|
||||
"""With annotate=True, screenshot command includes --annotate."""
|
||||
from tools.browser_tool import browser_vision
|
||||
|
||||
with (
|
||||
patch("tools.browser_tool._run_browser_command") as mock_cmd,
|
||||
patch("tools.browser_tool._aux_vision_client") as mock_client,
|
||||
patch("tools.browser_tool._DEFAULT_VISION_MODEL", "test-model"),
|
||||
patch("tools.browser_tool._get_vision_model", return_value="test-model"),
|
||||
):
|
||||
mock_cmd.return_value = {"success": True, "data": {}}
|
||||
try:
|
||||
browser_vision("test", annotate=True, task_id="test")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if mock_cmd.called:
|
||||
args = mock_cmd.call_args[0]
|
||||
cmd_args = args[2] if len(args) > 2 else []
|
||||
assert "--annotate" in cmd_args
|
||||
|
||||
|
||||
# ── auto-recording config ────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestRecordSessionsConfig:
|
||||
"""browser.record_sessions config option."""
|
||||
|
||||
def test_default_config_has_record_sessions(self):
|
||||
from hermes_cli.config import DEFAULT_CONFIG
|
||||
|
||||
browser_cfg = DEFAULT_CONFIG.get("browser", {})
|
||||
assert "record_sessions" in browser_cfg
|
||||
assert browser_cfg["record_sessions"] is False
|
||||
|
||||
def test_maybe_start_recording_disabled(self):
|
||||
"""Recording doesn't start when config says record_sessions: false."""
|
||||
from tools.browser_tool import _maybe_start_recording, _recording_sessions
|
||||
|
||||
with (
|
||||
patch("tools.browser_tool._run_browser_command") as mock_cmd,
|
||||
patch("builtins.open", side_effect=FileNotFoundError),
|
||||
):
|
||||
_maybe_start_recording("test-task")
|
||||
|
||||
mock_cmd.assert_not_called()
|
||||
assert "test-task" not in _recording_sessions
|
||||
|
||||
def test_maybe_stop_recording_noop_when_not_recording(self):
|
||||
"""Stopping when not recording is a no-op."""
|
||||
from tools.browser_tool import _maybe_stop_recording, _recording_sessions
|
||||
|
||||
_recording_sessions.discard("test-task") # ensure not in set
|
||||
with patch("tools.browser_tool._run_browser_command") as mock_cmd:
|
||||
_maybe_stop_recording("test-task")
|
||||
|
||||
mock_cmd.assert_not_called()
|
||||
|
||||
|
||||
# ── dogfood skill files ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestDogfoodSkill:
|
||||
"""Dogfood skill files exist and have correct structure."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _skill_dir(self):
|
||||
# Use the actual repo skills dir (not temp)
|
||||
self.skill_dir = os.path.join(
|
||||
os.path.dirname(__file__), "..", "..", "skills", "dogfood"
|
||||
)
|
||||
|
||||
def test_skill_md_exists(self):
|
||||
assert os.path.exists(os.path.join(self.skill_dir, "SKILL.md"))
|
||||
|
||||
def test_taxonomy_exists(self):
|
||||
assert os.path.exists(
|
||||
os.path.join(self.skill_dir, "references", "issue-taxonomy.md")
|
||||
)
|
||||
|
||||
def test_report_template_exists(self):
|
||||
assert os.path.exists(
|
||||
os.path.join(self.skill_dir, "templates", "dogfood-report-template.md")
|
||||
)
|
||||
|
||||
def test_skill_md_has_frontmatter(self):
|
||||
with open(os.path.join(self.skill_dir, "SKILL.md")) as f:
|
||||
content = f.read()
|
||||
assert content.startswith("---")
|
||||
assert "name: dogfood" in content
|
||||
assert "description:" in content
|
||||
|
||||
def test_skill_references_browser_console(self):
|
||||
with open(os.path.join(self.skill_dir, "SKILL.md")) as f:
|
||||
content = f.read()
|
||||
assert "browser_console" in content
|
||||
|
||||
def test_skill_references_annotate(self):
|
||||
with open(os.path.join(self.skill_dir, "SKILL.md")) as f:
|
||||
content = f.read()
|
||||
assert "annotate" in content
|
||||
|
||||
def test_taxonomy_has_severity_levels(self):
|
||||
with open(
|
||||
os.path.join(self.skill_dir, "references", "issue-taxonomy.md")
|
||||
) as f:
|
||||
content = f.read()
|
||||
assert "Critical" in content
|
||||
assert "High" in content
|
||||
assert "Medium" in content
|
||||
assert "Low" in content
|
||||
|
||||
def test_taxonomy_has_categories(self):
|
||||
with open(
|
||||
os.path.join(self.skill_dir, "references", "issue-taxonomy.md")
|
||||
) as f:
|
||||
content = f.read()
|
||||
assert "Functional" in content
|
||||
assert "Visual" in content
|
||||
assert "Accessibility" in content
|
||||
assert "Console" in content
|
||||
|
|
@ -550,14 +550,13 @@ class TestConvertToPng:
|
|||
"""BMP file should still be reported as success if no converter available."""
|
||||
dest = tmp_path / "img.png"
|
||||
dest.write_bytes(FAKE_BMP) # it's a BMP but named .png
|
||||
# Both Pillow and ImageMagick fail
|
||||
with patch("hermes_cli.clipboard.subprocess.run", side_effect=FileNotFoundError):
|
||||
# Pillow import fails
|
||||
with pytest.raises(Exception):
|
||||
from PIL import Image # noqa — this may or may not work
|
||||
# The function should still return True if file exists and has content
|
||||
# (raw BMP is better than nothing)
|
||||
assert dest.exists() and dest.stat().st_size > 0
|
||||
# Both Pillow and ImageMagick unavailable
|
||||
with patch.dict(sys.modules, {"PIL": None, "PIL.Image": None}):
|
||||
with patch("hermes_cli.clipboard.subprocess.run", side_effect=FileNotFoundError):
|
||||
result = _convert_to_png(dest)
|
||||
# Raw BMP is better than nothing — function should return True
|
||||
assert result is True
|
||||
assert dest.exists() and dest.stat().st_size > 0
|
||||
|
||||
|
||||
# ── has_clipboard_image dispatch ─────────────────────────────────────────
|
||||
|
|
|
|||
|
|
@ -259,6 +259,70 @@ class TestShellFileOpsHelpers:
|
|||
assert ops.cwd == "/"
|
||||
|
||||
|
||||
class TestSearchPathValidation:
|
||||
"""Test that search() returns an error for non-existent paths."""
|
||||
|
||||
def test_search_nonexistent_path_returns_error(self, mock_env):
|
||||
"""search() should return an error when the path doesn't exist."""
|
||||
def side_effect(command, **kwargs):
|
||||
if "test -e" in command:
|
||||
return {"output": "not_found", "returncode": 1}
|
||||
if "command -v" in command:
|
||||
return {"output": "yes", "returncode": 0}
|
||||
return {"output": "", "returncode": 0}
|
||||
mock_env.execute.side_effect = side_effect
|
||||
ops = ShellFileOperations(mock_env)
|
||||
result = ops.search("pattern", path="/nonexistent/path")
|
||||
assert result.error is not None
|
||||
assert "not found" in result.error.lower() or "Path not found" in result.error
|
||||
|
||||
def test_search_nonexistent_path_files_mode(self, mock_env):
|
||||
"""search(target='files') should also return error for bad paths."""
|
||||
def side_effect(command, **kwargs):
|
||||
if "test -e" in command:
|
||||
return {"output": "not_found", "returncode": 1}
|
||||
if "command -v" in command:
|
||||
return {"output": "yes", "returncode": 0}
|
||||
return {"output": "", "returncode": 0}
|
||||
mock_env.execute.side_effect = side_effect
|
||||
ops = ShellFileOperations(mock_env)
|
||||
result = ops.search("*.py", path="/nonexistent/path", target="files")
|
||||
assert result.error is not None
|
||||
assert "not found" in result.error.lower() or "Path not found" in result.error
|
||||
|
||||
def test_search_existing_path_proceeds(self, mock_env):
|
||||
"""search() should proceed normally when the path exists."""
|
||||
def side_effect(command, **kwargs):
|
||||
if "test -e" in command:
|
||||
return {"output": "exists", "returncode": 0}
|
||||
if "command -v" in command:
|
||||
return {"output": "yes", "returncode": 0}
|
||||
# rg returns exit 1 (no matches) with empty output
|
||||
return {"output": "", "returncode": 1}
|
||||
mock_env.execute.side_effect = side_effect
|
||||
ops = ShellFileOperations(mock_env)
|
||||
result = ops.search("pattern", path="/existing/path")
|
||||
assert result.error is None
|
||||
assert result.total_count == 0 # No matches but no error
|
||||
|
||||
def test_search_rg_error_exit_code(self, mock_env):
|
||||
"""search() should report error when rg returns exit code 2."""
|
||||
call_count = {"n": 0}
|
||||
def side_effect(command, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if "test -e" in command:
|
||||
return {"output": "exists", "returncode": 0}
|
||||
if "command -v" in command:
|
||||
return {"output": "yes", "returncode": 0}
|
||||
# rg returns exit 2 (error) with empty output
|
||||
return {"output": "", "returncode": 2}
|
||||
mock_env.execute.side_effect = side_effect
|
||||
ops = ShellFileOperations(mock_env)
|
||||
result = ops.search("pattern", path="/some/path")
|
||||
assert result.error is not None
|
||||
assert "search failed" in result.error.lower() or "Search error" in result.error
|
||||
|
||||
|
||||
class TestShellFileOpsWriteDenied:
|
||||
def test_write_file_denied_path(self, file_ops):
|
||||
result = file_ops.write_file("~/.ssh/authorized_keys", "evil key")
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ class TestReadFileHandler:
|
|||
def test_returns_file_content(self, mock_get):
|
||||
mock_ops = MagicMock()
|
||||
result_obj = MagicMock()
|
||||
result_obj.content = "line1\nline2"
|
||||
result_obj.to_dict.return_value = {"content": "line1\nline2", "total_lines": 2}
|
||||
mock_ops.read_file.return_value = result_obj
|
||||
mock_get.return_value = mock_ops
|
||||
|
|
@ -52,6 +53,7 @@ class TestReadFileHandler:
|
|||
def test_custom_offset_and_limit(self, mock_get):
|
||||
mock_ops = MagicMock()
|
||||
result_obj = MagicMock()
|
||||
result_obj.content = "line10"
|
||||
result_obj.to_dict.return_value = {"content": "line10", "total_lines": 50}
|
||||
mock_ops.read_file.return_value = result_obj
|
||||
mock_get.return_value = mock_ops
|
||||
|
|
@ -200,3 +202,96 @@ class TestSearchHandler:
|
|||
from tools.file_tools import search_tool
|
||||
result = json.loads(search_tool(pattern="x"))
|
||||
assert "error" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool result hint tests (#722)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPatchHints:
|
||||
"""Patch tool should hint when old_string is not found."""
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_no_match_includes_hint(self, mock_get):
|
||||
mock_ops = MagicMock()
|
||||
result_obj = MagicMock()
|
||||
result_obj.to_dict.return_value = {
|
||||
"error": "Could not find match for old_string in foo.py"
|
||||
}
|
||||
mock_ops.patch_replace.return_value = result_obj
|
||||
mock_get.return_value = mock_ops
|
||||
|
||||
from tools.file_tools import patch_tool
|
||||
raw = patch_tool(mode="replace", path="foo.py", old_string="x", new_string="y")
|
||||
assert "[Hint:" in raw
|
||||
assert "read_file" in raw
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_success_no_hint(self, mock_get):
|
||||
mock_ops = MagicMock()
|
||||
result_obj = MagicMock()
|
||||
result_obj.to_dict.return_value = {"success": True, "diff": "--- a\n+++ b"}
|
||||
mock_ops.patch_replace.return_value = result_obj
|
||||
mock_get.return_value = mock_ops
|
||||
|
||||
from tools.file_tools import patch_tool
|
||||
raw = patch_tool(mode="replace", path="foo.py", old_string="x", new_string="y")
|
||||
assert "[Hint:" not in raw
|
||||
|
||||
|
||||
class TestSearchHints:
|
||||
"""Search tool should hint when results are truncated."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Clear read/search tracker between tests to avoid cross-test state."""
|
||||
from tools.file_tools import clear_read_tracker
|
||||
clear_read_tracker()
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_truncated_results_hint(self, mock_get):
|
||||
mock_ops = MagicMock()
|
||||
result_obj = MagicMock()
|
||||
result_obj.to_dict.return_value = {
|
||||
"total_count": 100,
|
||||
"matches": [{"path": "a.py", "line": 1, "content": "x"}] * 50,
|
||||
"truncated": True,
|
||||
}
|
||||
mock_ops.search.return_value = result_obj
|
||||
mock_get.return_value = mock_ops
|
||||
|
||||
from tools.file_tools import search_tool
|
||||
raw = search_tool(pattern="foo", offset=0, limit=50)
|
||||
assert "[Hint:" in raw
|
||||
assert "offset=50" in raw
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_non_truncated_no_hint(self, mock_get):
|
||||
mock_ops = MagicMock()
|
||||
result_obj = MagicMock()
|
||||
result_obj.to_dict.return_value = {
|
||||
"total_count": 3,
|
||||
"matches": [{"path": "a.py", "line": 1, "content": "x"}] * 3,
|
||||
}
|
||||
mock_ops.search.return_value = result_obj
|
||||
mock_get.return_value = mock_ops
|
||||
|
||||
from tools.file_tools import search_tool
|
||||
raw = search_tool(pattern="foo")
|
||||
assert "[Hint:" not in raw
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_truncated_hint_with_nonzero_offset(self, mock_get):
|
||||
mock_ops = MagicMock()
|
||||
result_obj = MagicMock()
|
||||
result_obj.to_dict.return_value = {
|
||||
"total_count": 150,
|
||||
"matches": [{"path": "a.py", "line": 1, "content": "x"}] * 50,
|
||||
"truncated": True,
|
||||
}
|
||||
mock_ops.search.return_value = result_obj
|
||||
mock_get.return_value = mock_ops
|
||||
|
||||
from tools.file_tools import search_tool
|
||||
raw = search_tool(pattern="foo", offset=50, limit=50)
|
||||
assert "[Hint:" in raw
|
||||
assert "offset=100" in raw
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue