From d163a74c472828e511bf5871339736d0ac1ada8f Mon Sep 17 00:00:00 2001 From: Mark Varkevisser Date: Tue, 25 Feb 2025 10:41:29 -0800 Subject: [PATCH] Add Windows compatibility improvements (#105) * Add Windows compatibility improvements 1. Add error handling for Windows-specific modules 2. Update README with Windows installation instructions 3. Add Windows-specific tests 4. Improve cross-platform support in interactive.py * Fix: Add missing subprocess import in Windows compatibility tests * Improve Windows compatibility: 1. Add detailed error handling for Windows I/O operations 2. Enhance timeout messages with more descriptive information 3. Add comprehensive comments explaining Windows-specific code * Fix WebUI: Improve message display, add syntax highlighting, animations, and fix WebSocket communication --- README.md | 25 ++ ra_aid/dependencies.py | 10 +- ra_aid/proc/interactive.py | 322 +++++++++------ ra_aid/webui/server.py | 383 +++++++++--------- ra_aid/webui/static/index.html | 254 +++++++++--- ra_aid/webui/static/script.js | 326 ++++++++------- .../ra_aid/proc/test_windows_compatibility.py | 66 +++ 7 files changed, 866 insertions(+), 520 deletions(-) create mode 100644 tests/ra_aid/proc/test_windows_compatibility.py diff --git a/README.md b/README.md index 4c23e53..8da3a22 100644 --- a/README.md +++ b/README.md @@ -97,6 +97,31 @@ What sets RA.Aid apart is its ability to handle complex programming tasks that e ## Installation +### Windows Installation +1. Install Python 3.8 or higher from [python.org](https://www.python.org/downloads/) +2. Install required system dependencies: + ```powershell + # Install Chocolatey if not already installed (run in admin PowerShell) + Set-ExecutionPolicy Bypass -Scope Process -Force; [System.Net.ServicePointManager]::SecurityProtocol = [System.Net.ServicePointManager]::SecurityProtocol -bor 3072; iex ((New-Object System.Net.WebClient).DownloadString('https://community.chocolatey.org/install.ps1')) + + # Install ripgrep using Chocolatey + choco install ripgrep + ``` +3. Install RA.Aid: + ```powershell + pip install ra-aid + ``` +4. Install Windows-specific dependencies: + ```powershell + pip install pywin32 + ``` +5. Set up your API keys in a `.env` file: + ```env + ANTHROPIC_API_KEY=your_anthropic_key + OPENAI_API_KEY=your_openai_key + ``` + +### Unix/Linux Installation RA.Aid can be installed directly using pip: ```bash diff --git a/ra_aid/dependencies.py b/ra_aid/dependencies.py index 483b673..367e826 100644 --- a/ra_aid/dependencies.py +++ b/ra_aid/dependencies.py @@ -2,6 +2,7 @@ import os import sys +import subprocess from abc import ABC, abstractmethod from ra_aid import print_error @@ -21,8 +22,13 @@ class RipGrepDependency(Dependency): def check(self): """Check if ripgrep is installed.""" - result = os.system("rg --version > /dev/null 2>&1") - if result != 0: + try: + result = subprocess.run(['rg', '--version'], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL) + if result.returncode != 0: + raise FileNotFoundError() + except (FileNotFoundError, subprocess.SubprocessError): print_error("Required dependency 'ripgrep' is not installed.") print("Please install ripgrep:") print(" - Ubuntu/Debian: sudo apt-get install ripgrep") diff --git a/ra_aid/proc/interactive.py b/ra_aid/proc/interactive.py index ef21b33..4943b51 100644 --- a/ra_aid/proc/interactive.py +++ b/ra_aid/proc/interactive.py @@ -3,51 +3,123 @@ Module for running interactive subprocesses with output capture, with full raw input passthrough for interactive commands. -It uses a pseudo-tty and integrates pyte's HistoryScreen to simulate +It uses a pseudo-tty on Unix systems and direct pipes on Windows to simulate a terminal and capture the final scrollback history (non-blank lines). -The interface remains compatible with external callers expecting a tuple (output, return_code), -where output is a bytes object (UTF-8 encoded). """ import errno import io import os import select -import shutil import signal import subprocess import sys -import termios import time -import tty -from typing import List, Tuple +import shutil +from typing import List, Tuple, Optional import pyte -from pyte.screens import HistoryScreen +# Windows-specific imports +if sys.platform == "win32": + try: + # msvcrt: Provides Windows console I/O functionality + import msvcrt + # win32pipe, win32file: For low-level pipe operations + import win32pipe + import win32file + # win32con: Windows API constants + import win32con + # win32process: Process management on Windows + import win32process + except ImportError as e: + print("Error: Required Windows dependencies not found.") + print("Please install the required packages using:") + print(" pip install pywin32") + sys.exit(1) +else: + # Unix-specific imports for terminal handling + import termios + import fcntl + import pty -def render_line(line, columns: int) -> str: - """Render a single screen line from the pyte buffer (a mapping of column to Char).""" - return "".join(line[x].data for x in range(columns)) +def get_terminal_size(): + """Get the current terminal size.""" + if sys.platform == "win32": + import shutil + size = shutil.get_terminal_size() + return size.columns, size.lines + else: + import struct + try: + with open(sys.stdout.fileno(), 'wb', buffering=0) as fd: + size = struct.unpack('hh', fcntl.ioctl(fd, termios.TIOCGWINSZ, '1234')) + return size[1], size[0] + except (IOError, AttributeError): + return 80, 24 +def create_process(cmd: List[str]) -> Tuple[subprocess.Popen, Optional[int]]: + """Create a subprocess with appropriate handling for the platform. + + On Windows: + - Uses STARTUPINFO to hide the console window + - Creates a new process group for proper signal handling + - Returns direct pipe handles for I/O + + On Unix: + - Creates a pseudo-terminal (PTY) for proper terminal emulation + - Sets up process group for signal handling + - Returns master PTY file descriptor for I/O + """ + if sys.platform == "win32": + # Windows process creation with hidden console + startupinfo = subprocess.STARTUPINFO() + startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW # Hide the console window + + # Create process with proper pipe handling + proc = subprocess.Popen( + cmd, + stdin=subprocess.PIPE, # Allow writing to stdin + stdout=subprocess.PIPE, # Capture stdout + stderr=subprocess.PIPE, # Capture stderr + startupinfo=startupinfo, + # CREATE_NEW_PROCESS_GROUP allows proper Ctrl+C handling + creationflags=subprocess.CREATE_NEW_PROCESS_GROUP + ) + return proc, None # No PTY master_fd needed on Windows + else: + # Unix process creation with PTY + master_fd, slave_fd = pty.openpty() + proc = subprocess.Popen( + cmd, + stdin=slave_fd, + stdout=slave_fd, + stderr=slave_fd, + preexec_fn=os.setsid + ) + os.close(slave_fd) + return proc, master_fd def run_interactive_command( - cmd: List[str], expected_runtime_seconds: int = 30 + cmd: List[str], + expected_runtime_seconds: int = 1800, + ratio: float = 0.5 ) -> Tuple[bytes, int]: """ Runs an interactive command with a pseudo-tty, capturing final scrollback history. Assumptions and constraints: - - Running on a Linux system. + - Running on a Linux system or Windows. - `cmd` is a non-empty list where cmd[0] is the executable. - The executable is on PATH. Args: cmd: A list containing the command and its arguments. - expected_runtime_seconds: Expected runtime in seconds, defaults to 30. + expected_runtime_seconds: Expected runtime in seconds, defaults to 1800. If process exceeds 2x this value, it will be terminated gracefully. If process exceeds 3x this value, it will be killed forcefully. Must be between 1 and 1800 seconds (30 minutes). + ratio: Ratio of history to keep from top vs bottom (default: 0.5) Returns: A tuple of (captured_output, return_code), where captured_output is a UTF-8 encoded @@ -69,154 +141,164 @@ def run_interactive_command( ) try: - term_size = os.get_terminal_size() + term_size = get_terminal_size() cols, rows = term_size.columns, term_size.lines except OSError: cols, rows = 80, 24 # Set up pyte screen and stream to capture terminal output. - screen = HistoryScreen(cols, rows, history=2000, ratio=0.5) + screen = pyte.HistoryScreen(cols, rows, history=2000, ratio=ratio) stream = pyte.Stream(screen) - # Open a new pseudo-tty. - master_fd, slave_fd = os.openpty() - # Set master_fd to non-blocking to avoid indefinite blocking. - os.set_blocking(master_fd, False) - - try: - stdin_fd = sys.stdin.fileno() - except (AttributeError, io.UnsupportedOperation): - stdin_fd = None - - # Set up environment variables for the subprocess using detected terminal size. - env = os.environ.copy() - env.update( - { - "DEBIAN_FRONTEND": "noninteractive", - "GIT_PAGER": "", - "PYTHONUNBUFFERED": "1", - "CI": "true", - "LANG": "C.UTF-8", - "LC_ALL": "C.UTF-8", - "COLUMNS": str(cols), - "LINES": str(rows), - "FORCE_COLOR": "1", - "GIT_TERMINAL_PROMPT": "0", - "PYTHONDONTWRITEBYTECODE": "1", - "NODE_OPTIONS": "--unhandled-rejections=strict", - } - ) - - proc = subprocess.Popen( - cmd, - stdin=slave_fd, - stdout=slave_fd, - stderr=slave_fd, - bufsize=0, - close_fds=True, - env=env, - preexec_fn=os.setsid, # Create new process group for proper signal handling. - ) - os.close(slave_fd) # Close slave end in the parent process. + proc, master_fd = create_process(cmd) captured_data = [] start_time = time.time() was_terminated = False + timeout_type = None def check_timeout(): + nonlocal timeout_type elapsed = time.time() - start_time if elapsed > 3 * expected_runtime_seconds: - os.killpg(os.getpgid(proc.pid), signal.SIGKILL) + if sys.platform == "win32": + print("\nProcess exceeded hard timeout limit, forcefully terminating...") + proc.terminate() + time.sleep(0.5) + if proc.poll() is None: + print("Process did not respond to termination, killing...") + proc.kill() + else: + print("\nProcess exceeded hard timeout limit, sending SIGKILL...") + os.killpg(os.getpgid(proc.pid), signal.SIGKILL) + timeout_type = "hard_timeout" return True elif elapsed > 2 * expected_runtime_seconds: - os.killpg(os.getpgid(proc.pid), signal.SIGTERM) + if sys.platform == "win32": + print("\nProcess exceeded soft timeout limit, attempting graceful termination...") + proc.terminate() + else: + print("\nProcess exceeded soft timeout limit, sending SIGTERM...") + os.killpg(os.getpgid(proc.pid), signal.SIGTERM) + timeout_type = "soft_timeout" return True return False # Interactive mode: forward input if running in a TTY. - if stdin_fd is not None and sys.stdin.isatty(): - old_settings = termios.tcgetattr(stdin_fd) - tty.setraw(stdin_fd) + if sys.platform == "win32": + # Windows handling try: while True: if check_timeout(): was_terminated = True break - # Use a finite timeout to avoid indefinite blocking. - rlist, _, _ = select.select([master_fd, stdin_fd], [], [], 1.0) - if master_fd in rlist: - try: - data = os.read(master_fd, 1024) - except OSError as e: - if e.errno == errno.EIO: + + try: + # Check stdout with proper error handling + stdout_data = proc.stdout.read1(1024) + if stdout_data: + captured_data.append(stdout_data) + try: + stream.feed(stdout_data.decode(errors='ignore')) + except Exception as e: + print(f"Warning: Error processing stdout: {e}") + + # Check stderr with proper error handling + stderr_data = proc.stderr.read1(1024) + if stderr_data: + captured_data.append(stderr_data) + try: + stream.feed(stderr_data.decode(errors='ignore')) + except Exception as e: + print(f"Warning: Error processing stderr: {e}") + + # Check for input with proper error handling + if msvcrt.kbhit(): + try: + char = msvcrt.getch() + proc.stdin.write(char) + proc.stdin.flush() + except (IOError, OSError) as e: + print(f"Warning: Error handling keyboard input: {e}") break - else: - raise - if not data: # EOF detected. + + except (IOError, OSError) as e: + if isinstance(e, OSError) and e.winerror == 6: # Invalid handle break - captured_data.append(data) - decoded = data.decode("utf-8", errors="ignore") - stream.feed(decoded) - os.write(1, data) - if stdin_fd in rlist: + print(f"Warning: I/O error during process communication: {e}") + break + + except Exception as e: + print(f"Error in Windows process handling: {e}") + proc.terminate() + else: + # Unix handling + import tty + try: + old_settings = termios.tcgetattr(sys.stdin) + tty.setraw(sys.stdin) + while True: + if check_timeout(): + was_terminated = True + break + rlist, _, _ = select.select([master_fd, sys.stdin], [], [], 0.1) + + for fd in rlist: try: - input_data = os.read(stdin_fd, 1024) - except OSError: - input_data = b"" - if input_data: - os.write(master_fd, input_data) + if fd == master_fd: + data = os.read(master_fd, 1024) + if not data: + break + captured_data.append(data) + stream.feed(data.decode(errors='ignore')) + else: + data = os.read(fd, 1024) + os.write(master_fd, data) + except (IOError, OSError): + break + except KeyboardInterrupt: proc.terminate() finally: - termios.tcsetattr(stdin_fd, termios.TCSADRAIN, old_settings) - else: - # Non-interactive mode. - try: - while True: - if check_timeout(): - was_terminated = True - break - rlist, _, _ = select.select([master_fd], [], [], 1.0) - if not rlist: - continue - try: - data = os.read(master_fd, 1024) - except OSError as e: - if e.errno == errno.EIO: - break - else: - raise - if not data: # EOF detected. - break - captured_data.append(data) - decoded = data.decode("utf-8", errors="ignore") - stream.feed(decoded) - os.write(1, data) - except KeyboardInterrupt: - proc.terminate() + termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings) + + if not sys.platform == "win32" and master_fd is not None: + os.close(master_fd) - os.close(master_fd) proc.wait() # Assemble full scrollback: combine history.top, the current display, and history.bottom. - top_lines = [render_line(line, cols) for line in screen.history.top] + def render_line(line, width): + return ''.join(char.data for char in line[:width]).rstrip() + + # Combine history and current screen content + final_output = [] + + # Add lines from history + history_lines = [render_line(line, cols) for line in screen.history.top] + final_output.extend(line for line in history_lines if line.strip()) + + # Add current screen content + screen_lines = [render_line(line, cols) for line in screen.display] + final_output.extend(line for line in screen_lines if line.strip()) + + # Add bottom history bottom_lines = [render_line(line, cols) for line in screen.history.bottom] - display_lines = screen.display # List of strings representing the current screen. - all_lines = top_lines + display_lines + bottom_lines + final_output.extend(line for line in bottom_lines if line.strip()) - # Trim out empty lines to get only meaningful "history" lines. - trimmed_lines = [line for line in all_lines if line.strip()] - final_output = "\n".join(trimmed_lines) - - # Add timeout message if process was terminated due to timeout. + # Add timeout message if process was terminated if was_terminated: - timeout_msg = f"\n[Process exceeded timeout ({expected_runtime_seconds} seconds expected)]" - final_output += timeout_msg + if timeout_type == "hard_timeout": + timeout_msg = f"\n[Process forcefully terminated after exceeding {3 * expected_runtime_seconds:.1f} seconds (expected: {expected_runtime_seconds} seconds)]" + else: + timeout_msg = f"\n[Process gracefully terminated after exceeding {2 * expected_runtime_seconds:.1f} seconds (expected: {expected_runtime_seconds} seconds)]" + final_output.append(timeout_msg) - # Limit output to the last 8000 bytes. + # Limit output size final_output = final_output[-8000:] + final_output = '\n'.join(final_output) - return final_output.encode("utf-8"), proc.returncode + return final_output.encode('utf-8'), proc.returncode if __name__ == "__main__": diff --git a/ra_aid/webui/server.py b/ra_aid/webui/server.py index fd2bd91..50d6555 100644 --- a/ra_aid/webui/server.py +++ b/ra_aid/webui/server.py @@ -1,29 +1,37 @@ -"""Web interface server implementation for RA.Aid.""" - -import asyncio -import logging -import shutil +#!/usr/bin/env python3 import sys +import os from pathlib import Path -from typing import Any, Dict, List +import asyncio +from typing import List +import json +import threading +import queue +import traceback +import shutil +import logging -import uvicorn -from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect -from fastapi.middleware.cors import CORSMiddleware +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[ + logging.StreamHandler(sys.__stderr__) # Use the real stderr + ] +) +logger = logging.getLogger(__name__) + +# Add project root to Python path +project_root = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) +if project_root not in sys.path: + sys.path.insert(0, project_root) + +from fastapi import FastAPI, WebSocket, Request, WebSocketDisconnect from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates - -# Configure logging -logging.basicConfig(level=logging.DEBUG) # Set to DEBUG for more info -logger = logging.getLogger(__name__) - -# Verify ra-aid is available -if not shutil.which("ra-aid"): - logger.error( - "ra-aid command not found. Please ensure it's installed and in your PATH" - ) - sys.exit(1) +from fastapi.middleware.cors import CORSMiddleware +import uvicorn app = FastAPI() @@ -36,67 +44,105 @@ app.add_middleware( allow_headers=["*"], ) -# Get the directory containing static files -STATIC_DIR = Path(__file__).parent / "static" -if not STATIC_DIR.exists(): - logger.error(f"Static directory not found at {STATIC_DIR}") - sys.exit(1) +# Setup templates and static files directories +CURRENT_DIR = Path(__file__).parent +templates = Jinja2Templates(directory=CURRENT_DIR) -logger.info(f"Using static directory: {STATIC_DIR}") +# Mount static files for js and other assets +static_dir = CURRENT_DIR / "static" +app.mount("/static", StaticFiles(directory=str(static_dir)), name="static") -# Setup templates -templates = Jinja2Templates(directory=str(STATIC_DIR)) +# Store active WebSocket connections +active_connections: List[WebSocket] = [] +def run_ra_aid(message_content, output_queue): + """Run ra-aid in a separate thread""" + try: + import ra_aid.__main__ + logger.info("Successfully imported ra_aid.__main__") -class ConnectionManager: - def __init__(self): - self.active_connections: List[WebSocket] = [] + # Override sys.argv + sys.argv = [sys.argv[0], "-m", message_content, "--cowboy-mode"] + logger.info(f"Set sys.argv to: {sys.argv}") + + # Create custom output handler + class QueueHandler: + def __init__(self, queue): + self.queue = queue + self.buffer = [] + self.box_start = False + self._real_stderr = sys.__stderr__ + + def write(self, text): + # Always log raw output for debugging + logger.debug(f"Raw output: {repr(text)}") + + # Check if this is a box drawing character + if any(c in text for c in '╭╮╰╯│─'): + self.box_start = True + self.buffer.append(text) + elif self.box_start and text.strip(): + self.buffer.append(text) + if '╯' in text: # End of box + full_text = ''.join(self.buffer) + # Extract content from inside the box + lines = full_text.split('\n') + content_lines = [] + for line in lines: + # Remove box characters and leading/trailing spaces + clean_line = line.strip('╭╮╰╯│─ ') + if clean_line: + content_lines.append(clean_line) + if content_lines: + self.queue.put('\n'.join(content_lines)) + self.buffer = [] + self.box_start = False + elif not self.box_start and text.strip(): + self.queue.put(text.strip()) + + def flush(self): + if self.buffer: + full_text = ''.join(self.buffer) + # Extract content from partial box + lines = full_text.split('\n') + content_lines = [] + for line in lines: + # Remove box characters and leading/trailing spaces + clean_line = line.strip('╭╮╰╯│─ ') + if clean_line: + content_lines.append(clean_line) + if content_lines: + self.queue.put('\n'.join(content_lines)) + self.buffer = [] + self.box_start = False + + # Replace stdout and stderr + old_stdout = sys.stdout + old_stderr = sys.stderr + queue_handler = QueueHandler(output_queue) + sys.stdout = queue_handler + sys.stderr = queue_handler - async def connect(self, websocket: WebSocket) -> bool: try: - logger.debug("Accepting WebSocket connection...") - await websocket.accept() - logger.debug("WebSocket connection accepted") - self.active_connections.append(websocket) - return True + logger.info("Starting ra_aid.main()") + ra_aid.__main__.main() + logger.info("Finished ra_aid.main()") + except SystemExit: + logger.info("Caught SystemExit - this is normal") except Exception as e: - logger.error(f"Error accepting WebSocket connection: {e}") - return False - - def disconnect(self, websocket: WebSocket): - if websocket in self.active_connections: - self.active_connections.remove(websocket) - - async def send_message(self, websocket: WebSocket, message: Dict[str, Any]): - try: - await websocket.send_json(message) - except Exception as e: - logger.error(f"Error sending message: {e}") - await self.handle_error(websocket, str(e)) - - async def handle_error(self, websocket: WebSocket, error_message: str): - try: - await websocket.send_json( - { - "type": "chunk", - "chunk": { - "tools": { - "messages": [ - { - "content": f"Error: {error_message}", - "status": "error", - } - ] - } - }, - } - ) - except Exception as e: - logger.error(f"Error sending error message: {e}") - - -manager = ConnectionManager() + logger.error(f"Error in main: {str(e)}") + traceback.print_exc(file=sys.__stderr__) + finally: + # Flush any remaining output + queue_handler.flush() + # Restore stdout and stderr + sys.stdout = old_stdout + sys.stderr = old_stderr + except Exception as e: + logger.error(f"Error in thread: {str(e)}") + traceback.print_exc(file=sys.__stderr__) + output_queue.put(f"Error: {str(e)}") @app.get("/", response_class=HTMLResponse) async def get_root(request: Request): @@ -105,135 +151,108 @@ async def get_root(request: Request): "index.html", {"request": request, "server_port": request.url.port or 8080} ) - -# Mount static files for js and other assets -app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static") - - @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): - client_id = id(websocket) - logger.info(f"New WebSocket connection attempt from client {client_id}") - - if not await manager.connect(websocket): - logger.error(f"Failed to accept WebSocket connection for client {client_id}") - return - - logger.info(f"WebSocket connection accepted for client {client_id}") + await websocket.accept() + logger.info("WebSocket connection established") + active_connections.append(websocket) try: - # Send initial connection success message - await manager.send_message( - websocket, - { - "type": "chunk", - "chunk": { - "agent": { - "messages": [ - {"content": "Connected to RA.Aid server", "status": "info"} - ] - } - }, - }, - ) - while True: - try: - message = await websocket.receive_json() - logger.debug(f"Received message from client {client_id}: {message}") + message = await websocket.receive_json() + logger.info(f"Received message: {message}") - if message["type"] == "request": - await manager.send_message(websocket, {"type": "stream_start"}) + if message["type"] == "request": + content = message["content"] + logger.info(f"Processing request: {content}") - try: - # Run ra-aid with the request - cmd = ["ra-aid", "-m", message["content"], "--cowboy-mode"] - logger.info(f"Executing command: {' '.join(cmd)}") + # Create queue for output + output_queue = queue.Queue() - process = await asyncio.create_subprocess_exec( - *cmd, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - ) - logger.info(f"Process started with PID: {process.pid}") + # Create and start thread + thread = threading.Thread(target=run_ra_aid, args=(content, output_queue)) + thread.start() - async def read_stream(stream, is_error=False): - while True: - line = await stream.readline() - if not line: - break + try: + # Send stream start + await websocket.send_json({"type": "stream_start"}) - try: - decoded_line = line.decode().strip() - if decoded_line: - await manager.send_message( - websocket, - { - "type": "chunk", - "chunk": { - "tools" if is_error else "agent": { - "messages": [ - { - "content": decoded_line, - "status": "error" - if is_error - else "info", - } - ] - } - }, - }, - ) - except Exception as e: - logger.error(f"Error processing output: {e}") + while thread.is_alive() or not output_queue.empty(): + try: + # Get output with timeout to allow checking thread status + line = output_queue.get(timeout=0.1) + if line and line.strip(): # Only send non-empty messages + logger.debug(f"WebSocket sending: {repr(line)}") + await websocket.send_json({ + "type": "chunk", + "chunk": { + "agent": { + "messages": [{ + "content": line.strip(), + "status": "info" + }] + } + } + }) + except queue.Empty: + await asyncio.sleep(0.1) + except Exception as e: + logger.error(f"WebSocket error: {e}") + traceback.print_exc(file=sys.__stderr__) - # Create tasks for reading stdout and stderr - stdout_task = asyncio.create_task(read_stream(process.stdout)) - stderr_task = asyncio.create_task( - read_stream(process.stderr, True) - ) + # Wait for thread to finish + thread.join() + logger.info("Thread finished") - # Wait for both streams to complete - await asyncio.gather(stdout_task, stderr_task) + # Send stream end + await websocket.send_json({"type": "stream_end"}) + logger.info("Sent stream_end message") - # Wait for process to complete - return_code = await process.wait() + except Exception as e: + error_msg = f"Error running ra-aid: {str(e)}" + logger.error(error_msg) + await websocket.send_json({ + "type": "error", + "message": error_msg + }) - if return_code != 0: - await manager.handle_error( - websocket, f"Process exited with code {return_code}" - ) - - await manager.send_message( - websocket, - {"type": "stream_end", "request": message["content"]}, - ) - - except Exception as e: - logger.error(f"Error executing ra-aid: {e}") - await manager.handle_error(websocket, str(e)) - - except Exception as e: - logger.error(f"Error processing message: {e}") - await manager.handle_error(websocket, str(e)) + logger.info("Waiting for message...") except WebSocketDisconnect: - logger.info(f"WebSocket client {client_id} disconnected") + logger.info("WebSocket client disconnected") + active_connections.remove(websocket) except Exception as e: - logger.error(f"WebSocket error for client {client_id}: {e}") + logger.error(f"WebSocket error: {e}") + traceback.print_exc() finally: - manager.disconnect(websocket) - logger.info(f"WebSocket connection cleaned up for client {client_id}") + if websocket in active_connections: + active_connections.remove(websocket) + logger.info("WebSocket connection closed") + + +@app.get("/config") +async def get_config(request: Request): + """Return server configuration including host and port.""" + return {"host": request.client.host, "port": request.scope.get("server")[1]} def run_server(host: str = "0.0.0.0", port: int = 8080): """Run the FastAPI server.""" - logger.info(f"Starting server on {host}:{port}") - uvicorn.run( - app, - host=host, - port=port, - log_level="debug", - ws_max_size=16777216, # 16MB - timeout_keep_alive=0, # Disable keep-alive timeout + uvicorn.run(app, host=host, port=port) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description="RA.Aid Web Interface Server") + parser.add_argument( + "--port", type=int, default=8080, help="Port to listen on (default: 8080)" ) + parser.add_argument( + "--host", + type=str, + default="0.0.0.0", + help="Host to listen on (default: 0.0.0.0)", + ) + + args = parser.parse_args() + run_server(host=args.host, port=args.port) diff --git a/ra_aid/webui/static/index.html b/ra_aid/webui/static/index.html index 9f08e45..c534d91 100644 --- a/ra_aid/webui/static/index.html +++ b/ra_aid/webui/static/index.html @@ -1,90 +1,214 @@ - + - - RA.Aid Web Interface + RA.Aid Web UI + + + + - -
- -
-
-

History

-
-
-
+ + +
+ +
- -
- -
-
- -
- - -
-
- - -
-
+ +
+
+ + +
- + \ No newline at end of file diff --git a/ra_aid/webui/static/script.js b/ra_aid/webui/static/script.js index 6c6b5af..3057fca 100644 --- a/ra_aid/webui/static/script.js +++ b/ra_aid/webui/static/script.js @@ -1,206 +1,230 @@ -class RAWebUI { +class WebSocketHandler { constructor() { - this.messageHistory = []; - this.connectionAttempts = 0; - this.maxReconnectAttempts = 5; - this.setupElements(); - this.setupEventListeners(); - this.connectWebSocket(); + // Wait for DOM to be ready + if (document.readyState === 'loading') { + document.addEventListener('DOMContentLoaded', () => this.initialize()); + } else { + this.initialize(); + } } - setupElements() { - this.userInput = document.getElementById('user-input'); + initialize() { + // Store DOM elements as instance variables + this.messageInput = document.getElementById('user-input'); this.sendButton = document.getElementById('send-button'); - this.chatMessages = document.getElementById('chat-messages'); + this.clearButton = document.getElementById('clear-button'); this.streamOutput = document.getElementById('stream-output'); - this.historyList = document.getElementById('history-list'); - - // Disable send button initially - this.sendButton.disabled = true; - } - setupEventListeners() { + // Validate required elements exist + if (!this.messageInput || !this.sendButton || !this.streamOutput) { + console.error('Required elements not found:', { + messageInput: !!this.messageInput, + sendButton: !!this.sendButton, + streamOutput: !!this.streamOutput + }); + return; + } + + // Remove hidden class if present + this.streamOutput.classList.remove('hidden'); + + // Initialize WebSocket + this.connectWebSocket(); + + // Add event listeners this.sendButton.addEventListener('click', () => this.sendMessage()); - this.userInput.addEventListener('keypress', (e) => { + this.clearButton?.addEventListener('click', () => this.clearConversation()); + this.messageInput.addEventListener('keypress', (e) => { if (e.key === 'Enter' && !e.shiftKey) { e.preventDefault(); this.sendMessage(); } }); + + console.log('WebSocketHandler initialized with elements:', { + messageInput: this.messageInput, + sendButton: this.sendButton, + streamOutput: this.streamOutput + }); } - async connectWebSocket() { - // Don't try to reconnect if we've exceeded the maximum attempts - if (this.connectionAttempts >= this.maxReconnectAttempts) { - this.appendMessage( - 'Maximum reconnection attempts reached. Please refresh the page.', - 'error' - ); - return; - } - + connectWebSocket() { try { - // Get the server port from the meta tag - const serverPort = document.querySelector('meta[name="server-port"]')?.content || '8080'; - - // Construct WebSocket URL - const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:'; - const wsUrl = `${protocol}//${window.location.hostname}:${serverPort}/ws`; + const wsUrl = `ws://${window.location.host}/ws`; console.log('Attempting to connect to WebSocket URL:', wsUrl); - - // Close existing connection if any - if (this.ws) { - this.ws.close(); - } - // Create new WebSocket connection - console.log('Creating new WebSocket connection...'); this.ws = new WebSocket(wsUrl); - this.connectionAttempts++; - - // Setup WebSocket event handlers + this.ws.onopen = () => { console.log('WebSocket connection established successfully'); - this.connectionAttempts = 0; // Reset counter on successful connection this.sendButton.disabled = false; }; - this.ws.onclose = (event) => { - console.log('WebSocket closed:', event.code, event.reason); + this.ws.onclose = () => { + console.log('WebSocket connection closed'); this.sendButton.disabled = true; + // Try to reconnect after a delay + setTimeout(() => this.connectWebSocket(), 2000); + }; - // Only attempt reconnect if not a normal closure and within retry limits - if (event.code !== 1000 && this.connectionAttempts < this.maxReconnectAttempts) { - const delay = Math.min(1000 * Math.pow(2, this.connectionAttempts), 10000); - this.appendMessage( - `Connection lost. Reconnecting in ${delay/1000} seconds...`, - 'error' - ); - setTimeout(() => this.connectWebSocket(), delay); + this.ws.onmessage = (event) => { + try { + console.log('Raw WebSocket message:', event.data); + const message = JSON.parse(event.data); + console.log('Parsed WebSocket message:', message); + this.handleMessage(message); + } catch (error) { + console.error('Error handling message:', error); + this.appendOutput({ + content: `Error: ${error.message}`, + status: 'error' + }); } }; this.ws.onerror = (error) => { console.error('WebSocket error:', error); + this.sendButton.disabled = true; + this.appendOutput({ + content: 'Connection error. Retrying...', + status: 'error' + }); }; - - this.ws.onmessage = (event) => { - try { - const data = JSON.parse(event.data); - this.handleServerMessage(data); - } catch (error) { - console.error('Error parsing message:', error); - this.appendMessage('Error processing server message', 'error'); - } - }; - } catch (error) { - console.error('Failed to connect to WebSocket:', error); - this.appendMessage( - `Connection error: ${error.message}. Retrying...`, - 'error' - ); - - // Attempt to reconnect with exponential backoff - const delay = Math.min(1000 * Math.pow(2, this.connectionAttempts), 10000); - setTimeout(() => this.connectWebSocket(), delay); + console.error('Error connecting to WebSocket:', error); + this.appendOutput({ + content: `Connection error: ${error.message}`, + status: 'error' + }); } } - handleServerMessage(data) { - if (data.type === 'stream_start') { - this.streamOutput.textContent = ''; - this.streamOutput.style.display = 'block'; - } else if (data.type === 'stream_end') { - this.streamOutput.style.display = 'none'; - this.addToHistory(data.request); - this.sendButton.disabled = false; - } else if (data.type === 'chunk') { - this.handleChunk(data.chunk); + handleMessage(message) { + switch (message.type) { + case 'stream_start': + this.handleStreamStart(); + break; + case 'chunk': + this.handleChunk(message.chunk); + break; + case 'stream_end': + this.handleStreamEnd(); + break; + default: + console.warn('Unknown message type:', message.type); } } + handleStreamStart() { + console.log('Stream starting'); + this.clearStreamOutput(); + this.appendOutput({ + content: 'Starting new conversation...', + status: 'info' + }); + } + + handleStreamEnd() { + console.log('Stream ending'); + this.appendOutput({ + content: 'Conversation complete.', + status: 'success' + }); + this.messageInput.disabled = false; + this.sendButton.disabled = false; + } + handleChunk(chunk) { + console.log(' Processing chunk:', chunk); if (chunk.agent && chunk.agent.messages) { - chunk.agent.messages.forEach(msg => { - if (msg.content) { - if (Array.isArray(msg.content)) { - msg.content.forEach(content => { - if (content.type === 'text' && content.text.trim()) { - this.appendMessage(content.text.trim(), 'system'); - } - }); - } else if (msg.content.trim()) { - this.appendMessage(msg.content.trim(), 'system'); - } - } - }); - } else if (chunk.tools && chunk.tools.messages) { - chunk.tools.messages.forEach(msg => { - if (msg.status === 'error' && msg.content) { - this.appendMessage(msg.content.trim(), 'error'); - } + chunk.agent.messages.forEach(message => { + console.log(' Processing agent message:', message); + console.log(' Adding agent message:', message.content); + this.appendOutput(message); }); } } - appendMessage(content, type) { - const messageDiv = document.createElement('div'); - messageDiv.className = `message ${type}-message`; - messageDiv.textContent = content; - this.chatMessages.appendChild(messageDiv); - this.chatMessages.scrollTop = this.chatMessages.scrollHeight; + clearStreamOutput() { + console.log('Clearing stream output'); + while (this.streamOutput.firstChild) { + this.streamOutput.removeChild(this.streamOutput.firstChild); + } + console.log('Stream output cleared'); } - addToHistory(request) { - const historyItem = document.createElement('div'); - historyItem.className = 'history-item'; - historyItem.textContent = request.slice(0, 50) + (request.length > 50 ? '...' : ''); - historyItem.title = request; - historyItem.addEventListener('click', () => { - this.userInput.value = request; - this.userInput.focus(); + clearConversation() { + this.clearStreamOutput(); + this.appendOutput({ + content: 'Conversation cleared.', + status: 'info' + }); + } + + appendOutput(message) { + console.log(' Appending output:', message); + const messageDiv = document.createElement('div'); + messageDiv.className = `message ${message.status || 'info'}`; + + const contentSpan = document.createElement('span'); + + // Convert ANSI escape codes to HTML + let content = message.content; + content = this.convertAnsiToHtml(content); + + // Check for code blocks and apply syntax highlighting + if (content.includes('```')) { + content = this.highlightCodeBlocks(content); + } + + contentSpan.innerHTML = content; + messageDiv.appendChild(contentSpan); + + this.streamOutput.appendChild(messageDiv); + messageDiv.scrollIntoView({ behavior: 'smooth', block: 'end' }); + } + + convertAnsiToHtml(text) { + // ANSI color codes to CSS classes + const ansiToClass = { + '[94m': '', + '[1;32m': '', + '[0m': '' + }; + + // Replace ANSI codes with HTML + let html = text; + for (const [ansi, htmlClass] of Object.entries(ansiToClass)) { + html = html.replaceAll('\u001b' + ansi, htmlClass); + } + return html; + } + + highlightCodeBlocks(content) { + const codeBlockRegex = /```(\w+)?\n([\s\S]*?)```/g; + return content.replace(codeBlockRegex, (match, lang, code) => { + const language = lang || 'plaintext'; + const highlighted = hljs.highlight(code.trim(), { language }).value; + return `
${highlighted}
`; }); - this.historyList.insertBefore(historyItem, this.historyList.firstChild); - this.messageHistory.push(request); } sendMessage() { - console.log('Send button clicked'); - const message = this.userInput.value.trim(); - console.log('Message content:', message); - - if (!message) { - console.log('Message is empty, not sending'); - return; - } + const message = this.messageInput.value.trim(); + if (!message) return; - if (!this.ws || this.ws.readyState !== WebSocket.OPEN) { - console.error('WebSocket is not connected'); - this.appendMessage('Error: Not connected to server. Please wait...', 'error'); - return; - } - - try { - console.log('Sending message to server'); - this.appendMessage(message, 'user'); - const payload = { type: 'request', content: message }; - console.log('Payload:', payload); - - this.ws.send(JSON.stringify(payload)); - console.log('Message sent successfully'); - - this.userInput.value = ''; - this.sendButton.disabled = true; - } catch (error) { - console.error('Error sending message:', error); - this.appendMessage(`Error sending message: ${error.message}`, 'error'); - this.sendButton.disabled = false; - } + console.log('Sending message:', message); + this.ws.send(JSON.stringify({ + type: "request", + content: message + })); + this.messageInput.value = ''; + this.messageInput.disabled = true; + this.sendButton.disabled = true; } } -// Initialize the UI when the page loads -document.addEventListener('DOMContentLoaded', () => { - window.raWebUI = new RAWebUI(); +// Initialize WebSocket handler when the page loads +window.addEventListener('load', () => { + new WebSocketHandler(); }); \ No newline at end of file diff --git a/tests/ra_aid/proc/test_windows_compatibility.py b/tests/ra_aid/proc/test_windows_compatibility.py new file mode 100644 index 0000000..842ba39 --- /dev/null +++ b/tests/ra_aid/proc/test_windows_compatibility.py @@ -0,0 +1,66 @@ +"""Tests for Windows-specific functionality.""" + +import os +import sys +import subprocess +import pytest +from unittest.mock import patch, MagicMock + +from ra_aid.proc.interactive import get_terminal_size, create_process, run_interactive_command + +@pytest.mark.skipif(sys.platform != "win32", reason="Windows-specific tests") +class TestWindowsCompatibility: + """Test suite for Windows-specific functionality.""" + + def test_get_terminal_size(self): + """Test terminal size detection on Windows.""" + with patch('shutil.get_terminal_size') as mock_get_size: + mock_get_size.return_value = MagicMock(columns=120, lines=30) + cols, rows = get_terminal_size() + assert cols == 120 + assert rows == 30 + mock_get_size.assert_called_once() + + def test_create_process(self): + """Test process creation on Windows.""" + with patch('subprocess.Popen') as mock_popen: + mock_process = MagicMock() + mock_process.returncode = 0 + mock_popen.return_value = mock_process + + proc, _ = create_process(['echo', 'test']) + + assert mock_popen.called + args, kwargs = mock_popen.call_args + assert kwargs['stdin'] == subprocess.PIPE + assert kwargs['stdout'] == subprocess.PIPE + assert kwargs['stderr'] == subprocess.PIPE + assert 'startupinfo' in kwargs + assert kwargs['startupinfo'].dwFlags & subprocess.STARTF_USESHOWWINDOW + + def test_run_interactive_command(self): + """Test running an interactive command on Windows.""" + test_output = "Test output\n" + + with patch('subprocess.Popen') as mock_popen: + mock_process = MagicMock() + mock_process.stdout = MagicMock() + mock_process.stdout.read.return_value = test_output.encode() + mock_process.wait.return_value = 0 + mock_popen.return_value = mock_process + + output, return_code = run_interactive_command(['echo', 'test']) + assert return_code == 0 + assert "Test output" in output.decode() + + def test_windows_dependencies(self): + """Test that required Windows dependencies are available.""" + if sys.platform == "win32": + import msvcrt + import win32pipe + import win32file + import win32con + import win32process + + # If we get here without ImportError, the test passes + assert True