simplify docstrings, fix some bugs

This commit is contained in:
balyan.sid@gmail.com 2026-03-15 01:12:16 +05:30
parent 861202b56c
commit 9001b34146
6 changed files with 37 additions and 196 deletions

View file

@ -1,10 +1,4 @@
"""Tests for the local persistent shell backend. """Tests for the local persistent shell backend."""
Unit tests cover config plumbing (no real shell needed).
Integration tests run real commands no external dependencies required.
pytest tests/tools/test_local_persistent.py -v
"""
import glob as glob_mod import glob as glob_mod
@ -14,10 +8,6 @@ from tools.environments.local import LocalEnvironment
from tools.environments.persistent_shell import PersistentShellMixin from tools.environments.persistent_shell import PersistentShellMixin
# ---------------------------------------------------------------------------
# Unit tests — config plumbing
# ---------------------------------------------------------------------------
class TestLocalConfig: class TestLocalConfig:
def test_local_persistent_default_false(self, monkeypatch): def test_local_persistent_default_false(self, monkeypatch):
monkeypatch.delenv("TERMINAL_LOCAL_PERSISTENT", raising=False) monkeypatch.delenv("TERMINAL_LOCAL_PERSISTENT", raising=False)
@ -36,8 +26,6 @@ class TestLocalConfig:
class TestMergeOutput: class TestMergeOutput:
"""Test the shared _merge_output static method."""
def test_stdout_only(self): def test_stdout_only(self):
assert PersistentShellMixin._merge_output("out", "") == "out" assert PersistentShellMixin._merge_output("out", "") == "out"
@ -54,13 +42,7 @@ class TestMergeOutput:
assert PersistentShellMixin._merge_output("out\n\n", "err\n") == "out\nerr" assert PersistentShellMixin._merge_output("out\n\n", "err\n") == "out\nerr"
# ---------------------------------------------------------------------------
# One-shot regression tests — ensure refactor didn't break anything
# ---------------------------------------------------------------------------
class TestLocalOneShotRegression: class TestLocalOneShotRegression:
"""Verify one-shot mode still works after adding the mixin."""
def test_echo(self): def test_echo(self):
env = LocalEnvironment(persistent=False) env = LocalEnvironment(persistent=False)
r = env.execute("echo hello") r = env.execute("echo hello")
@ -75,22 +57,14 @@ class TestLocalOneShotRegression:
env.cleanup() env.cleanup()
def test_state_does_not_persist(self): def test_state_does_not_persist(self):
"""Env vars set in one command should NOT survive in one-shot mode."""
env = LocalEnvironment(persistent=False) env = LocalEnvironment(persistent=False)
env.execute("export HERMES_ONESHOT_LOCAL=yes") env.execute("export HERMES_ONESHOT_LOCAL=yes")
r = env.execute("echo $HERMES_ONESHOT_LOCAL") r = env.execute("echo $HERMES_ONESHOT_LOCAL")
# In one-shot mode, env var should not persist
assert r["output"].strip() == "" assert r["output"].strip() == ""
env.cleanup() env.cleanup()
# ---------------------------------------------------------------------------
# Persistent shell integration tests
# ---------------------------------------------------------------------------
class TestLocalPersistent: class TestLocalPersistent:
"""Persistent mode: state persists across execute() calls."""
@pytest.fixture @pytest.fixture
def env(self): def env(self):
e = LocalEnvironment(persistent=True) e = LocalEnvironment(persistent=True)
@ -128,8 +102,7 @@ class TestLocalPersistent:
def test_timeout_then_recovery(self, env): def test_timeout_then_recovery(self, env):
r = env.execute("sleep 999", timeout=2) r = env.execute("sleep 999", timeout=2)
assert r["returncode"] in (124, 130) # timeout or interrupted assert r["returncode"] in (124, 130)
# Shell should survive — next command works
r = env.execute("echo alive") r = env.execute("echo alive")
assert r["returncode"] == 0 assert r["returncode"] == 0
assert "alive" in r["output"] assert "alive" in r["output"]
@ -143,7 +116,6 @@ class TestLocalPersistent:
assert lines[-1] == "1000" assert lines[-1] == "1000"
def test_shell_variable_persists(self, env): def test_shell_variable_persists(self, env):
"""Shell variables (not exported) should also persist."""
env.execute("MY_LOCAL_VAR=hello123") env.execute("MY_LOCAL_VAR=hello123")
r = env.execute("echo $MY_LOCAL_VAR") r = env.execute("echo $MY_LOCAL_VAR")
assert r["output"].strip() == "hello123" assert r["output"].strip() == "hello123"
@ -151,14 +123,12 @@ class TestLocalPersistent:
def test_cleanup_removes_temp_files(self, env): def test_cleanup_removes_temp_files(self, env):
env.execute("echo warmup") env.execute("echo warmup")
prefix = env._temp_prefix prefix = env._temp_prefix
# Temp files should exist
assert len(glob_mod.glob(f"{prefix}-*")) > 0 assert len(glob_mod.glob(f"{prefix}-*")) > 0
env.cleanup() env.cleanup()
remaining = glob_mod.glob(f"{prefix}-*") remaining = glob_mod.glob(f"{prefix}-*")
assert remaining == [] assert remaining == []
def test_state_does_not_leak_between_instances(self): def test_state_does_not_leak_between_instances(self):
"""Two separate persistent instances don't share state."""
env1 = LocalEnvironment(persistent=True) env1 = LocalEnvironment(persistent=True)
env2 = LocalEnvironment(persistent=True) env2 = LocalEnvironment(persistent=True)
try: try:
@ -170,7 +140,6 @@ class TestLocalPersistent:
env2.cleanup() env2.cleanup()
def test_special_characters_in_command(self, env): def test_special_characters_in_command(self, env):
"""Commands with quotes and special chars should work."""
r = env.execute("echo 'hello world'") r = env.execute("echo 'hello world'")
assert r["output"].strip() == "hello world" assert r["output"].strip() == "hello world"

