refactor: deduplicate execute/cleanup, merge init, clean up helpers

- Merge _init_persistent_shell + _start_persistent_shell into single method
- Move execute() dispatcher and cleanup() into PersistentShellMixin
  so LocalEnvironment and SSHEnvironment inherit them
- Remove broad except Exception wrappers from _execute_oneshot in both backends
- Replace try/except with os.path.exists checks in local _read_temp_files
  and _cleanup_temp_files
- Remove redundant bash -c from SSH oneshot (SSH already runs in a shell)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
alt-glitch 2026-03-15 02:39:56 +05:30
parent 7be314c456
commit 9f36483bf4
3 changed files with 195 additions and 203 deletions

View file

@ -247,10 +247,10 @@ class LocalEnvironment(PersistentShellMixin, BaseEnvironment):
def _read_temp_files(self, *paths: str) -> list[str]: def _read_temp_files(self, *paths: str) -> list[str]:
results = [] results = []
for path in paths: for path in paths:
try: if os.path.exists(path):
with open(path) as f: with open(path) as f:
results.append(f.read()) results.append(f.read())
except OSError: else:
results.append("") results.append("")
return results return results
@ -262,15 +262,13 @@ class LocalEnvironment(PersistentShellMixin, BaseEnvironment):
["pkill", "-P", str(self._shell_pid)], ["pkill", "-P", str(self._shell_pid)],
capture_output=True, timeout=5, capture_output=True, timeout=5,
) )
except (subprocess.TimeoutExpired, OSError, FileNotFoundError): except (subprocess.TimeoutExpired, FileNotFoundError):
pass pass
def _cleanup_temp_files(self): def _cleanup_temp_files(self):
for f in glob.glob(f"{self._temp_prefix}-*"): for f in glob.glob(f"{self._temp_prefix}-*"):
try: if os.path.exists(f):
os.remove(f) os.remove(f)
except OSError:
pass
def _execute_oneshot(self, command: str, cwd: str = "", *, def _execute_oneshot(self, command: str, cwd: str = "", *,
timeout: int | None = None, timeout: int | None = None,
@ -286,106 +284,87 @@ class LocalEnvironment(PersistentShellMixin, BaseEnvironment):
else: else:
effective_stdin = stdin_data effective_stdin = stdin_data
try: user_shell = _find_bash()
user_shell = _find_bash() fenced_cmd = (
fenced_cmd = ( f"printf '{_OUTPUT_FENCE}';"
f"printf '{_OUTPUT_FENCE}';" f" {exec_command};"
f" {exec_command};" f" __hermes_rc=$?;"
f" __hermes_rc=$?;" f" printf '{_OUTPUT_FENCE}';"
f" printf '{_OUTPUT_FENCE}';" f" exit $__hermes_rc"
f" exit $__hermes_rc" )
) run_env = _make_run_env(self.env)
run_env = _make_run_env(self.env)
proc = subprocess.Popen( proc = subprocess.Popen(
[user_shell, "-lic", fenced_cmd], [user_shell, "-lic", fenced_cmd],
text=True, text=True,
cwd=work_dir, cwd=work_dir,
env=run_env, env=run_env,
encoding="utf-8", encoding="utf-8",
errors="replace", errors="replace",
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, stderr=subprocess.STDOUT,
stdin=subprocess.PIPE if effective_stdin is not None else subprocess.DEVNULL, stdin=subprocess.PIPE if effective_stdin is not None else subprocess.DEVNULL,
preexec_fn=None if _IS_WINDOWS else os.setsid, preexec_fn=None if _IS_WINDOWS else os.setsid,
)
if effective_stdin is not None:
def _write_stdin():
try:
proc.stdin.write(effective_stdin)
proc.stdin.close()
except (BrokenPipeError, OSError):
pass
threading.Thread(target=_write_stdin, daemon=True).start()
_output_chunks: list[str] = []
def _drain_stdout():
try:
for line in proc.stdout:
_output_chunks.append(line)
except ValueError:
pass
finally:
try:
proc.stdout.close()
except Exception:
pass
reader = threading.Thread(target=_drain_stdout, daemon=True)
reader.start()
deadline = time.monotonic() + effective_timeout
while proc.poll() is None:
if is_interrupted():
try:
if _IS_WINDOWS:
proc.terminate()
else:
pgid = os.getpgid(proc.pid)
os.killpg(pgid, signal.SIGTERM)
try:
proc.wait(timeout=1.0)
except subprocess.TimeoutExpired:
os.killpg(pgid, signal.SIGKILL)
except (ProcessLookupError, PermissionError):
proc.kill()
reader.join(timeout=2)
return {
"output": "".join(_output_chunks) + "\n[Command interrupted — user sent a new message]",
"returncode": 130,
}
if time.monotonic() > deadline:
try:
if _IS_WINDOWS:
proc.terminate()
else:
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
except (ProcessLookupError, PermissionError):
proc.kill()
reader.join(timeout=2)
return self._timeout_result(effective_timeout)
time.sleep(0.2)
reader.join(timeout=5)
output = _extract_fenced_output("".join(_output_chunks))
return {"output": output, "returncode": proc.returncode}
except Exception as e:
return {"output": f"Execution error: {str(e)}", "returncode": 1}
def execute(self, command: str, cwd: str = "", *,
timeout: int | None = None,
stdin_data: str | None = None) -> dict:
if self.persistent:
return self._execute_persistent(
command, cwd, timeout=timeout, stdin_data=stdin_data,
)
return self._execute_oneshot(
command, cwd, timeout=timeout, stdin_data=stdin_data,
) )
def cleanup(self): if effective_stdin is not None:
if self.persistent: def _write_stdin():
self._cleanup_persistent_shell() try:
proc.stdin.write(effective_stdin)
proc.stdin.close()
except (BrokenPipeError, OSError):
pass
threading.Thread(target=_write_stdin, daemon=True).start()
_output_chunks: list[str] = []
def _drain_stdout():
try:
for line in proc.stdout:
_output_chunks.append(line)
except ValueError:
pass
finally:
try:
proc.stdout.close()
except Exception:
pass
reader = threading.Thread(target=_drain_stdout, daemon=True)
reader.start()
deadline = time.monotonic() + effective_timeout
while proc.poll() is None:
if is_interrupted():
try:
if _IS_WINDOWS:
proc.terminate()
else:
pgid = os.getpgid(proc.pid)
os.killpg(pgid, signal.SIGTERM)
try:
proc.wait(timeout=1.0)
except subprocess.TimeoutExpired:
os.killpg(pgid, signal.SIGKILL)
except (ProcessLookupError, PermissionError):
proc.kill()
reader.join(timeout=2)
return {
"output": "".join(_output_chunks) + "\n[Command interrupted — user sent a new message]",
"returncode": 130,
}
if time.monotonic() > deadline:
try:
if _IS_WINDOWS:
proc.terminate()
else:
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
except (ProcessLookupError, PermissionError):
proc.kill()
reader.join(timeout=2)
return self._timeout_result(effective_timeout)
time.sleep(0.2)
reader.join(timeout=5)
output = _extract_fenced_output("".join(_output_chunks))
return {"output": output, "returncode": proc.returncode}

View file

@ -17,9 +17,11 @@ class PersistentShellMixin:
"""Mixin that adds persistent shell capability to any BaseEnvironment. """Mixin that adds persistent shell capability to any BaseEnvironment.
Subclasses must implement ``_spawn_shell_process()``, ``_read_temp_files()``, Subclasses must implement ``_spawn_shell_process()``, ``_read_temp_files()``,
``_kill_shell_children()``, and ``_execute_oneshot()`` (stdin fallback). ``_kill_shell_children()``, ``_execute_oneshot()``, and ``_cleanup_temp_files()``.
""" """
persistent: bool
@abstractmethod @abstractmethod
def _spawn_shell_process(self) -> subprocess.Popen: ... def _spawn_shell_process(self) -> subprocess.Popen: ...
@ -43,15 +45,16 @@ class PersistentShellMixin:
def _temp_prefix(self) -> str: def _temp_prefix(self) -> str:
return f"/tmp/hermes-persistent-{self._session_id}" return f"/tmp/hermes-persistent-{self._session_id}"
# ------------------------------------------------------------------
# Lifecycle
# ------------------------------------------------------------------
def _init_persistent_shell(self): def _init_persistent_shell(self):
self._shell_lock = threading.Lock() self._shell_lock = threading.Lock()
self._session_id: str = ""
self._shell_proc: subprocess.Popen | None = None self._shell_proc: subprocess.Popen | None = None
self._shell_alive: bool = False self._shell_alive: bool = False
self._shell_pid: int | None = None self._shell_pid: int | None = None
self._start_persistent_shell()
def _start_persistent_shell(self):
self._session_id = uuid.uuid4().hex[:12] self._session_id = uuid.uuid4().hex[:12]
p = self._temp_prefix p = self._temp_prefix
self._pshell_stdout = f"{p}-stdout" self._pshell_stdout = f"{p}-stdout"
@ -98,6 +101,52 @@ class PersistentShellMixin:
if reported_cwd: if reported_cwd:
self.cwd = reported_cwd self.cwd = reported_cwd
def _cleanup_persistent_shell(self):
if self._shell_proc is None:
return
if self._session_id:
self._cleanup_temp_files()
try:
self._shell_proc.stdin.close()
except Exception:
pass
try:
self._shell_proc.terminate()
self._shell_proc.wait(timeout=3)
except subprocess.TimeoutExpired:
self._shell_proc.kill()
self._shell_alive = False
self._shell_proc = None
if hasattr(self, "_drain_thread") and self._drain_thread.is_alive():
self._drain_thread.join(timeout=1.0)
# ------------------------------------------------------------------
# execute() / cleanup() — shared dispatcher, subclasses inherit
# ------------------------------------------------------------------
def execute(self, command: str, cwd: str = "", *,
timeout: int | None = None,
stdin_data: str | None = None) -> dict:
if self.persistent:
return self._execute_persistent(
command, cwd, timeout=timeout, stdin_data=stdin_data,
)
return self._execute_oneshot(
command, cwd, timeout=timeout, stdin_data=stdin_data,
)
def cleanup(self):
if self.persistent:
self._cleanup_persistent_shell()
# ------------------------------------------------------------------
# Shell I/O
# ------------------------------------------------------------------
def _drain_shell_output(self): def _drain_shell_output(self):
try: try:
for _ in self._shell_proc.stdout: for _ in self._shell_proc.stdout:
@ -130,12 +179,16 @@ class PersistentShellMixin:
exit_code = 1 exit_code = 1
return output, exit_code, cwd.strip() return output, exit_code, cwd.strip()
# ------------------------------------------------------------------
# Execution
# ------------------------------------------------------------------
def _execute_persistent(self, command: str, cwd: str, *, def _execute_persistent(self, command: str, cwd: str, *,
timeout: int | None = None, timeout: int | None = None,
stdin_data: str | None = None) -> dict: stdin_data: str | None = None) -> dict:
if not self._shell_alive: if not self._shell_alive:
logger.info("Persistent shell died, restarting...") logger.info("Persistent shell died, restarting...")
self._start_persistent_shell() self._init_persistent_shell()
exec_command, sudo_stdin = self._prepare_command(command) exec_command, sudo_stdin = self._prepare_command(command)
effective_timeout = timeout or self.timeout effective_timeout = timeout or self.timeout
@ -216,27 +269,3 @@ class PersistentShellMixin:
if stderr.strip(): if stderr.strip():
parts.append(stderr.rstrip("\n")) parts.append(stderr.rstrip("\n"))
return "\n".join(parts) return "\n".join(parts)
def _cleanup_persistent_shell(self):
if self._shell_proc is None:
return
if self._session_id:
self._cleanup_temp_files()
try:
self._shell_proc.stdin.close()
except Exception:
pass
try:
self._shell_proc.terminate()
self._shell_proc.wait(timeout=3)
except subprocess.TimeoutExpired:
self._shell_proc.kill()
self._shell_alive = False
self._shell_proc = None
if hasattr(self, "_drain_thread") and self._drain_thread.is_alive():
self._drain_thread.join(timeout=1.0)

View file

@ -130,11 +130,11 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment):
pass pass
def _cleanup_temp_files(self): def _cleanup_temp_files(self):
cmd = self._build_ssh_command()
cmd.append(f"rm -f {self._temp_prefix}-*")
try: try:
cmd = self._build_ssh_command()
cmd.append(f"rm -f {self._temp_prefix}-*")
subprocess.run(cmd, capture_output=True, timeout=5) subprocess.run(cmd, capture_output=True, timeout=5)
except (OSError, subprocess.SubprocessError): except (subprocess.TimeoutExpired, OSError):
pass pass
def _execute_oneshot(self, command: str, cwd: str = "", *, def _execute_oneshot(self, command: str, cwd: str = "", *,
@ -155,74 +155,58 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment):
cmd = self._build_ssh_command() cmd = self._build_ssh_command()
cmd.append(wrapped) cmd.append(wrapped)
try: kwargs = self._build_run_kwargs(timeout, effective_stdin)
kwargs = self._build_run_kwargs(timeout, effective_stdin) kwargs.pop("timeout", None)
kwargs.pop("timeout", None) _output_chunks = []
_output_chunks = [] proc = subprocess.Popen(
proc = subprocess.Popen( cmd,
cmd, stdout=subprocess.PIPE,
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
stderr=subprocess.STDOUT, stdin=subprocess.PIPE if effective_stdin else subprocess.DEVNULL,
stdin=subprocess.PIPE if effective_stdin else subprocess.DEVNULL, text=True,
text=True,
)
if effective_stdin:
try:
proc.stdin.write(effective_stdin)
proc.stdin.close()
except Exception:
pass
def _drain():
try:
for line in proc.stdout:
_output_chunks.append(line)
except Exception:
pass
reader = threading.Thread(target=_drain, daemon=True)
reader.start()
deadline = time.monotonic() + effective_timeout
while proc.poll() is None:
if is_interrupted():
proc.terminate()
try:
proc.wait(timeout=1)
except subprocess.TimeoutExpired:
proc.kill()
reader.join(timeout=2)
return {
"output": "".join(_output_chunks) + "\n[Command interrupted]",
"returncode": 130,
}
if time.monotonic() > deadline:
proc.kill()
reader.join(timeout=2)
return self._timeout_result(effective_timeout)
time.sleep(0.2)
reader.join(timeout=5)
return {"output": "".join(_output_chunks), "returncode": proc.returncode}
except Exception as e:
return {"output": f"SSH execution error: {str(e)}", "returncode": 1}
def execute(self, command: str, cwd: str = "", *,
timeout: int | None = None,
stdin_data: str | None = None) -> dict:
if self.persistent:
return self._execute_persistent(
command, cwd, timeout=timeout, stdin_data=stdin_data
)
return self._execute_oneshot(
command, cwd, timeout=timeout, stdin_data=stdin_data
) )
if effective_stdin:
try:
proc.stdin.write(effective_stdin)
proc.stdin.close()
except (BrokenPipeError, OSError):
pass
def _drain():
try:
for line in proc.stdout:
_output_chunks.append(line)
except Exception:
pass
reader = threading.Thread(target=_drain, daemon=True)
reader.start()
deadline = time.monotonic() + effective_timeout
while proc.poll() is None:
if is_interrupted():
proc.terminate()
try:
proc.wait(timeout=1)
except subprocess.TimeoutExpired:
proc.kill()
reader.join(timeout=2)
return {
"output": "".join(_output_chunks) + "\n[Command interrupted]",
"returncode": 130,
}
if time.monotonic() > deadline:
proc.kill()
reader.join(timeout=2)
return self._timeout_result(effective_timeout)
time.sleep(0.2)
reader.join(timeout=5)
return {"output": "".join(_output_chunks), "returncode": proc.returncode}
def cleanup(self): def cleanup(self):
if self.persistent: super().cleanup()
self._cleanup_persistent_shell()
if self.control_socket.exists(): if self.control_socket.exists():
try: try:
cmd = ["ssh", "-o", f"ControlPath={self.control_socket}", cmd = ["ssh", "-o", f"ControlPath={self.control_socket}",