From 804ebd76a500bc7c06bd285dff030cf41f0bff3e Mon Sep 17 00:00:00 2001 From: AI Christianson Date: Tue, 25 Feb 2025 14:15:28 -0500 Subject: [PATCH] fix windows --- ra_aid/proc/interactive.py | 298 ++++++++++++++++++++++++++++--------- 1 file changed, 226 insertions(+), 72 deletions(-) diff --git a/ra_aid/proc/interactive.py b/ra_aid/proc/interactive.py index 2d5beac..af66818 100644 --- a/ra_aid/proc/interactive.py +++ b/ra_aid/proc/interactive.py @@ -12,19 +12,86 @@ where output is a bytes object (UTF-8 encoded). import errno import io import os -import select import shutil import signal import subprocess import sys import time -from typing import List, Tuple +from typing import List, Optional, Tuple, Union import pyte from pyte.screens import HistoryScreen -import termios -import tty +# Platform-specific imports +if sys.platform == "win32": + import msvcrt + import threading +else: + import select + import termios + import tty + + +def create_process( + cmd: List[str], env: Optional[dict] = None, cols: Optional[int] = None, rows: Optional[int] = None +) -> Tuple[subprocess.Popen, Optional[int]]: + """ + Create a subprocess with appropriate settings for the current platform. + + Args: + cmd: Command to execute as a list of strings + env: Environment variables dictionary, defaults to os.environ.copy() + cols: Number of columns for the terminal, defaults to current terminal width + rows: Number of rows for the terminal, defaults to current terminal height + + Returns: + On Unix: (process, master_fd) where master_fd is the file descriptor for the pty master + On Windows: (process, None) as Windows doesn't use ptys + """ + # Set default values if not provided + if env is None: + env = os.environ.copy() + if cols is None or rows is None: + default_cols, default_rows = get_terminal_size() + if cols is None: + cols = default_cols + if rows is None: + rows = default_rows + if sys.platform == "win32": + # Windows-specific process creation + startupinfo = subprocess.STARTUPINFO() + startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW + + proc = subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + bufsize=0, + env=env, + startupinfo=startupinfo, + universal_newlines=False, + ) + return proc, None + else: + # Unix-specific process creation with pty + master_fd, slave_fd = os.openpty() + # Set master_fd to non-blocking to avoid indefinite blocking + os.set_blocking(master_fd, False) + + proc = subprocess.Popen( + cmd, + stdin=slave_fd, + stdout=slave_fd, + stderr=slave_fd, + bufsize=0, + close_fds=True, + env=env, + preexec_fn=os.setsid, # Create new process group for proper signal handling + ) + os.close(slave_fd) # Close slave end in the parent process + + return proc, master_fd def get_terminal_size() -> Tuple[int, int]: @@ -121,16 +188,6 @@ def run_interactive_command( screen = HistoryScreen(cols, rows, history=2000, ratio=0.5) stream = pyte.Stream(screen) - # Open a new pseudo-tty. - master_fd, slave_fd = os.openpty() - # Set master_fd to non-blocking to avoid indefinite blocking. - os.set_blocking(master_fd, False) - - try: - stdin_fd = sys.stdin.fileno() - except (AttributeError, io.UnsupportedOperation): - stdin_fd = None - # Set up environment variables for the subprocess using detected terminal size. env = os.environ.copy() env.update( @@ -150,17 +207,8 @@ def run_interactive_command( } ) - proc = subprocess.Popen( - cmd, - stdin=slave_fd, - stdout=slave_fd, - stderr=slave_fd, - bufsize=0, - close_fds=True, - env=env, - preexec_fn=os.setsid, # Create new process group for proper signal handling. - ) - os.close(slave_fd) # Close slave end in the parent process. + # Create process based on platform + proc, master_fd = create_process(cmd, env, cols, rows) captured_data = [] start_time = time.time() @@ -169,25 +217,165 @@ def run_interactive_command( def check_timeout(): elapsed = time.time() - start_time if elapsed > 3 * expected_runtime_seconds: - os.killpg(os.getpgid(proc.pid), signal.SIGKILL) + if sys.platform == "win32": + proc.kill() + else: + os.killpg(os.getpgid(proc.pid), signal.SIGKILL) return True elif elapsed > 2 * expected_runtime_seconds: - os.killpg(os.getpgid(proc.pid), signal.SIGTERM) + if sys.platform == "win32": + proc.terminate() + else: + os.killpg(os.getpgid(proc.pid), signal.SIGTERM) return True return False - # Interactive mode: forward input if running in a TTY. - if stdin_fd is not None and sys.stdin.isatty(): - old_settings = termios.tcgetattr(stdin_fd) - tty.setraw(stdin_fd) + if sys.platform == "win32": + # Windows implementation using threads for I/O + running = True + stdin_thread = None + + def read_stdout(): + nonlocal running + while running and proc.poll() is None: + try: + data = proc.stdout.read(1024) + if not data: + break + captured_data.append(data) + decoded = data.decode("utf-8", errors="ignore") + stream.feed(decoded) + sys.stdout.buffer.write(data) + sys.stdout.buffer.flush() + except (OSError, IOError): + break + except Exception as e: + print(f"Error reading stdout: {e}", file=sys.stderr) + break + + def read_stderr(): + nonlocal running + while running and proc.poll() is None: + try: + data = proc.stderr.read(1024) + if not data: + break + captured_data.append(data) + decoded = data.decode("utf-8", errors="ignore") + stream.feed(decoded) + sys.stderr.buffer.write(data) + sys.stderr.buffer.flush() + except (OSError, IOError): + break + except Exception as e: + print(f"Error reading stderr: {e}", file=sys.stderr) + break + + def handle_input(): + nonlocal running + try: + while running and proc.poll() is None: + if msvcrt.kbhit(): + char = msvcrt.getch() + proc.stdin.write(char) + proc.stdin.flush() + time.sleep(0.01) # Small sleep to prevent CPU hogging + except (OSError, IOError): + pass + except Exception as e: + print(f"Error handling input: {e}", file=sys.stderr) + + # Start I/O threads + stdout_thread = threading.Thread(target=read_stdout) + stderr_thread = threading.Thread(target=read_stderr) + stdout_thread.daemon = True + stderr_thread.daemon = True + stdout_thread.start() + stderr_thread.start() + + # Only start stdin thread if we're in an interactive terminal + if sys.stdin.isatty(): + stdin_thread = threading.Thread(target=handle_input) + stdin_thread.daemon = True + stdin_thread.start() + try: - while True: + # Main thread monitors timeout + while proc.poll() is None: if check_timeout(): was_terminated = True break - # Use a finite timeout to avoid indefinite blocking. - rlist, _, _ = select.select([master_fd, stdin_fd], [], [], 1.0) - if master_fd in rlist: + time.sleep(0.1) + except KeyboardInterrupt: + proc.terminate() + finally: + running = False + # Wait for threads to finish + stdout_thread.join(1.0) + stderr_thread.join(1.0) + if stdin_thread: + stdin_thread.join(1.0) + + # Close pipes + if proc.stdout: + proc.stdout.close() + if proc.stderr: + proc.stderr.close() + if proc.stdin: + proc.stdin.close() + else: + # Unix implementation using select and pty + try: + stdin_fd = sys.stdin.fileno() + except (AttributeError, io.UnsupportedOperation): + stdin_fd = None + + # Interactive mode: forward input if running in a TTY. + if stdin_fd is not None and sys.stdin.isatty(): + old_settings = termios.tcgetattr(stdin_fd) + tty.setraw(stdin_fd) + try: + while True: + if check_timeout(): + was_terminated = True + break + # Use a finite timeout to avoid indefinite blocking. + rlist, _, _ = select.select([master_fd, stdin_fd], [], [], 1.0) + if master_fd in rlist: + try: + data = os.read(master_fd, 1024) + except OSError as e: + if e.errno == errno.EIO: + break + else: + raise + if not data: # EOF detected. + break + captured_data.append(data) + decoded = data.decode("utf-8", errors="ignore") + stream.feed(decoded) + os.write(1, data) + if stdin_fd in rlist: + try: + input_data = os.read(stdin_fd, 1024) + except OSError: + input_data = b"" + if input_data: + os.write(master_fd, input_data) + except KeyboardInterrupt: + proc.terminate() + finally: + termios.tcsetattr(stdin_fd, termios.TCSADRAIN, old_settings) + else: + # Non-interactive mode. + try: + while True: + if check_timeout(): + was_terminated = True + break + rlist, _, _ = select.select([master_fd], [], [], 1.0) + if not rlist: + continue try: data = os.read(master_fd, 1024) except OSError as e: @@ -201,44 +389,10 @@ def run_interactive_command( decoded = data.decode("utf-8", errors="ignore") stream.feed(decoded) os.write(1, data) - if stdin_fd in rlist: - try: - input_data = os.read(stdin_fd, 1024) - except OSError: - input_data = b"" - if input_data: - os.write(master_fd, input_data) - except KeyboardInterrupt: - proc.terminate() - finally: - termios.tcsetattr(stdin_fd, termios.TCSADRAIN, old_settings) - else: - # Non-interactive mode. - try: - while True: - if check_timeout(): - was_terminated = True - break - rlist, _, _ = select.select([master_fd], [], [], 1.0) - if not rlist: - continue - try: - data = os.read(master_fd, 1024) - except OSError as e: - if e.errno == errno.EIO: - break - else: - raise - if not data: # EOF detected. - break - captured_data.append(data) - decoded = data.decode("utf-8", errors="ignore") - stream.feed(decoded) - os.write(1, data) - except KeyboardInterrupt: - proc.terminate() + except KeyboardInterrupt: + proc.terminate() - os.close(master_fd) + os.close(master_fd) # Wait for the process to finish proc.wait()