View file

@ -1,15 +1,4 @@
"""Tests for the SSH remote execution environment backend. """Tests for the SSH remote execution environment backend."""
Unit tests (no SSH required) cover pure logic: command building, output merging,
config plumbing.
Integration tests require a real SSH target. Set TERMINAL_SSH_HOST and
TERMINAL_SSH_USER to enable them. In CI, start an sshd container or enable
the localhost SSH service.
TERMINAL_SSH_HOST=localhost TERMINAL_SSH_USER=$(whoami) \
pytest tests/tools/test_ssh_environment.py -v
"""
import json import json
import os import os
@ -20,11 +9,6 @@ import pytest
from tools.environments.ssh import SSHEnvironment from tools.environments.ssh import SSHEnvironment
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_SSH_HOST = os.getenv("TERMINAL_SSH_HOST", "") _SSH_HOST = os.getenv("TERMINAL_SSH_HOST", "")
_SSH_USER = os.getenv("TERMINAL_SSH_USER", "") _SSH_USER = os.getenv("TERMINAL_SSH_USER", "")
_SSH_PORT = int(os.getenv("TERMINAL_SSH_PORT", "22")) _SSH_PORT = int(os.getenv("TERMINAL_SSH_PORT", "22"))
@ -39,7 +23,6 @@ requires_ssh = pytest.mark.skipif(
def _run(command, task_id="ssh_test", **kwargs): def _run(command, task_id="ssh_test", **kwargs):
"""Call terminal_tool like an LLM would, return parsed JSON."""
from tools.terminal_tool import terminal_tool from tools.terminal_tool import terminal_tool
return json.loads(terminal_tool(command, task_id=task_id, **kwargs)) return json.loads(terminal_tool(command, task_id=task_id, **kwargs))
@ -49,12 +32,7 @@ def _cleanup(task_id="ssh_test"):
cleanup_vm(task_id) cleanup_vm(task_id)
# ---------------------------------------------------------------------------
# Unit tests — no SSH connection needed
# ---------------------------------------------------------------------------
class TestBuildSSHCommand: class TestBuildSSHCommand:
"""Pure logic: verify the ssh command list is assembled correctly."""
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def _mock_connection(self, monkeypatch): def _mock_connection(self, monkeypatch):
@ -100,12 +78,7 @@ class TestTerminalToolConfig:
assert _get_env_config()["ssh_persistent"] is True assert _get_env_config()["ssh_persistent"] is True
# ---------------------------------------------------------------------------
# Integration tests — real SSH, through terminal_tool() interface
# ---------------------------------------------------------------------------
def _setup_ssh_env(monkeypatch, persistent: bool): def _setup_ssh_env(monkeypatch, persistent: bool):
"""Configure env vars for SSH integration tests."""
monkeypatch.setenv("TERMINAL_ENV", "ssh") monkeypatch.setenv("TERMINAL_ENV", "ssh")
monkeypatch.setenv("TERMINAL_SSH_HOST", _SSH_HOST) monkeypatch.setenv("TERMINAL_SSH_HOST", _SSH_HOST)
monkeypatch.setenv("TERMINAL_SSH_USER", _SSH_USER) monkeypatch.setenv("TERMINAL_SSH_USER", _SSH_USER)
@ -118,7 +91,6 @@ def _setup_ssh_env(monkeypatch, persistent: bool):
@requires_ssh @requires_ssh
class TestOneShotSSH: class TestOneShotSSH:
"""One-shot mode: each command is a fresh ssh invocation."""
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def _setup(self, monkeypatch): def _setup(self, monkeypatch):
@ -136,7 +108,6 @@ class TestOneShotSSH:
assert r["exit_code"] == 42 assert r["exit_code"] == 42
def test_state_does_not_persist(self): def test_state_does_not_persist(self):
"""Env vars set in one command should NOT survive to the next."""
_run("export HERMES_ONESHOT_TEST=yes") _run("export HERMES_ONESHOT_TEST=yes")
r = _run("echo $HERMES_ONESHOT_TEST") r = _run("echo $HERMES_ONESHOT_TEST")
assert r["output"].strip() == "" assert r["output"].strip() == ""
@ -144,7 +115,6 @@ class TestOneShotSSH:
@requires_ssh @requires_ssh
class TestPersistentSSH: class TestPersistentSSH:
"""Persistent mode: single long-lived shell, state persists."""
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def _setup(self, monkeypatch): def _setup(self, monkeypatch):
@ -184,7 +154,6 @@ class TestPersistentSSH:
def test_timeout_then_recovery(self): def test_timeout_then_recovery(self):
r = _run("sleep 999", timeout=2) r = _run("sleep 999", timeout=2)
assert r["exit_code"] == 124 assert r["exit_code"] == 124
# Shell should survive — next command works
r = _run("echo alive") r = _run("echo alive")
assert r["exit_code"] == 0 assert r["exit_code"] == 0
assert "alive" in r["output"] assert "alive" in r["output"]

View file

@ -1,5 +1,6 @@
"""Local execution environment with interrupt support and non-blocking I/O.""" """Local execution environment with interrupt support and non-blocking I/O."""
import glob
import os import os
import platform import platform
import shutil import shutil
@ -226,10 +227,6 @@ class LocalEnvironment(PersistentShellMixin, BaseEnvironment):
if self.persistent: if self.persistent:
self._init_persistent_shell() self._init_persistent_shell()
# ------------------------------------------------------------------
# PersistentShellMixin: backend-specific implementations
# ------------------------------------------------------------------
@property @property
def _temp_prefix(self) -> str: def _temp_prefix(self) -> str:
return f"/tmp/hermes-local-{self._session_id}" return f"/tmp/hermes-local-{self._session_id}"
@ -241,14 +238,13 @@ class LocalEnvironment(PersistentShellMixin, BaseEnvironment):
[user_shell, "-l"], [user_shell, "-l"],
stdin=subprocess.PIPE, stdin=subprocess.PIPE,
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, stderr=subprocess.DEVNULL,
text=True, text=True,
env=run_env, env=run_env,
preexec_fn=None if _IS_WINDOWS else os.setsid, preexec_fn=None if _IS_WINDOWS else os.setsid,
) )
def _read_temp_files(self, *paths: str) -> list[str]: def _read_temp_files(self, *paths: str) -> list[str]:
"""Read local files directly."""
results = [] results = []
for path in paths: for path in paths:
try: try:
@ -259,7 +255,6 @@ class LocalEnvironment(PersistentShellMixin, BaseEnvironment):
return results return results
def _kill_shell_children(self): def _kill_shell_children(self):
"""Kill children of the persistent shell via pkill -P."""
if self._shell_pid is None: if self._shell_pid is None:
return return
try: try:
@ -270,9 +265,12 @@ class LocalEnvironment(PersistentShellMixin, BaseEnvironment):
except (subprocess.TimeoutExpired, OSError, FileNotFoundError): except (subprocess.TimeoutExpired, OSError, FileNotFoundError):
pass pass
# ------------------------------------------------------------------ def _cleanup_temp_files(self):
# One-shot execution (original behavior) for f in glob.glob(f"{self._temp_prefix}-*"):
# ------------------------------------------------------------------ try:
os.remove(f)
except OSError:
pass
def _execute_oneshot(self, command: str, cwd: str = "", *, def _execute_oneshot(self, command: str, cwd: str = "", *,
timeout: int | None = None, timeout: int | None = None,
@ -281,7 +279,6 @@ class LocalEnvironment(PersistentShellMixin, BaseEnvironment):
effective_timeout = timeout or self.timeout effective_timeout = timeout or self.timeout
exec_command, sudo_stdin = self._prepare_command(command) exec_command, sudo_stdin = self._prepare_command(command)
# Merge the sudo password (if any) with caller-supplied stdin_data.
if sudo_stdin is not None and stdin_data is not None: if sudo_stdin is not None and stdin_data is not None:
effective_stdin = sudo_stdin + stdin_data effective_stdin = sudo_stdin + stdin_data
elif sudo_stdin is not None: elif sudo_stdin is not None:
@ -378,10 +375,6 @@ class LocalEnvironment(PersistentShellMixin, BaseEnvironment):
except Exception as e: except Exception as e:
return {"output": f"Execution error: {str(e)}", "returncode": 1} return {"output": f"Execution error: {str(e)}", "returncode": 1}
# ------------------------------------------------------------------
# Public interface
# ------------------------------------------------------------------
def execute(self, command: str, cwd: str = "", *, def execute(self, command: str, cwd: str = "", *,
timeout: int | None = None, timeout: int | None = None,
stdin_data: str | None = None) -> dict: stdin_data: str | None = None) -> dict:

View file

@ -1,18 +1,6 @@
"""Persistent shell mixin: file-based IPC protocol for long-lived bash shells. """Persistent shell mixin: file-based IPC protocol for long-lived bash shells."""
Provides the shared logic for maintaining a persistent bash shell across
execute() calls. Backend-specific operations (spawning the shell, reading
temp files, killing child processes) are implemented by subclasses via
abstract methods.
The IPC protocol writes each command's stdout/stderr/exit-code/cwd to temp
files, then polls the status file for completion. A daemon thread drains
the shell's stdout to prevent pipe deadlock and detect shell death.
"""
import glob as glob_mod
import logging import logging
import os
import shlex import shlex
import subprocess import subprocess
import threading import threading
@ -28,65 +16,42 @@ logger = logging.getLogger(__name__)
class PersistentShellMixin: class PersistentShellMixin:
"""Mixin that adds persistent shell capability to any BaseEnvironment. """Mixin that adds persistent shell capability to any BaseEnvironment.
Subclasses MUST implement: Subclasses must implement ``_spawn_shell_process()``, ``_read_temp_files()``,
_spawn_shell_process() -> subprocess.Popen ``_kill_shell_children()``, and ``_execute_oneshot()`` (stdin fallback).
_read_temp_files(*paths) -> list[str]
_kill_shell_children()
Subclasses MUST also provide ``_execute_oneshot()`` for the stdin_data
fallback path (commands with piped stdin cannot use the persistent shell).
""" """
# -- State (initialized by _init_persistent_shell) --------------------- @abstractmethod
_shell_proc: subprocess.Popen | None = None def _spawn_shell_process(self) -> subprocess.Popen: ...
_shell_alive: bool = False
_shell_pid: int | None = None @abstractmethod
def _read_temp_files(self, *paths: str) -> list[str]: ...
@abstractmethod
def _kill_shell_children(self): ...
@abstractmethod
def _execute_oneshot(self, command: str, cwd: str, *,
timeout: int | None = None,
stdin_data: str | None = None) -> dict: ...
@abstractmethod
def _cleanup_temp_files(self): ...
_session_id: str = "" _session_id: str = ""
# -- Abstract methods (backend-specific) -------------------------------
@abstractmethod
def _spawn_shell_process(self) -> subprocess.Popen:
"""Spawn a long-lived bash shell and return the Popen handle.
Must use ``stdin=PIPE, stdout=PIPE, stderr=PIPE, text=True``.
"""
...
@abstractmethod
def _read_temp_files(self, *paths: str) -> list[str]:
"""Read temp files from the execution context.
Returns contents in the same order as *paths*. Falls back to
empty strings on failure.
"""
...
@abstractmethod
def _kill_shell_children(self):
"""Kill the running command's processes but keep the shell alive."""
...
# -- Overridable properties --------------------------------------------
@property @property
def _temp_prefix(self) -> str: def _temp_prefix(self) -> str:
"""Base path for temp files. Override per backend."""
return f"/tmp/hermes-persistent-{self._session_id}" return f"/tmp/hermes-persistent-{self._session_id}"
# -- Shared implementation ---------------------------------------------
def _init_persistent_shell(self): def _init_persistent_shell(self):
"""Call from ``__init__`` when ``persistent=True``."""
self._shell_lock = threading.Lock() self._shell_lock = threading.Lock()
self._session_id = "" self._session_id: str = ""
self._shell_proc = None self._shell_proc: subprocess.Popen | None = None
self._shell_alive = False self._shell_alive: bool = False
self._shell_pid = None self._shell_pid: int | None = None
self._start_persistent_shell() self._start_persistent_shell()
def _start_persistent_shell(self): def _start_persistent_shell(self):
"""Spawn the shell, create temp files, capture PID."""
self._session_id = uuid.uuid4().hex[:12] self._session_id = uuid.uuid4().hex[:12]
p = self._temp_prefix p = self._temp_prefix
self._pshell_stdout = f"{p}-stdout" self._pshell_stdout = f"{p}-stdout"
@ -103,7 +68,6 @@ class PersistentShellMixin:
) )
self._drain_thread.start() self._drain_thread.start()
# Initialize temp files and capture shell PID
init_script = ( init_script = (
f"touch {self._pshell_stdout} {self._pshell_stderr} " f"touch {self._pshell_stdout} {self._pshell_stderr} "
f"{self._pshell_status} {self._pshell_cwd} {self._pshell_pid_file}\n" f"{self._pshell_status} {self._pshell_cwd} {self._pshell_pid_file}\n"
@ -112,7 +76,6 @@ class PersistentShellMixin:
) )
self._send_to_shell(init_script) self._send_to_shell(init_script)
# Poll for PID file
deadline = time.monotonic() + 3.0 deadline = time.monotonic() + 3.0
while time.monotonic() < deadline: while time.monotonic() < deadline:
pid_str = self._read_temp_files(self._pshell_pid_file)[0].strip() pid_str = self._read_temp_files(self._pshell_pid_file)[0].strip()
@ -130,22 +93,19 @@ class PersistentShellMixin:
self._session_id, self._shell_pid, self._session_id, self._shell_pid,
) )
# Update cwd from what the shell reports
reported_cwd = self._read_temp_files(self._pshell_cwd)[0].strip() reported_cwd = self._read_temp_files(self._pshell_cwd)[0].strip()
if reported_cwd: if reported_cwd:
self.cwd = reported_cwd self.cwd = reported_cwd
def _drain_shell_output(self): def _drain_shell_output(self):
"""Drain stdout to prevent pipe deadlock; detect shell death."""
try: try:
for _ in self._shell_proc.stdout: for _ in self._shell_proc.stdout:
pass # Real output goes to temp files pass
except Exception: except Exception:
pass pass
self._shell_alive = False self._shell_alive = False
def _send_to_shell(self, text: str): def _send_to_shell(self, text: str):
"""Write text to the persistent shell's stdin."""
if not self._shell_alive or self._shell_proc is None: if not self._shell_alive or self._shell_proc is None:
return return
try: try:
@ -155,13 +115,11 @@ class PersistentShellMixin:
self._shell_alive = False self._shell_alive = False
def _read_persistent_output(self) -> tuple[str, int, str]: def _read_persistent_output(self) -> tuple[str, int, str]:
"""Read stdout, stderr, status, cwd. Returns (output, exit_code, cwd)."""
stdout, stderr, status_raw, cwd = self._read_temp_files( stdout, stderr, status_raw, cwd = self._read_temp_files(
self._pshell_stdout, self._pshell_stderr, self._pshell_stdout, self._pshell_stderr,
self._pshell_status, self._pshell_cwd, self._pshell_status, self._pshell_cwd,
) )
output = self._merge_output(stdout, stderr) output = self._merge_output(stdout, stderr)
# Status format: "cmd_id:exit_code" — strip the ID prefix
status = status_raw.strip() status = status_raw.strip()
if ":" in status: if ":" in status:
status = status.split(":", 1)[1] status = status.split(":", 1)[1]
@ -174,15 +132,12 @@ class PersistentShellMixin:
def _execute_persistent(self, command: str, cwd: str, *, def _execute_persistent(self, command: str, cwd: str, *,
timeout: int | None = None, timeout: int | None = None,
stdin_data: str | None = None) -> dict: stdin_data: str | None = None) -> dict:
"""Execute a command in the persistent shell."""
if not self._shell_alive: if not self._shell_alive:
logger.info("Persistent shell died, restarting...") logger.info("Persistent shell died, restarting...")
self._start_persistent_shell() self._start_persistent_shell()
exec_command, sudo_stdin = self._prepare_command(command) exec_command, sudo_stdin = self._prepare_command(command)
effective_timeout = timeout or self.timeout effective_timeout = timeout or self.timeout
# Fall back to one-shot for commands needing piped stdin
if stdin_data or sudo_stdin: if stdin_data or sudo_stdin:
return self._execute_oneshot( return self._execute_oneshot(
command, cwd, timeout=timeout, stdin_data=stdin_data, command, cwd, timeout=timeout, stdin_data=stdin_data,
@ -195,25 +150,14 @@ class PersistentShellMixin:
def _execute_persistent_locked(self, command: str, cwd: str, def _execute_persistent_locked(self, command: str, cwd: str,
timeout: int) -> dict: timeout: int) -> dict:
"""Inner persistent execution — caller must hold ``_shell_lock``."""
work_dir = cwd or self.cwd work_dir = cwd or self.cwd
# Each command gets a unique ID written into the status file so the
# poll loop can distinguish the *current* command's result from a
# stale value left over from the previous command. This eliminates
# the race where a fast local file read sees the old status before
# the shell has processed the truncation.
cmd_id = uuid.uuid4().hex[:8] cmd_id = uuid.uuid4().hex[:8]
# Truncate temp files
truncate = ( truncate = (
f": > {self._pshell_stdout}\n" f": > {self._pshell_stdout}\n"
f": > {self._pshell_stderr}\n" f": > {self._pshell_stderr}\n"
f": > {self._pshell_status}\n" f": > {self._pshell_status}\n"
) )
self._send_to_shell(truncate) self._send_to_shell(truncate)
# Escape command for eval
escaped = command.replace("'", "'\\''") escaped = command.replace("'", "'\\''")
ipc_script = ( ipc_script = (
@ -224,8 +168,6 @@ class PersistentShellMixin:
f"echo {cmd_id}:$__EC > {self._pshell_status}\n" f"echo {cmd_id}:$__EC > {self._pshell_status}\n"
) )
self._send_to_shell(ipc_script) self._send_to_shell(ipc_script)
# Poll the status file for current command's ID
deadline = time.monotonic() + timeout deadline = time.monotonic() + timeout
poll_interval = 0.15 poll_interval = 0.15
@ -267,7 +209,6 @@ class PersistentShellMixin:
@staticmethod @staticmethod
def _merge_output(stdout: str, stderr: str) -> str: def _merge_output(stdout: str, stderr: str) -> str:
"""Combine stdout and stderr into a single output string."""
parts = [] parts = []
if stdout.strip(): if stdout.strip():
parts.append(stdout.rstrip("\n")) parts.append(stdout.rstrip("\n"))
@ -276,7 +217,6 @@ class PersistentShellMixin:
return "\n".join(parts) return "\n".join(parts)
def _cleanup_persistent_shell(self): def _cleanup_persistent_shell(self):
"""Clean up persistent shell resources. Call from ``cleanup()``."""
if self._shell_proc is None: if self._shell_proc is None:
return return
@ -299,10 +239,3 @@ class PersistentShellMixin:
if hasattr(self, "_drain_thread") and self._drain_thread.is_alive(): if hasattr(self, "_drain_thread") and self._drain_thread.is_alive():
self._drain_thread.join(timeout=1.0) self._drain_thread.join(timeout=1.0)
def _cleanup_temp_files(self):
"""Remove local temp files. Override for remote backends (SSH, Docker)."""
for f in glob_mod.glob(f"{self._temp_prefix}-*"):
try:
os.remove(f)
except OSError:
pass

View file

@ -77,10 +77,6 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment):
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
raise RuntimeError(f"SSH connection to {self.user}@{self.host} timed out") raise RuntimeError(f"SSH connection to {self.user}@{self.host} timed out")
# ------------------------------------------------------------------
# PersistentShellMixin: backend-specific implementations
# ------------------------------------------------------------------
@property @property
def _temp_prefix(self) -> str: def _temp_prefix(self) -> str:
return f"/tmp/hermes-ssh-{self._session_id}" return f"/tmp/hermes-ssh-{self._session_id}"
@ -92,12 +88,11 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment):
cmd, cmd,
stdin=subprocess.PIPE, stdin=subprocess.PIPE,
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, stderr=subprocess.DEVNULL,
text=True, text=True,
) )
def _read_temp_files(self, *paths: str) -> list[str]: def _read_temp_files(self, *paths: str) -> list[str]:
"""Read remote files via ControlMaster one-shot SSH calls."""
if len(paths) == 1: if len(paths) == 1:
cmd = self._build_ssh_command() cmd = self._build_ssh_command()
cmd.append(f"cat {paths[0]} 2>/dev/null") cmd.append(f"cat {paths[0]} 2>/dev/null")
@ -135,7 +130,6 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment):
pass pass
def _cleanup_temp_files(self): def _cleanup_temp_files(self):
"""Remove remote temp files via SSH."""
try: try:
cmd = self._build_ssh_command() cmd = self._build_ssh_command()
cmd.append(f"rm -f {self._temp_prefix}-*") cmd.append(f"rm -f {self._temp_prefix}-*")
@ -143,20 +137,14 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment):
except (OSError, subprocess.SubprocessError): except (OSError, subprocess.SubprocessError):
pass pass
# ------------------------------------------------------------------
# One-shot execution (original behavior)
# ------------------------------------------------------------------
def _execute_oneshot(self, command: str, cwd: str = "", *, def _execute_oneshot(self, command: str, cwd: str = "", *,
timeout: int | None = None, timeout: int | None = None,
stdin_data: str | None = None) -> dict: stdin_data: str | None = None) -> dict:
"""Execute a command via a fresh one-shot SSH invocation."""
work_dir = cwd or self.cwd work_dir = cwd or self.cwd
exec_command, sudo_stdin = self._prepare_command(command) exec_command, sudo_stdin = self._prepare_command(command)
wrapped = f'cd {work_dir} && {exec_command}' wrapped = f'cd {work_dir} && {exec_command}'
effective_timeout = timeout or self.timeout effective_timeout = timeout or self.timeout
# Merge sudo password (if any) with caller-supplied stdin_data.
if sudo_stdin is not None and stdin_data is not None: if sudo_stdin is not None and stdin_data is not None:
effective_stdin = sudo_stdin + stdin_data effective_stdin = sudo_stdin + stdin_data
elif sudo_stdin is not None: elif sudo_stdin is not None:
@ -169,11 +157,8 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment):
try: try:
kwargs = self._build_run_kwargs(timeout, effective_stdin) kwargs = self._build_run_kwargs(timeout, effective_stdin)
# Remove timeout from kwargs -- we handle it in the poll loop
kwargs.pop("timeout", None) kwargs.pop("timeout", None)
_output_chunks = [] _output_chunks = []
proc = subprocess.Popen( proc = subprocess.Popen(
cmd, cmd,
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
@ -224,10 +209,6 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment):
except Exception as e: except Exception as e:
return {"output": f"SSH execution error: {str(e)}", "returncode": 1} return {"output": f"SSH execution error: {str(e)}", "returncode": 1}
# ------------------------------------------------------------------
# Public interface
# ------------------------------------------------------------------
def execute(self, command: str, cwd: str = "", *, def execute(self, command: str, cwd: str = "", *,
timeout: int | None = None, timeout: int | None = None,
stdin_data: str | None = None) -> dict: stdin_data: str | None = None) -> dict:
@ -240,11 +221,8 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment):
) )
def cleanup(self): def cleanup(self):
# Persistent shell teardown (via mixin)
if self.persistent: if self.persistent:
self._cleanup_persistent_shell() self._cleanup_persistent_shell()
# ControlMaster cleanup
if self.control_socket.exists(): if self.control_socket.exists():
try: try:
cmd = ["ssh", "-o", f"ControlPath={self.control_socket}", cmd = ["ssh", "-o", f"ControlPath={self.control_socket}",

View file

@ -504,7 +504,6 @@ def _get_env_config() -> Dict[str, Any]:
"ssh_port": _parse_env_var("TERMINAL_SSH_PORT", "22"), "ssh_port": _parse_env_var("TERMINAL_SSH_PORT", "22"),
"ssh_key": os.getenv("TERMINAL_SSH_KEY", ""), "ssh_key": os.getenv("TERMINAL_SSH_KEY", ""),
"ssh_persistent": os.getenv("TERMINAL_SSH_PERSISTENT", "false").lower() in ("true", "1", "yes"), "ssh_persistent": os.getenv("TERMINAL_SSH_PERSISTENT", "false").lower() in ("true", "1", "yes"),
# Local persistent shell (cwd/env vars survive across calls)
"local_persistent": os.getenv("TERMINAL_LOCAL_PERSISTENT", "false").lower() in ("true", "1", "yes"), "local_persistent": os.getenv("TERMINAL_LOCAL_PERSISTENT", "false").lower() in ("true", "1", "yes"),
# Container resource config (applies to docker, singularity, modal, daytona -- ignored for local/ssh) # Container resource config (applies to docker, singularity, modal, daytona -- ignored for local/ssh)
"container_cpu": _parse_env_var("TERMINAL_CONTAINER_CPU", "1", float, "number"), "container_cpu": _parse_env_var("TERMINAL_CONTAINER_CPU", "1", float, "number"),