From ea44e5dd6a158d6c4264672c089bad9f3b30bf10 Mon Sep 17 00:00:00 2001 From: AI Christianson Date: Tue, 25 Feb 2025 13:14:26 -0500 Subject: [PATCH] linux fixes --- ra_aid/proc/interactive.py | 316 ++++++++++++++++++++----------------- 1 file changed, 175 insertions(+), 141 deletions(-) diff --git a/ra_aid/proc/interactive.py b/ra_aid/proc/interactive.py index 4943b51..a425ed0 100644 --- a/ra_aid/proc/interactive.py +++ b/ra_aid/proc/interactive.py @@ -42,6 +42,7 @@ else: import termios import fcntl import pty + import tty def get_terminal_size(): """Get the current terminal size.""" @@ -101,90 +102,105 @@ def create_process(cmd: List[str]) -> Tuple[subprocess.Popen, Optional[int]]: return proc, master_fd def run_interactive_command( - cmd: List[str], - expected_runtime_seconds: int = 1800, + cmd: List[str], + expected_runtime_seconds: int = 1800, ratio: float = 0.5 ) -> Tuple[bytes, int]: """ Runs an interactive command with a pseudo-tty, capturing final scrollback history. - - Assumptions and constraints: - - Running on a Linux system or Windows. - - `cmd` is a non-empty list where cmd[0] is the executable. - - The executable is on PATH. - + Args: - cmd: A list containing the command and its arguments. - expected_runtime_seconds: Expected runtime in seconds, defaults to 1800. - If process exceeds 2x this value, it will be terminated gracefully. - If process exceeds 3x this value, it will be killed forcefully. - Must be between 1 and 1800 seconds (30 minutes). - ratio: Ratio of history to keep from top vs bottom (default: 0.5) - + cmd: A list containing the command and its arguments. + expected_runtime_seconds: Expected runtime in seconds, defaults to 1800. + ratio: Ratio of history to keep from top vs bottom (default: 0.5) + Returns: - A tuple of (captured_output, return_code), where captured_output is a UTF-8 encoded - bytes object containing the trimmed non-empty history lines from the terminal session. - - Raises: - ValueError: If no command is provided. - FileNotFoundError: If the command is not found in PATH. - ValueError: If expected_runtime_seconds is less than or equal to 0 or greater than 1800. - RuntimeError: If an error occurs during execution. + A tuple of (captured_output, return_code) """ if not cmd: - raise ValueError("No command provided.") - if shutil.which(cmd[0]) is None: - raise FileNotFoundError(f"Command '{cmd[0]}' not found in PATH.") - if expected_runtime_seconds <= 0 or expected_runtime_seconds > 1800: - raise ValueError( - "expected_runtime_seconds must be between 1 and 1800 seconds (30 minutes)" - ) - + raise ValueError("No command provided") + if not 0 < expected_runtime_seconds <= 1800: + raise ValueError("Expected runtime must be between 1 and 1800 seconds") + try: - term_size = get_terminal_size() + term_size = os.get_terminal_size() cols, rows = term_size.columns, term_size.lines except OSError: cols, rows = 80, 24 - - # Set up pyte screen and stream to capture terminal output. + screen = pyte.HistoryScreen(cols, rows, history=2000, ratio=ratio) stream = pyte.Stream(screen) - proc, master_fd = create_process(cmd) + # Set up environment variables for the subprocess + env = os.environ.copy() + env.update({ + "DEBIAN_FRONTEND": "noninteractive", + "GIT_PAGER": "", + "PYTHONUNBUFFERED": "1", + "CI": "true", + "LANG": "C.UTF-8", + "LC_ALL": "C.UTF-8", + "COLUMNS": str(cols), + "LINES": str(rows), + "FORCE_COLOR": "1", + "GIT_TERMINAL_PROMPT": "0", + "PYTHONDONTWRITEBYTECODE": "1", + "NODE_OPTIONS": "--unhandled-rejections=strict", + }) + # Create process with proper PTY handling + if sys.platform == "win32": + proc = subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env, + bufsize=0, + creationflags=subprocess.CREATE_NEW_PROCESS_GROUP + ) + master_fd = None + else: + master_fd, slave_fd = os.openpty() + os.set_blocking(master_fd, False) + + proc = subprocess.Popen( + cmd, + stdin=slave_fd, + stdout=slave_fd, + stderr=slave_fd, + env=env, + bufsize=0, + close_fds=True, + preexec_fn=os.setsid + ) + os.close(slave_fd) + + try: + stdin_fd = sys.stdin.fileno() + except (AttributeError, io.UnsupportedOperation): + stdin_fd = None + captured_data = [] start_time = time.time() was_terminated = False - timeout_type = None - + def check_timeout(): - nonlocal timeout_type elapsed = time.time() - start_time if elapsed > 3 * expected_runtime_seconds: if sys.platform == "win32": - print("\nProcess exceeded hard timeout limit, forcefully terminating...") - proc.terminate() - time.sleep(0.5) - if proc.poll() is None: - print("Process did not respond to termination, killing...") - proc.kill() + proc.kill() else: - print("\nProcess exceeded hard timeout limit, sending SIGKILL...") os.killpg(os.getpgid(proc.pid), signal.SIGKILL) - timeout_type = "hard_timeout" return True elif elapsed > 2 * expected_runtime_seconds: if sys.platform == "win32": - print("\nProcess exceeded soft timeout limit, attempting graceful termination...") proc.terminate() else: - print("\nProcess exceeded soft timeout limit, sending SIGTERM...") os.killpg(os.getpgid(proc.pid), signal.SIGTERM) - timeout_type = "soft_timeout" return True return False - # Interactive mode: forward input if running in a TTY. if sys.platform == "win32": # Windows handling try: @@ -192,112 +208,130 @@ def run_interactive_command( if check_timeout(): was_terminated = True break - + + if proc.poll() is not None: + break + try: - # Check stdout with proper error handling - stdout_data = proc.stdout.read1(1024) - if stdout_data: - captured_data.append(stdout_data) - try: - stream.feed(stdout_data.decode(errors='ignore')) - except Exception as e: - print(f"Warning: Error processing stdout: {e}") - - # Check stderr with proper error handling - stderr_data = proc.stderr.read1(1024) - if stderr_data: - captured_data.append(stderr_data) - try: - stream.feed(stderr_data.decode(errors='ignore')) - except Exception as e: - print(f"Warning: Error processing stderr: {e}") - - # Check for input with proper error handling + output = proc.stdout.read1(1024) + if output: + captured_data.append(output) + stream.feed(output.decode('utf-8', errors='replace')) + os.write(1, output) # Write to stdout + if msvcrt.kbhit(): - try: - char = msvcrt.getch() - proc.stdin.write(char) - proc.stdin.flush() - except (IOError, OSError) as e: - print(f"Warning: Error handling keyboard input: {e}") - break - + char = msvcrt.getch() + proc.stdin.write(char) + proc.stdin.flush() except (IOError, OSError) as e: - if isinstance(e, OSError) and e.winerror == 6: # Invalid handle - break - print(f"Warning: I/O error during process communication: {e}") break - - except Exception as e: - print(f"Error in Windows process handling: {e}") - proc.terminate() - else: - # Unix handling - import tty - try: - old_settings = termios.tcgetattr(sys.stdin) - tty.setraw(sys.stdin) - while True: - if check_timeout(): - was_terminated = True - break - rlist, _, _ = select.select([master_fd, sys.stdin], [], [], 0.1) - - for fd in rlist: - try: - if fd == master_fd: - data = os.read(master_fd, 1024) - if not data: - break - captured_data.append(data) - stream.feed(data.decode(errors='ignore')) - else: - data = os.read(fd, 1024) - os.write(master_fd, data) - except (IOError, OSError): - break - + except KeyboardInterrupt: proc.terminate() - finally: - termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings) - - if not sys.platform == "win32" and master_fd is not None: + + else: + # Unix handling with proper TTY passthrough + if stdin_fd is not None and sys.stdin.isatty(): + old_settings = termios.tcgetattr(stdin_fd) + tty.setraw(stdin_fd) + try: + while True: + if check_timeout(): + was_terminated = True + break + + rlist, _, _ = select.select([master_fd, stdin_fd], [], [], 0.1) + + if master_fd in rlist: + try: + data = os.read(master_fd, 1024) + except OSError as e: + if e.errno == errno.EIO: + break + raise + + if not data: + break + + captured_data.append(data) + stream.feed(data.decode('utf-8', errors='replace')) + os.write(1, data) # Write to stdout + + if stdin_fd in rlist: + try: + input_data = os.read(stdin_fd, 1024) + except OSError: + input_data = b"" + if input_data: + os.write(master_fd, input_data) + + except KeyboardInterrupt: + proc.terminate() + finally: + termios.tcsetattr(stdin_fd, termios.TCSADRAIN, old_settings) + else: + # Non-interactive mode + try: + while True: + if check_timeout(): + was_terminated = True + break + + rlist, _, _ = select.select([master_fd], [], [], 0.1) + if not rlist: + continue + + try: + data = os.read(master_fd, 1024) + except OSError as e: + if e.errno == errno.EIO: + break + raise + + if not data: + break + + captured_data.append(data) + stream.feed(data.decode('utf-8', errors='replace')) + os.write(1, data) # Write to stdout + + except KeyboardInterrupt: + proc.terminate() + + # Cleanup + if master_fd is not None: os.close(master_fd) - - proc.wait() - - # Assemble full scrollback: combine history.top, the current display, and history.bottom. + + if proc.poll() is None: + try: + proc.terminate() + proc.wait(timeout=1.0) + except subprocess.TimeoutExpired: + proc.kill() + proc.wait() + + # Get the final screen content def render_line(line, width): return ''.join(char.data for char in line[:width]).rstrip() - + # Combine history and current screen content - final_output = [] - - # Add lines from history - history_lines = [render_line(line, cols) for line in screen.history.top] - final_output.extend(line for line in history_lines if line.strip()) - - # Add current screen content - screen_lines = [render_line(line, cols) for line in screen.display] - final_output.extend(line for line in screen_lines if line.strip()) - - # Add bottom history + top_lines = [render_line(line, cols) for line in screen.history.top] bottom_lines = [render_line(line, cols) for line in screen.history.bottom] - final_output.extend(line for line in bottom_lines if line.strip()) - + display_lines = screen.display + all_lines = top_lines + display_lines + bottom_lines + + # Filter empty lines + trimmed_lines = [line for line in all_lines if line.strip()] + final_output = '\n'.join(trimmed_lines) + # Add timeout message if process was terminated if was_terminated: - if timeout_type == "hard_timeout": - timeout_msg = f"\n[Process forcefully terminated after exceeding {3 * expected_runtime_seconds:.1f} seconds (expected: {expected_runtime_seconds} seconds)]" - else: - timeout_msg = f"\n[Process gracefully terminated after exceeding {2 * expected_runtime_seconds:.1f} seconds (expected: {expected_runtime_seconds} seconds)]" - final_output.append(timeout_msg) - + timeout_msg = f"\n[Process exceeded timeout ({expected_runtime_seconds} seconds expected)]" + final_output += timeout_msg + # Limit output size final_output = final_output[-8000:] - final_output = '\n'.join(final_output) - + return final_output.encode('utf-8'), proc.returncode