add expected_runtime_seconds and shutdown processes w/ grace period that run too long.

This commit is contained in:
AI Christianson 2025-02-13 08:29:17 -05:00
parent c5c27c9f87
commit abfa5a1d6a
1 changed files with 40 additions and 3 deletions

View File

@ -19,6 +19,8 @@ import subprocess
import select import select
import termios import termios
import tty import tty
import time
import signal
from typing import List, Tuple from typing import List, Tuple
import pyte import pyte
@ -28,7 +30,7 @@ def render_line(line, columns: int) -> str:
"""Render a single screen line from the pyte buffer (a mapping of column to Char).""" """Render a single screen line from the pyte buffer (a mapping of column to Char)."""
return "".join(line[x].data for x in range(columns)) return "".join(line[x].data for x in range(columns))
def run_interactive_command(cmd: List[str]) -> Tuple[bytes, int]: def run_interactive_command(cmd: List[str], expected_runtime_seconds: int = 30) -> Tuple[bytes, int]:
""" """
Runs an interactive command with a pseudo-tty, capturing final scrollback history. Runs an interactive command with a pseudo-tty, capturing final scrollback history.
@ -37,6 +39,12 @@ def run_interactive_command(cmd: List[str]) -> Tuple[bytes, int]:
- `cmd` is a non-empty list where cmd[0] is the executable. - `cmd` is a non-empty list where cmd[0] is the executable.
- The executable is on PATH. - The executable is on PATH.
Args:
cmd: A list containing the command and its arguments.
expected_runtime_seconds: Expected runtime in seconds, defaults to 30.
If process exceeds 2x this value, it will be terminated gracefully.
If process exceeds 3x this value, it will be killed forcefully.
Returns: Returns:
A tuple of (captured_output, return_code), where captured_output is a UTF-8 encoded A tuple of (captured_output, return_code), where captured_output is a UTF-8 encoded
bytes object containing the trimmed non-empty history lines from the terminal session. bytes object containing the trimmed non-empty history lines from the terminal session.
@ -93,11 +101,26 @@ def run_interactive_command(cmd: List[str]) -> Tuple[bytes, int]:
stderr=slave_fd, stderr=slave_fd,
bufsize=0, bufsize=0,
close_fds=True, close_fds=True,
env=env env=env,
preexec_fn=os.setsid # Create new process group for proper signal handling
) )
os.close(slave_fd) # Close slave end in the parent process. os.close(slave_fd) # Close slave end in the parent process.
captured_data = [] captured_data = []
start_time = time.time()
was_terminated = False
def check_timeout():
elapsed = time.time() - start_time
if elapsed > 3 * expected_runtime_seconds:
# Force kill after 3x expected runtime
os.killpg(os.getpgid(proc.pid), signal.SIGKILL)
return True
elif elapsed > 2 * expected_runtime_seconds:
# Graceful termination after 2x expected runtime
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
return True
return False
# If we're in an interactive TTY, set raw mode and forward input. # If we're in an interactive TTY, set raw mode and forward input.
if stdin_fd is not None and sys.stdin.isatty(): if stdin_fd is not None and sys.stdin.isatty():
@ -105,7 +128,10 @@ def run_interactive_command(cmd: List[str]) -> Tuple[bytes, int]:
tty.setraw(stdin_fd) tty.setraw(stdin_fd)
try: try:
while True: while True:
rlist, _, _ = select.select([master_fd, stdin_fd], [], []) if check_timeout():
was_terminated = True
break
rlist, _, _ = select.select([master_fd, stdin_fd], [], [], 1.0) # 1 second timeout for select
if master_fd in rlist: if master_fd in rlist:
try: try:
data = os.read(master_fd, 1024) data = os.read(master_fd, 1024)
@ -137,7 +163,13 @@ def run_interactive_command(cmd: List[str]) -> Tuple[bytes, int]:
# Non-interactive mode (e.g., during unit tests). # Non-interactive mode (e.g., during unit tests).
try: try:
while True: while True:
if check_timeout():
was_terminated = True
break
try: try:
rlist, _, _ = select.select([master_fd], [], [], 1.0) # 1 second timeout for select
if not rlist:
continue
data = os.read(master_fd, 1024) data = os.read(master_fd, 1024)
except OSError as e: except OSError as e:
if e.errno == errno.EIO: if e.errno == errno.EIO:
@ -165,6 +197,11 @@ def run_interactive_command(cmd: List[str]) -> Tuple[bytes, int]:
trimmed_lines = [line for line in all_lines if line.strip()] trimmed_lines = [line for line in all_lines if line.strip()]
final_output = "\n".join(trimmed_lines) final_output = "\n".join(trimmed_lines)
# Add timeout message if process was terminated
if was_terminated:
timeout_msg = f"\n[Process exceeded timeout ({expected_runtime_seconds} seconds expected)]"
final_output += timeout_msg
# Limit output to last 8000 bytes # Limit output to last 8000 bytes
final_output = final_output[-8000:] final_output = final_output[-8000:]