add persistent ssh backend

This commit is contained in:
balyan.sid@gmail.com 2026-03-12 23:38:11 +05:30
parent b2bdaecf9b
commit 9d63dcc3f9
2 changed files with 308 additions and 4 deletions

View file

@ -1,10 +1,12 @@
"""SSH remote execution environment with ControlMaster connection persistence.""" """SSH remote execution environment with ControlMaster connection persistence."""
import logging import logging
import shlex
import subprocess import subprocess
import tempfile import tempfile
import threading import threading
import time import time
import uuid
from pathlib import Path from pathlib import Path
from tools.environments.base import BaseEnvironment from tools.environments.base import BaseEnvironment
@ -22,21 +24,44 @@ class SSHEnvironment(BaseEnvironment):
Foreground commands are interruptible: the local ssh process is killed Foreground commands are interruptible: the local ssh process is killed
and a remote kill is attempted over the ControlMaster socket. and a remote kill is attempted over the ControlMaster socket.
When ``persistent=True``, a single long-lived bash shell is kept alive
over SSH and state (cwd, env vars, shell variables) persists across
``execute()`` calls. Output capture uses file-based IPC on the remote
host (stdout/stderr/exit-code written to temp files, polled via fast
ControlMaster one-shot reads).
""" """
def __init__(self, host: str, user: str, cwd: str = "~", def __init__(self, host: str, user: str, cwd: str = "~",
timeout: int = 60, port: int = 22, key_path: str = ""): timeout: int = 60, port: int = 22, key_path: str = "",
persistent: bool = False):
super().__init__(cwd=cwd, timeout=timeout) super().__init__(cwd=cwd, timeout=timeout)
self.host = host self.host = host
self.user = user self.user = user
self.port = port self.port = port
self.key_path = key_path self.key_path = key_path
self.persistent = persistent
self.control_dir = Path(tempfile.gettempdir()) / "hermes-ssh" self.control_dir = Path(tempfile.gettempdir()) / "hermes-ssh"
self.control_dir.mkdir(parents=True, exist_ok=True) self.control_dir.mkdir(parents=True, exist_ok=True)
self.control_socket = self.control_dir / f"{user}@{host}:{port}.sock" self.control_socket = self.control_dir / f"{user}@{host}:{port}.sock"
self._establish_connection() self._establish_connection()
# Persistent shell state
self._shell_proc: subprocess.Popen | None = None
self._shell_lock = threading.Lock()
self._shell_alive = False
self._session_id: str = ""
self._remote_stdout: str = ""
self._remote_stderr: str = ""
self._remote_status: str = ""
self._remote_cwd: str = ""
self._remote_pid: str = ""
self._remote_shell_pid: int | None = None
if self.persistent:
self._start_persistent_shell()
def _build_ssh_command(self, extra_args: list = None) -> list: def _build_ssh_command(self, extra_args: list = None) -> list:
cmd = ["ssh"] cmd = ["ssh"]
cmd.extend(["-o", f"ControlPath={self.control_socket}"]) cmd.extend(["-o", f"ControlPath={self.control_socket}"])
@ -65,9 +90,240 @@ class SSHEnvironment(BaseEnvironment):
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
raise RuntimeError(f"SSH connection to {self.user}@{self.host} timed out") raise RuntimeError(f"SSH connection to {self.user}@{self.host} timed out")
def execute(self, command: str, cwd: str = "", *, # ------------------------------------------------------------------
timeout: int | None = None, # Persistent shell management
stdin_data: str | None = None) -> dict: # ------------------------------------------------------------------
def _start_persistent_shell(self):
"""Spawn a long-lived bash shell over SSH."""
self._session_id = uuid.uuid4().hex[:12]
prefix = f"/tmp/hermes-ssh-{self._session_id}"
self._remote_stdout = f"{prefix}-stdout"
self._remote_stderr = f"{prefix}-stderr"
self._remote_status = f"{prefix}-status"
self._remote_cwd = f"{prefix}-cwd"
self._remote_pid = f"{prefix}-pid"
cmd = self._build_ssh_command()
cmd.append("bash -l")
self._shell_proc = subprocess.Popen(
cmd,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
self._shell_alive = True
# Start daemon thread to drain stdout/stderr and detect shell death
self._drain_thread = threading.Thread(
target=self._drain_shell_output, daemon=True
)
self._drain_thread.start()
# Initialize remote temp files and capture shell PID
init_script = (
f"touch {self._remote_stdout} {self._remote_stderr} "
f"{self._remote_status} {self._remote_cwd} {self._remote_pid}\n"
f"echo $$ > {self._remote_pid}\n"
f"pwd > {self._remote_cwd}\n"
)
self._send_to_shell(init_script)
# Give shell time to initialize and write PID file
time.sleep(0.3)
# Read the remote shell PID
pid_str = self._read_remote_file(self._remote_pid).strip()
if pid_str.isdigit():
self._remote_shell_pid = int(pid_str)
logger.info("Persistent shell started (session=%s, pid=%d)",
self._session_id, self._remote_shell_pid)
else:
logger.warning("Could not read persistent shell PID (got %r)", pid_str)
self._remote_shell_pid = None
# Update cwd from what the shell reports
remote_cwd = self._read_remote_file(self._remote_cwd).strip()
if remote_cwd:
self.cwd = remote_cwd
def _drain_shell_output(self):
"""Drain the shell's stdout/stderr to prevent pipe deadlock.
Also detects when the shell process dies.
"""
try:
for _ in self._shell_proc.stdout:
pass # Discard — real output goes to temp files
except Exception:
pass
self._shell_alive = False
def _send_to_shell(self, text: str):
"""Write text to the persistent shell's stdin."""
if not self._shell_alive or self._shell_proc is None:
return
try:
self._shell_proc.stdin.write(text)
self._shell_proc.stdin.flush()
except (BrokenPipeError, OSError):
self._shell_alive = False
def _read_remote_file(self, path: str) -> str:
"""Read a file on the remote host via a one-shot SSH command.
Uses ControlMaster so this is very fast (~5ms on LAN).
"""
cmd = self._build_ssh_command()
cmd.append(f"cat {path} 2>/dev/null")
try:
result = subprocess.run(
cmd, capture_output=True, text=True, timeout=10
)
return result.stdout
except (subprocess.TimeoutExpired, OSError):
return ""
def _kill_shell_children(self):
"""Kill children of the persistent shell (the running command),
but not the shell itself."""
if self._remote_shell_pid is None:
return
cmd = self._build_ssh_command()
cmd.append(f"pkill -P {self._remote_shell_pid} 2>/dev/null; true")
try:
subprocess.run(cmd, capture_output=True, timeout=5)
except (subprocess.TimeoutExpired, OSError):
pass
def _execute_persistent(self, command: str, cwd: str, *,
timeout: int | None = None,
stdin_data: str | None = None) -> dict:
"""Execute a command in the persistent shell."""
# If shell is dead, restart it
if not self._shell_alive:
logger.info("Persistent shell died, restarting...")
self._start_persistent_shell()
exec_command, sudo_stdin = self._prepare_command(command)
effective_timeout = timeout or self.timeout
# Fall back to one-shot for commands needing piped stdin
if stdin_data or sudo_stdin:
return self._execute_oneshot(
command, cwd, timeout=timeout, stdin_data=stdin_data
)
with self._shell_lock:
return self._execute_persistent_locked(
exec_command, cwd, effective_timeout
)
def _execute_persistent_locked(self, command: str, cwd: str,
timeout: int) -> dict:
"""Inner persistent execution — caller must hold _shell_lock."""
work_dir = cwd or self.cwd
# Truncate temp files
truncate = (
f": > {self._remote_stdout}\n"
f": > {self._remote_stderr}\n"
f": > {self._remote_status}\n"
)
self._send_to_shell(truncate)
# Escape command for eval — use single quotes with proper escaping
escaped = command.replace("'", "'\\''")
# Send the IPC script
ipc_script = (
f"cd {shlex.quote(work_dir)}\n"
f"eval '{escaped}' < /dev/null > {self._remote_stdout} 2> {self._remote_stderr}\n"
f"__EC=$?\n"
f"pwd > {self._remote_cwd}\n"
f"echo $__EC > {self._remote_status}\n"
)
self._send_to_shell(ipc_script)
# Poll the status file
deadline = time.monotonic() + timeout
poll_interval = 0.05 # 50ms
while True:
if is_interrupted():
self._kill_shell_children()
stdout = self._read_remote_file(self._remote_stdout)
stderr = self._read_remote_file(self._remote_stderr)
output = self._merge_output(stdout, stderr)
return {
"output": output + "\n[Command interrupted]",
"returncode": 130,
}
if time.monotonic() > deadline:
self._kill_shell_children()
stdout = self._read_remote_file(self._remote_stdout)
stderr = self._read_remote_file(self._remote_stderr)
output = self._merge_output(stdout, stderr)
if output:
return {
"output": output + f"\n[Command timed out after {timeout}s]",
"returncode": 124,
}
return self._timeout_result(timeout)
if not self._shell_alive:
return {
"output": "Persistent shell died during execution",
"returncode": 1,
}
# Check if status file has content (command is done)
status_content = self._read_remote_file(self._remote_status).strip()
if status_content:
break
time.sleep(poll_interval)
# Read results
stdout = self._read_remote_file(self._remote_stdout)
stderr = self._read_remote_file(self._remote_stderr)
exit_code_str = status_content
new_cwd = self._read_remote_file(self._remote_cwd).strip()
# Parse exit code
try:
exit_code = int(exit_code_str)
except ValueError:
exit_code = 1
# Update cwd
if new_cwd:
self.cwd = new_cwd
output = self._merge_output(stdout, stderr)
return {"output": output, "returncode": exit_code}
@staticmethod
def _merge_output(stdout: str, stderr: str) -> str:
"""Combine stdout and stderr into a single output string."""
parts = []
if stdout.strip():
parts.append(stdout.rstrip("\n"))
if stderr.strip():
parts.append(stderr.rstrip("\n"))
return "\n".join(parts)
# ------------------------------------------------------------------
# One-shot execution (original behavior)
# ------------------------------------------------------------------
def _execute_oneshot(self, command: str, cwd: str = "", *,
timeout: int | None = None,
stdin_data: str | None = None) -> dict:
"""Execute a command via a fresh one-shot SSH invocation."""
work_dir = cwd or self.cwd work_dir = cwd or self.cwd
exec_command, sudo_stdin = self._prepare_command(command) exec_command, sudo_stdin = self._prepare_command(command)
wrapped = f'cd {work_dir} && {exec_command}' wrapped = f'cd {work_dir} && {exec_command}'
@ -141,7 +397,52 @@ class SSHEnvironment(BaseEnvironment):
except Exception as e: except Exception as e:
return {"output": f"SSH execution error: {str(e)}", "returncode": 1} return {"output": f"SSH execution error: {str(e)}", "returncode": 1}
# ------------------------------------------------------------------
# Public interface
# ------------------------------------------------------------------
def execute(self, command: str, cwd: str = "", *,
timeout: int | None = None,
stdin_data: str | None = None) -> dict:
if self.persistent:
return self._execute_persistent(
command, cwd, timeout=timeout, stdin_data=stdin_data
)
return self._execute_oneshot(
command, cwd, timeout=timeout, stdin_data=stdin_data
)
def cleanup(self): def cleanup(self):
# Persistent shell teardown
if self.persistent and self._shell_proc is not None:
# Remove remote temp files
if self._session_id:
try:
cmd = self._build_ssh_command()
cmd.append(
f"rm -f /tmp/hermes-ssh-{self._session_id}-*"
)
subprocess.run(cmd, capture_output=True, timeout=5)
except (OSError, subprocess.SubprocessError):
pass
# Close the shell
try:
self._shell_proc.stdin.close()
except Exception:
pass
try:
self._shell_proc.terminate()
self._shell_proc.wait(timeout=3)
except Exception:
try:
self._shell_proc.kill()
except Exception:
pass
self._shell_alive = False
self._shell_proc = None
# ControlMaster cleanup
if self.control_socket.exists(): if self.control_socket.exists():
try: try:
cmd = ["ssh", "-o", f"ControlPath={self.control_socket}", cmd = ["ssh", "-o", f"ControlPath={self.control_socket}",

View file

@ -503,6 +503,7 @@ def _get_env_config() -> Dict[str, Any]:
"ssh_user": os.getenv("TERMINAL_SSH_USER", ""), "ssh_user": os.getenv("TERMINAL_SSH_USER", ""),
"ssh_port": _parse_env_var("TERMINAL_SSH_PORT", "22"), "ssh_port": _parse_env_var("TERMINAL_SSH_PORT", "22"),
"ssh_key": os.getenv("TERMINAL_SSH_KEY", ""), "ssh_key": os.getenv("TERMINAL_SSH_KEY", ""),
"ssh_persistent": os.getenv("TERMINAL_SSH_PERSISTENT", "false").lower() in ("true", "1", "yes"),
# Container resource config (applies to docker, singularity, modal, daytona -- ignored for local/ssh) # Container resource config (applies to docker, singularity, modal, daytona -- ignored for local/ssh)
"container_cpu": _parse_env_var("TERMINAL_CONTAINER_CPU", "1", float, "number"), "container_cpu": _parse_env_var("TERMINAL_CONTAINER_CPU", "1", float, "number"),
"container_memory": _parse_env_var("TERMINAL_CONTAINER_MEMORY", "5120"), # MB (default 5GB) "container_memory": _parse_env_var("TERMINAL_CONTAINER_MEMORY", "5120"), # MB (default 5GB)
@ -594,6 +595,7 @@ def _create_environment(env_type: str, image: str, cwd: str, timeout: int,
key_path=ssh_config.get("key", ""), key_path=ssh_config.get("key", ""),
cwd=cwd, cwd=cwd,
timeout=timeout, timeout=timeout,
persistent=ssh_config.get("persistent", False),
) )
else: else:
@ -923,6 +925,7 @@ def terminal_tool(
"user": config.get("ssh_user", ""), "user": config.get("ssh_user", ""),
"port": config.get("ssh_port", 22), "port": config.get("ssh_port", 22),
"key": config.get("ssh_key", ""), "key": config.get("ssh_key", ""),
"persistent": config.get("ssh_persistent", False),
} }
container_config = None container_config = None