diff --git a/ra_aid/proc/interactive.py b/ra_aid/proc/interactive.py index 37dd3b2..b9af2c8 100644 --- a/ra_aid/proc/interactive.py +++ b/ra_aid/proc/interactive.py @@ -75,13 +75,15 @@ def run_interactive_command(cmd: List[str], expected_runtime_seconds: int = 30) # Open a new pseudo-tty. master_fd, slave_fd = os.openpty() + # Set master_fd to non-blocking to avoid indefinite blocking. + os.set_blocking(master_fd, False) try: stdin_fd = sys.stdin.fileno() except (AttributeError, io.UnsupportedOperation): stdin_fd = None - # Set up environment variables for the subprocess using detected terminal size + # Set up environment variables for the subprocess using detected terminal size. env = os.environ.copy() env.update({ 'DEBIAN_FRONTEND': 'noninteractive', @@ -106,7 +108,7 @@ def run_interactive_command(cmd: List[str], expected_runtime_seconds: int = 30) bufsize=0, close_fds=True, env=env, - preexec_fn=os.setsid # Create new process group for proper signal handling + preexec_fn=os.setsid # Create new process group for proper signal handling. ) os.close(slave_fd) # Close slave end in the parent process. @@ -117,16 +119,14 @@ def run_interactive_command(cmd: List[str], expected_runtime_seconds: int = 30) def check_timeout(): elapsed = time.time() - start_time if elapsed > 3 * expected_runtime_seconds: - # Force kill after 3x expected runtime os.killpg(os.getpgid(proc.pid), signal.SIGKILL) return True elif elapsed > 2 * expected_runtime_seconds: - # Graceful termination after 2x expected runtime os.killpg(os.getpgid(proc.pid), signal.SIGTERM) return True return False - # If we're in an interactive TTY, set raw mode and forward input. + # Interactive mode: forward input if running in a TTY. if stdin_fd is not None and sys.stdin.isatty(): old_settings = termios.tcgetattr(stdin_fd) tty.setraw(stdin_fd) @@ -135,7 +135,8 @@ def run_interactive_command(cmd: List[str], expected_runtime_seconds: int = 30) if check_timeout(): was_terminated = True break - rlist, _, _ = select.select([master_fd, stdin_fd], [], [], 1.0) # 1 second timeout for select + # Use a finite timeout to avoid indefinite blocking. + rlist, _, _ = select.select([master_fd, stdin_fd], [], [], 1.0) if master_fd in rlist: try: data = os.read(master_fd, 1024) @@ -144,12 +145,11 @@ def run_interactive_command(cmd: List[str], expected_runtime_seconds: int = 30) break else: raise - if not data: + if not data: # EOF detected. break captured_data.append(data) - # Update pyte's screen state. - stream.feed(data.decode("utf-8", errors="ignore")) - # Write to stdout for live output. + decoded = data.decode("utf-8", errors="ignore") + stream.feed(decoded) os.write(1, data) if stdin_fd in rlist: try: @@ -157,33 +157,33 @@ def run_interactive_command(cmd: List[str], expected_runtime_seconds: int = 30) except OSError: input_data = b"" if input_data: - # Forward raw keystrokes directly to the subprocess. os.write(master_fd, input_data) except KeyboardInterrupt: proc.terminate() finally: termios.tcsetattr(stdin_fd, termios.TCSADRAIN, old_settings) else: - # Non-interactive mode (e.g., during unit tests). + # Non-interactive mode. try: while True: if check_timeout(): was_terminated = True break + rlist, _, _ = select.select([master_fd], [], [], 1.0) + if not rlist: + continue try: - rlist, _, _ = select.select([master_fd], [], [], 1.0) # 1 second timeout for select - if not rlist: - continue data = os.read(master_fd, 1024) except OSError as e: if e.errno == errno.EIO: break else: raise - if not data: + if not data: # EOF detected. break captured_data.append(data) - stream.feed(data.decode("utf-8", errors="ignore")) + decoded = data.decode("utf-8", errors="ignore") + stream.feed(decoded) os.write(1, data) except KeyboardInterrupt: proc.terminate() @@ -201,12 +201,12 @@ def run_interactive_command(cmd: List[str], expected_runtime_seconds: int = 30) trimmed_lines = [line for line in all_lines if line.strip()] final_output = "\n".join(trimmed_lines) - # Add timeout message if process was terminated + # Add timeout message if process was terminated due to timeout. if was_terminated: timeout_msg = f"\n[Process exceeded timeout ({expected_runtime_seconds} seconds expected)]" final_output += timeout_msg - # Limit output to last 8000 bytes + # Limit output to the last 8000 bytes. final_output = final_output[-8000:] return final_output.encode("utf-8"), proc.returncode