diff --git a/ra_aid/proc/interactive.py b/ra_aid/proc/interactive.py index 0942166..2cfb8e9 100644 --- a/ra_aid/proc/interactive.py +++ b/ra_aid/proc/interactive.py @@ -19,6 +19,8 @@ import subprocess import select import termios import tty +import time +import signal from typing import List, Tuple import pyte @@ -28,7 +30,7 @@ def render_line(line, columns: int) -> str: """Render a single screen line from the pyte buffer (a mapping of column to Char).""" return "".join(line[x].data for x in range(columns)) -def run_interactive_command(cmd: List[str]) -> Tuple[bytes, int]: +def run_interactive_command(cmd: List[str], expected_runtime_seconds: int = 30) -> Tuple[bytes, int]: """ Runs an interactive command with a pseudo-tty, capturing final scrollback history. @@ -37,6 +39,12 @@ def run_interactive_command(cmd: List[str]) -> Tuple[bytes, int]: - `cmd` is a non-empty list where cmd[0] is the executable. - The executable is on PATH. + Args: + cmd: A list containing the command and its arguments. + expected_runtime_seconds: Expected runtime in seconds, defaults to 30. + If process exceeds 2x this value, it will be terminated gracefully. + If process exceeds 3x this value, it will be killed forcefully. + Returns: A tuple of (captured_output, return_code), where captured_output is a UTF-8 encoded bytes object containing the trimmed non-empty history lines from the terminal session. @@ -93,11 +101,26 @@ def run_interactive_command(cmd: List[str]) -> Tuple[bytes, int]: stderr=slave_fd, bufsize=0, close_fds=True, - env=env + env=env, + preexec_fn=os.setsid # Create new process group for proper signal handling ) os.close(slave_fd) # Close slave end in the parent process. captured_data = [] + start_time = time.time() + was_terminated = False + + def check_timeout(): + elapsed = time.time() - start_time + if elapsed > 3 * expected_runtime_seconds: + # Force kill after 3x expected runtime + os.killpg(os.getpgid(proc.pid), signal.SIGKILL) + return True + elif elapsed > 2 * expected_runtime_seconds: + # Graceful termination after 2x expected runtime + os.killpg(os.getpgid(proc.pid), signal.SIGTERM) + return True + return False # If we're in an interactive TTY, set raw mode and forward input. if stdin_fd is not None and sys.stdin.isatty(): @@ -105,7 +128,10 @@ def run_interactive_command(cmd: List[str]) -> Tuple[bytes, int]: tty.setraw(stdin_fd) try: while True: - rlist, _, _ = select.select([master_fd, stdin_fd], [], []) + if check_timeout(): + was_terminated = True + break + rlist, _, _ = select.select([master_fd, stdin_fd], [], [], 1.0) # 1 second timeout for select if master_fd in rlist: try: data = os.read(master_fd, 1024) @@ -137,7 +163,13 @@ def run_interactive_command(cmd: List[str]) -> Tuple[bytes, int]: # Non-interactive mode (e.g., during unit tests). try: while True: + if check_timeout(): + was_terminated = True + break try: + rlist, _, _ = select.select([master_fd], [], [], 1.0) # 1 second timeout for select + if not rlist: + continue data = os.read(master_fd, 1024) except OSError as e: if e.errno == errno.EIO: @@ -165,6 +197,11 @@ def run_interactive_command(cmd: List[str]) -> Tuple[bytes, int]: trimmed_lines = [line for line in all_lines if line.strip()] final_output = "\n".join(trimmed_lines) + # Add timeout message if process was terminated + if was_terminated: + timeout_msg = f"\n[Process exceeded timeout ({expected_runtime_seconds} seconds expected)]" + final_output += timeout_msg + # Limit output to last 8000 bytes final_output = final_output[-8000:]