Add Windows compatibility improvements (#105)

* Add Windows compatibility improvements

1. Add error handling for Windows-specific modules
2. Update README with Windows installation instructions
3. Add Windows-specific tests
4. Improve cross-platform support in interactive.py

* Fix: Add missing subprocess import in Windows compatibility tests

* Improve Windows compatibility:

1. Add detailed error handling for Windows I/O operations
2. Enhance timeout messages with more descriptive information
3. Add comprehensive comments explaining Windows-specific code

* Fix WebUI: Improve message display, add syntax highlighting, animations, and fix WebSocket communication
This commit is contained in:
Mark Varkevisser 2025-02-25 10:41:29 -08:00 committed by GitHub
parent 2f132bdeb5
commit d163a74c47
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 866 additions and 520 deletions

View File

@ -97,6 +97,31 @@ What sets RA.Aid apart is its ability to handle complex programming tasks that e
## Installation ## Installation
### Windows Installation
1. Install Python 3.8 or higher from [python.org](https://www.python.org/downloads/)
2. Install required system dependencies:
```powershell
# Install Chocolatey if not already installed (run in admin PowerShell)
Set-ExecutionPolicy Bypass -Scope Process -Force; [System.Net.ServicePointManager]::SecurityProtocol = [System.Net.ServicePointManager]::SecurityProtocol -bor 3072; iex ((New-Object System.Net.WebClient).DownloadString('https://community.chocolatey.org/install.ps1'))
# Install ripgrep using Chocolatey
choco install ripgrep
```
3. Install RA.Aid:
```powershell
pip install ra-aid
```
4. Install Windows-specific dependencies:
```powershell
pip install pywin32
```
5. Set up your API keys in a `.env` file:
```env
ANTHROPIC_API_KEY=your_anthropic_key
OPENAI_API_KEY=your_openai_key
```
### Unix/Linux Installation
RA.Aid can be installed directly using pip: RA.Aid can be installed directly using pip:
```bash ```bash

View File

@ -2,6 +2,7 @@
import os import os
import sys import sys
import subprocess
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from ra_aid import print_error from ra_aid import print_error
@ -21,8 +22,13 @@ class RipGrepDependency(Dependency):
def check(self): def check(self):
"""Check if ripgrep is installed.""" """Check if ripgrep is installed."""
result = os.system("rg --version > /dev/null 2>&1") try:
if result != 0: result = subprocess.run(['rg', '--version'],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL)
if result.returncode != 0:
raise FileNotFoundError()
except (FileNotFoundError, subprocess.SubprocessError):
print_error("Required dependency 'ripgrep' is not installed.") print_error("Required dependency 'ripgrep' is not installed.")
print("Please install ripgrep:") print("Please install ripgrep:")
print(" - Ubuntu/Debian: sudo apt-get install ripgrep") print(" - Ubuntu/Debian: sudo apt-get install ripgrep")

View File

@ -3,51 +3,123 @@
Module for running interactive subprocesses with output capture, Module for running interactive subprocesses with output capture,
with full raw input passthrough for interactive commands. with full raw input passthrough for interactive commands.
It uses a pseudo-tty and integrates pyte's HistoryScreen to simulate It uses a pseudo-tty on Unix systems and direct pipes on Windows to simulate
a terminal and capture the final scrollback history (non-blank lines). a terminal and capture the final scrollback history (non-blank lines).
The interface remains compatible with external callers expecting a tuple (output, return_code),
where output is a bytes object (UTF-8 encoded).
""" """
import errno import errno
import io import io
import os import os
import select import select
import shutil
import signal import signal
import subprocess import subprocess
import sys import sys
import termios
import time import time
import tty import shutil
from typing import List, Tuple from typing import List, Tuple, Optional
import pyte import pyte
from pyte.screens import HistoryScreen
# Windows-specific imports
if sys.platform == "win32":
try:
# msvcrt: Provides Windows console I/O functionality
import msvcrt
# win32pipe, win32file: For low-level pipe operations
import win32pipe
import win32file
# win32con: Windows API constants
import win32con
# win32process: Process management on Windows
import win32process
except ImportError as e:
print("Error: Required Windows dependencies not found.")
print("Please install the required packages using:")
print(" pip install pywin32")
sys.exit(1)
else:
# Unix-specific imports for terminal handling
import termios
import fcntl
import pty
def render_line(line, columns: int) -> str: def get_terminal_size():
"""Render a single screen line from the pyte buffer (a mapping of column to Char).""" """Get the current terminal size."""
return "".join(line[x].data for x in range(columns)) if sys.platform == "win32":
import shutil
size = shutil.get_terminal_size()
return size.columns, size.lines
else:
import struct
try:
with open(sys.stdout.fileno(), 'wb', buffering=0) as fd:
size = struct.unpack('hh', fcntl.ioctl(fd, termios.TIOCGWINSZ, '1234'))
return size[1], size[0]
except (IOError, AttributeError):
return 80, 24
def create_process(cmd: List[str]) -> Tuple[subprocess.Popen, Optional[int]]:
"""Create a subprocess with appropriate handling for the platform.
On Windows:
- Uses STARTUPINFO to hide the console window
- Creates a new process group for proper signal handling
- Returns direct pipe handles for I/O
On Unix:
- Creates a pseudo-terminal (PTY) for proper terminal emulation
- Sets up process group for signal handling
- Returns master PTY file descriptor for I/O
"""
if sys.platform == "win32":
# Windows process creation with hidden console
startupinfo = subprocess.STARTUPINFO()
startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW # Hide the console window
# Create process with proper pipe handling
proc = subprocess.Popen(
cmd,
stdin=subprocess.PIPE, # Allow writing to stdin
stdout=subprocess.PIPE, # Capture stdout
stderr=subprocess.PIPE, # Capture stderr
startupinfo=startupinfo,
# CREATE_NEW_PROCESS_GROUP allows proper Ctrl+C handling
creationflags=subprocess.CREATE_NEW_PROCESS_GROUP
)
return proc, None # No PTY master_fd needed on Windows
else:
# Unix process creation with PTY
master_fd, slave_fd = pty.openpty()
proc = subprocess.Popen(
cmd,
stdin=slave_fd,
stdout=slave_fd,
stderr=slave_fd,
preexec_fn=os.setsid
)
os.close(slave_fd)
return proc, master_fd
def run_interactive_command( def run_interactive_command(
cmd: List[str], expected_runtime_seconds: int = 30 cmd: List[str],
expected_runtime_seconds: int = 1800,
ratio: float = 0.5
) -> Tuple[bytes, int]: ) -> Tuple[bytes, int]:
""" """
Runs an interactive command with a pseudo-tty, capturing final scrollback history. Runs an interactive command with a pseudo-tty, capturing final scrollback history.
Assumptions and constraints: Assumptions and constraints:
- Running on a Linux system. - Running on a Linux system or Windows.
- `cmd` is a non-empty list where cmd[0] is the executable. - `cmd` is a non-empty list where cmd[0] is the executable.
- The executable is on PATH. - The executable is on PATH.
Args: Args:
cmd: A list containing the command and its arguments. cmd: A list containing the command and its arguments.
expected_runtime_seconds: Expected runtime in seconds, defaults to 30. expected_runtime_seconds: Expected runtime in seconds, defaults to 1800.
If process exceeds 2x this value, it will be terminated gracefully. If process exceeds 2x this value, it will be terminated gracefully.
If process exceeds 3x this value, it will be killed forcefully. If process exceeds 3x this value, it will be killed forcefully.
Must be between 1 and 1800 seconds (30 minutes). Must be between 1 and 1800 seconds (30 minutes).
ratio: Ratio of history to keep from top vs bottom (default: 0.5)
Returns: Returns:
A tuple of (captured_output, return_code), where captured_output is a UTF-8 encoded A tuple of (captured_output, return_code), where captured_output is a UTF-8 encoded
@ -69,154 +141,164 @@ def run_interactive_command(
) )
try: try:
term_size = os.get_terminal_size() term_size = get_terminal_size()
cols, rows = term_size.columns, term_size.lines cols, rows = term_size.columns, term_size.lines
except OSError: except OSError:
cols, rows = 80, 24 cols, rows = 80, 24
# Set up pyte screen and stream to capture terminal output. # Set up pyte screen and stream to capture terminal output.
screen = HistoryScreen(cols, rows, history=2000, ratio=0.5) screen = pyte.HistoryScreen(cols, rows, history=2000, ratio=ratio)
stream = pyte.Stream(screen) stream = pyte.Stream(screen)
# Open a new pseudo-tty. proc, master_fd = create_process(cmd)
master_fd, slave_fd = os.openpty()
# Set master_fd to non-blocking to avoid indefinite blocking.
os.set_blocking(master_fd, False)
try:
stdin_fd = sys.stdin.fileno()
except (AttributeError, io.UnsupportedOperation):
stdin_fd = None
# Set up environment variables for the subprocess using detected terminal size.
env = os.environ.copy()
env.update(
{
"DEBIAN_FRONTEND": "noninteractive",
"GIT_PAGER": "",
"PYTHONUNBUFFERED": "1",
"CI": "true",
"LANG": "C.UTF-8",
"LC_ALL": "C.UTF-8",
"COLUMNS": str(cols),
"LINES": str(rows),
"FORCE_COLOR": "1",
"GIT_TERMINAL_PROMPT": "0",
"PYTHONDONTWRITEBYTECODE": "1",
"NODE_OPTIONS": "--unhandled-rejections=strict",
}
)
proc = subprocess.Popen(
cmd,
stdin=slave_fd,
stdout=slave_fd,
stderr=slave_fd,
bufsize=0,
close_fds=True,
env=env,
preexec_fn=os.setsid, # Create new process group for proper signal handling.
)
os.close(slave_fd) # Close slave end in the parent process.
captured_data = [] captured_data = []
start_time = time.time() start_time = time.time()
was_terminated = False was_terminated = False
timeout_type = None
def check_timeout(): def check_timeout():
nonlocal timeout_type
elapsed = time.time() - start_time elapsed = time.time() - start_time
if elapsed > 3 * expected_runtime_seconds: if elapsed > 3 * expected_runtime_seconds:
os.killpg(os.getpgid(proc.pid), signal.SIGKILL) if sys.platform == "win32":
print("\nProcess exceeded hard timeout limit, forcefully terminating...")
proc.terminate()
time.sleep(0.5)
if proc.poll() is None:
print("Process did not respond to termination, killing...")
proc.kill()
else:
print("\nProcess exceeded hard timeout limit, sending SIGKILL...")
os.killpg(os.getpgid(proc.pid), signal.SIGKILL)
timeout_type = "hard_timeout"
return True return True
elif elapsed > 2 * expected_runtime_seconds: elif elapsed > 2 * expected_runtime_seconds:
os.killpg(os.getpgid(proc.pid), signal.SIGTERM) if sys.platform == "win32":
print("\nProcess exceeded soft timeout limit, attempting graceful termination...")
proc.terminate()
else:
print("\nProcess exceeded soft timeout limit, sending SIGTERM...")
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
timeout_type = "soft_timeout"
return True return True
return False return False
# Interactive mode: forward input if running in a TTY. # Interactive mode: forward input if running in a TTY.
if stdin_fd is not None and sys.stdin.isatty(): if sys.platform == "win32":
old_settings = termios.tcgetattr(stdin_fd) # Windows handling
tty.setraw(stdin_fd)
try: try:
while True: while True:
if check_timeout(): if check_timeout():
was_terminated = True was_terminated = True
break break
# Use a finite timeout to avoid indefinite blocking.
rlist, _, _ = select.select([master_fd, stdin_fd], [], [], 1.0) try:
if master_fd in rlist: # Check stdout with proper error handling
try: stdout_data = proc.stdout.read1(1024)
data = os.read(master_fd, 1024) if stdout_data:
except OSError as e: captured_data.append(stdout_data)
if e.errno == errno.EIO: try:
stream.feed(stdout_data.decode(errors='ignore'))
except Exception as e:
print(f"Warning: Error processing stdout: {e}")
# Check stderr with proper error handling
stderr_data = proc.stderr.read1(1024)
if stderr_data:
captured_data.append(stderr_data)
try:
stream.feed(stderr_data.decode(errors='ignore'))
except Exception as e:
print(f"Warning: Error processing stderr: {e}")
# Check for input with proper error handling
if msvcrt.kbhit():
try:
char = msvcrt.getch()
proc.stdin.write(char)
proc.stdin.flush()
except (IOError, OSError) as e:
print(f"Warning: Error handling keyboard input: {e}")
break break
else:
raise except (IOError, OSError) as e:
if not data: # EOF detected. if isinstance(e, OSError) and e.winerror == 6: # Invalid handle
break break
captured_data.append(data) print(f"Warning: I/O error during process communication: {e}")
decoded = data.decode("utf-8", errors="ignore") break
stream.feed(decoded)
os.write(1, data) except Exception as e:
if stdin_fd in rlist: print(f"Error in Windows process handling: {e}")
proc.terminate()
else:
# Unix handling
import tty
try:
old_settings = termios.tcgetattr(sys.stdin)
tty.setraw(sys.stdin)
while True:
if check_timeout():
was_terminated = True
break
rlist, _, _ = select.select([master_fd, sys.stdin], [], [], 0.1)
for fd in rlist:
try: try:
input_data = os.read(stdin_fd, 1024) if fd == master_fd:
except OSError: data = os.read(master_fd, 1024)
input_data = b"" if not data:
if input_data: break
os.write(master_fd, input_data) captured_data.append(data)
stream.feed(data.decode(errors='ignore'))
else:
data = os.read(fd, 1024)
os.write(master_fd, data)
except (IOError, OSError):
break
except KeyboardInterrupt: except KeyboardInterrupt:
proc.terminate() proc.terminate()
finally: finally:
termios.tcsetattr(stdin_fd, termios.TCSADRAIN, old_settings) termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings)
else:
# Non-interactive mode. if not sys.platform == "win32" and master_fd is not None:
try: os.close(master_fd)
while True:
if check_timeout():
was_terminated = True
break
rlist, _, _ = select.select([master_fd], [], [], 1.0)
if not rlist:
continue
try:
data = os.read(master_fd, 1024)
except OSError as e:
if e.errno == errno.EIO:
break
else:
raise
if not data: # EOF detected.
break
captured_data.append(data)
decoded = data.decode("utf-8", errors="ignore")
stream.feed(decoded)
os.write(1, data)
except KeyboardInterrupt:
proc.terminate()
os.close(master_fd)
proc.wait() proc.wait()
# Assemble full scrollback: combine history.top, the current display, and history.bottom. # Assemble full scrollback: combine history.top, the current display, and history.bottom.
top_lines = [render_line(line, cols) for line in screen.history.top] def render_line(line, width):
return ''.join(char.data for char in line[:width]).rstrip()
# Combine history and current screen content
final_output = []
# Add lines from history
history_lines = [render_line(line, cols) for line in screen.history.top]
final_output.extend(line for line in history_lines if line.strip())
# Add current screen content
screen_lines = [render_line(line, cols) for line in screen.display]
final_output.extend(line for line in screen_lines if line.strip())
# Add bottom history
bottom_lines = [render_line(line, cols) for line in screen.history.bottom] bottom_lines = [render_line(line, cols) for line in screen.history.bottom]
display_lines = screen.display # List of strings representing the current screen. final_output.extend(line for line in bottom_lines if line.strip())
all_lines = top_lines + display_lines + bottom_lines
# Trim out empty lines to get only meaningful "history" lines. # Add timeout message if process was terminated
trimmed_lines = [line for line in all_lines if line.strip()]
final_output = "\n".join(trimmed_lines)
# Add timeout message if process was terminated due to timeout.
if was_terminated: if was_terminated:
timeout_msg = f"\n[Process exceeded timeout ({expected_runtime_seconds} seconds expected)]" if timeout_type == "hard_timeout":
final_output += timeout_msg timeout_msg = f"\n[Process forcefully terminated after exceeding {3 * expected_runtime_seconds:.1f} seconds (expected: {expected_runtime_seconds} seconds)]"
else:
timeout_msg = f"\n[Process gracefully terminated after exceeding {2 * expected_runtime_seconds:.1f} seconds (expected: {expected_runtime_seconds} seconds)]"
final_output.append(timeout_msg)
# Limit output to the last 8000 bytes. # Limit output size
final_output = final_output[-8000:] final_output = final_output[-8000:]
final_output = '\n'.join(final_output)
return final_output.encode("utf-8"), proc.returncode return final_output.encode('utf-8'), proc.returncode
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,29 +1,37 @@
"""Web interface server implementation for RA.Aid.""" #!/usr/bin/env python3
import asyncio
import logging
import shutil
import sys import sys
import os
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List import asyncio
from typing import List
import json
import threading
import queue
import traceback
import shutil
import logging
import uvicorn # Configure logging
from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect logging.basicConfig(
from fastapi.middleware.cors import CORSMiddleware level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.__stderr__) # Use the real stderr
]
)
logger = logging.getLogger(__name__)
# Add project root to Python path
project_root = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
if project_root not in sys.path:
sys.path.insert(0, project_root)
from fastapi import FastAPI, WebSocket, Request, WebSocketDisconnect
from fastapi.responses import HTMLResponse from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from fastapi.middleware.cors import CORSMiddleware
# Configure logging import uvicorn
logging.basicConfig(level=logging.DEBUG) # Set to DEBUG for more info
logger = logging.getLogger(__name__)
# Verify ra-aid is available
if not shutil.which("ra-aid"):
logger.error(
"ra-aid command not found. Please ensure it's installed and in your PATH"
)
sys.exit(1)
app = FastAPI() app = FastAPI()
@ -36,67 +44,105 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
# Get the directory containing static files # Setup templates and static files directories
STATIC_DIR = Path(__file__).parent / "static" CURRENT_DIR = Path(__file__).parent
if not STATIC_DIR.exists(): templates = Jinja2Templates(directory=CURRENT_DIR)
logger.error(f"Static directory not found at {STATIC_DIR}")
sys.exit(1)
logger.info(f"Using static directory: {STATIC_DIR}") # Mount static files for js and other assets
static_dir = CURRENT_DIR / "static"
app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")
# Setup templates # Store active WebSocket connections
templates = Jinja2Templates(directory=str(STATIC_DIR)) active_connections: List[WebSocket] = []
def run_ra_aid(message_content, output_queue):
"""Run ra-aid in a separate thread"""
try:
import ra_aid.__main__
logger.info("Successfully imported ra_aid.__main__")
class ConnectionManager: # Override sys.argv
def __init__(self): sys.argv = [sys.argv[0], "-m", message_content, "--cowboy-mode"]
self.active_connections: List[WebSocket] = [] logger.info(f"Set sys.argv to: {sys.argv}")
# Create custom output handler
class QueueHandler:
def __init__(self, queue):
self.queue = queue
self.buffer = []
self.box_start = False
self._real_stderr = sys.__stderr__
def write(self, text):
# Always log raw output for debugging
logger.debug(f"Raw output: {repr(text)}")
# Check if this is a box drawing character
if any(c in text for c in '╭╮╰╯│─'):
self.box_start = True
self.buffer.append(text)
elif self.box_start and text.strip():
self.buffer.append(text)
if '' in text: # End of box
full_text = ''.join(self.buffer)
# Extract content from inside the box
lines = full_text.split('\n')
content_lines = []
for line in lines:
# Remove box characters and leading/trailing spaces
clean_line = line.strip('╭╮╰╯│─ ')
if clean_line:
content_lines.append(clean_line)
if content_lines:
self.queue.put('\n'.join(content_lines))
self.buffer = []
self.box_start = False
elif not self.box_start and text.strip():
self.queue.put(text.strip())
def flush(self):
if self.buffer:
full_text = ''.join(self.buffer)
# Extract content from partial box
lines = full_text.split('\n')
content_lines = []
for line in lines:
# Remove box characters and leading/trailing spaces
clean_line = line.strip('╭╮╰╯│─ ')
if clean_line:
content_lines.append(clean_line)
if content_lines:
self.queue.put('\n'.join(content_lines))
self.buffer = []
self.box_start = False
# Replace stdout and stderr
old_stdout = sys.stdout
old_stderr = sys.stderr
queue_handler = QueueHandler(output_queue)
sys.stdout = queue_handler
sys.stderr = queue_handler
async def connect(self, websocket: WebSocket) -> bool:
try: try:
logger.debug("Accepting WebSocket connection...") logger.info("Starting ra_aid.main()")
await websocket.accept() ra_aid.__main__.main()
logger.debug("WebSocket connection accepted") logger.info("Finished ra_aid.main()")
self.active_connections.append(websocket) except SystemExit:
return True logger.info("Caught SystemExit - this is normal")
except Exception as e: except Exception as e:
logger.error(f"Error accepting WebSocket connection: {e}") logger.error(f"Error in main: {str(e)}")
return False traceback.print_exc(file=sys.__stderr__)
finally:
def disconnect(self, websocket: WebSocket): # Flush any remaining output
if websocket in self.active_connections: queue_handler.flush()
self.active_connections.remove(websocket) # Restore stdout and stderr
sys.stdout = old_stdout
async def send_message(self, websocket: WebSocket, message: Dict[str, Any]): sys.stderr = old_stderr
try:
await websocket.send_json(message)
except Exception as e:
logger.error(f"Error sending message: {e}")
await self.handle_error(websocket, str(e))
async def handle_error(self, websocket: WebSocket, error_message: str):
try:
await websocket.send_json(
{
"type": "chunk",
"chunk": {
"tools": {
"messages": [
{
"content": f"Error: {error_message}",
"status": "error",
}
]
}
},
}
)
except Exception as e:
logger.error(f"Error sending error message: {e}")
manager = ConnectionManager()
except Exception as e:
logger.error(f"Error in thread: {str(e)}")
traceback.print_exc(file=sys.__stderr__)
output_queue.put(f"Error: {str(e)}")
@app.get("/", response_class=HTMLResponse) @app.get("/", response_class=HTMLResponse)
async def get_root(request: Request): async def get_root(request: Request):
@ -105,135 +151,108 @@ async def get_root(request: Request):
"index.html", {"request": request, "server_port": request.url.port or 8080} "index.html", {"request": request, "server_port": request.url.port or 8080}
) )
# Mount static files for js and other assets
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
@app.websocket("/ws") @app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket): async def websocket_endpoint(websocket: WebSocket):
client_id = id(websocket) await websocket.accept()
logger.info(f"New WebSocket connection attempt from client {client_id}") logger.info("WebSocket connection established")
active_connections.append(websocket)
if not await manager.connect(websocket):
logger.error(f"Failed to accept WebSocket connection for client {client_id}")
return
logger.info(f"WebSocket connection accepted for client {client_id}")
try: try:
# Send initial connection success message
await manager.send_message(
websocket,
{
"type": "chunk",
"chunk": {
"agent": {
"messages": [
{"content": "Connected to RA.Aid server", "status": "info"}
]
}
},
},
)
while True: while True:
try: message = await websocket.receive_json()
message = await websocket.receive_json() logger.info(f"Received message: {message}")
logger.debug(f"Received message from client {client_id}: {message}")
if message["type"] == "request": if message["type"] == "request":
await manager.send_message(websocket, {"type": "stream_start"}) content = message["content"]
logger.info(f"Processing request: {content}")
try: # Create queue for output
# Run ra-aid with the request output_queue = queue.Queue()
cmd = ["ra-aid", "-m", message["content"], "--cowboy-mode"]
logger.info(f"Executing command: {' '.join(cmd)}")
process = await asyncio.create_subprocess_exec( # Create and start thread
*cmd, thread = threading.Thread(target=run_ra_aid, args=(content, output_queue))
stdout=asyncio.subprocess.PIPE, thread.start()
stderr=asyncio.subprocess.PIPE,
)
logger.info(f"Process started with PID: {process.pid}")
async def read_stream(stream, is_error=False): try:
while True: # Send stream start
line = await stream.readline() await websocket.send_json({"type": "stream_start"})
if not line:
break
try: while thread.is_alive() or not output_queue.empty():
decoded_line = line.decode().strip() try:
if decoded_line: # Get output with timeout to allow checking thread status
await manager.send_message( line = output_queue.get(timeout=0.1)
websocket, if line and line.strip(): # Only send non-empty messages
{ logger.debug(f"WebSocket sending: {repr(line)}")
"type": "chunk", await websocket.send_json({
"chunk": { "type": "chunk",
"tools" if is_error else "agent": { "chunk": {
"messages": [ "agent": {
{ "messages": [{
"content": decoded_line, "content": line.strip(),
"status": "error" "status": "info"
if is_error }]
else "info", }
} }
] })
} except queue.Empty:
}, await asyncio.sleep(0.1)
}, except Exception as e:
) logger.error(f"WebSocket error: {e}")
except Exception as e: traceback.print_exc(file=sys.__stderr__)
logger.error(f"Error processing output: {e}")
# Create tasks for reading stdout and stderr # Wait for thread to finish
stdout_task = asyncio.create_task(read_stream(process.stdout)) thread.join()
stderr_task = asyncio.create_task( logger.info("Thread finished")
read_stream(process.stderr, True)
)
# Wait for both streams to complete # Send stream end
await asyncio.gather(stdout_task, stderr_task) await websocket.send_json({"type": "stream_end"})
logger.info("Sent stream_end message")
# Wait for process to complete except Exception as e:
return_code = await process.wait() error_msg = f"Error running ra-aid: {str(e)}"
logger.error(error_msg)
await websocket.send_json({
"type": "error",
"message": error_msg
})
if return_code != 0: logger.info("Waiting for message...")
await manager.handle_error(
websocket, f"Process exited with code {return_code}"
)
await manager.send_message(
websocket,
{"type": "stream_end", "request": message["content"]},
)
except Exception as e:
logger.error(f"Error executing ra-aid: {e}")
await manager.handle_error(websocket, str(e))
except Exception as e:
logger.error(f"Error processing message: {e}")
await manager.handle_error(websocket, str(e))
except WebSocketDisconnect: except WebSocketDisconnect:
logger.info(f"WebSocket client {client_id} disconnected") logger.info("WebSocket client disconnected")
active_connections.remove(websocket)
except Exception as e: except Exception as e:
logger.error(f"WebSocket error for client {client_id}: {e}") logger.error(f"WebSocket error: {e}")
traceback.print_exc()
finally: finally:
manager.disconnect(websocket) if websocket in active_connections:
logger.info(f"WebSocket connection cleaned up for client {client_id}") active_connections.remove(websocket)
logger.info("WebSocket connection closed")
@app.get("/config")
async def get_config(request: Request):
"""Return server configuration including host and port."""
return {"host": request.client.host, "port": request.scope.get("server")[1]}
def run_server(host: str = "0.0.0.0", port: int = 8080): def run_server(host: str = "0.0.0.0", port: int = 8080):
"""Run the FastAPI server.""" """Run the FastAPI server."""
logger.info(f"Starting server on {host}:{port}") uvicorn.run(app, host=host, port=port)
uvicorn.run(
app,
host=host, if __name__ == "__main__":
port=port, import argparse
log_level="debug", parser = argparse.ArgumentParser(description="RA.Aid Web Interface Server")
ws_max_size=16777216, # 16MB parser.add_argument(
timeout_keep_alive=0, # Disable keep-alive timeout "--port", type=int, default=8080, help="Port to listen on (default: 8080)"
) )
parser.add_argument(
"--host",
type=str,
default="0.0.0.0",
help="Host to listen on (default: 0.0.0.0)",
)
args = parser.parse_args()
run_server(host=args.host, port=args.port)

View File

@ -1,90 +1,214 @@
<!DOCTYPE html> <!DOCTYPE html>
<html lang="en" class="h-full bg-gray-900"> <html lang="en">
<head> <head>
<meta charset="UTF-8"> <meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0"> <meta name="viewport" content="width=device-width, initial-scale=1.0">
<meta name="server-port" content="{{ server_port }}"> <title>RA.Aid Web UI</title>
<title>RA.Aid Web Interface</title>
<script src="https://cdn.tailwindcss.com"></script> <script src="https://cdn.tailwindcss.com"></script>
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/styles/tokyo-night-dark.min.css">
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/highlight.min.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/languages/python.min.js"></script>
<script> <script>
tailwind.config = { tailwind.config = {
darkMode: 'class',
theme: { theme: {
extend: { extend: {
colors: { colors: {
'dark-primary': '#1a1b26', 'dark-primary': '#1a1b26',
'dark-secondary': '#24283b', 'dark-secondary': '#24283b',
'dark-accent': '#7aa2f7', 'dark-accent': '#414868',
'dark-text': '#c0caf5' 'dark-text': '#a9b1d6',
'dark-background': '#16161e'
} }
} }
} }
} }
</script> </script>
<style>
/* Core styles */
body {
background-color: #16161e;
color: #a9b1d6;
margin: 0;
padding: 0;
height: 100vh;
display: flex;
flex-direction: column;
}
/* Output area */
#stream-output {
flex: 1;
overflow-y: auto;
padding: 1rem;
background-color: #1a1b26;
margin: 0;
display: block !important;
min-height: 300px;
font-family: monospace;
white-space: pre-wrap;
word-break: break-word;
}
/* Message styling */
.message {
margin-bottom: 1.5rem;
padding: 0.75rem 1rem;
border-radius: 0.375rem;
background-color: #24283b;
color: #a9b1d6;
display: block !important;
border: 1px solid #414868;
opacity: 0;
transform: translateY(20px);
animation: slideIn 0.3s ease forwards;
}
@keyframes slideIn {
to {
opacity: 1;
transform: translateY(0);
}
}
/* Message types */
.message.info {
border-color: #7aa2f7;
}
.message.success {
border-color: #9ece6a;
background-color: rgba(158, 206, 106, 0.1);
}
.message.error {
border-color: #f7768e;
background-color: rgba(247, 118, 142, 0.1);
}
.message.warning {
border-color: #e0af68;
background-color: rgba(224, 175, 104, 0.1);
}
/* Section spacing */
.message.section {
margin-top: 2rem;
margin-bottom: 2rem;
padding: 1rem;
background-color: rgba(122, 162, 247, 0.1);
}
/* Code blocks */
.message pre {
margin: 0.5rem 0;
padding: 0.75rem;
border-radius: 0.25rem;
background-color: #1a1b26 !important;
border: 1px solid #414868;
overflow-x: auto;
}
.message code {
font-family: 'JetBrains Mono', monospace;
font-size: 0.9em;
}
/* Input area */
.input-area {
padding: 1rem;
background-color: #24283b;
border-top: 1px solid #414868;
}
.input-container {
display: flex;
gap: 1rem;
align-items: center;
}
#user-input {
flex: 1;
padding: 0.75rem;
background-color: #1a1b26;
border: 1px solid #414868;
border-radius: 0.375rem;
color: #a9b1d6;
font-family: 'JetBrains Mono', monospace;
transition: border-color 0.2s;
}
#user-input:focus {
outline: none;
border-color: #7aa2f7;
}
/* Buttons */
.button {
padding: 0.75rem 1rem;
background-color: #7aa2f7;
color: white;
border: none;
border-radius: 0.375rem;
cursor: pointer;
min-width: 80px;
transition: background-color 0.2s;
display: flex;
align-items: center;
justify-content: center;
gap: 0.5rem;
}
.button:hover {
background-color: #5d87e6;
}
.button:disabled {
opacity: 0.5;
cursor: not-allowed;
}
.button.clear {
background-color: #414868;
}
.button.clear:hover {
background-color: #363b54;
}
/* Icons */
.icon {
font-size: 1.2em;
line-height: 1;
}
</style>
</head> </head>
<body class="h-full bg-dark-primary text-dark-text"> <body>
<div class="flex h-full"> <!-- Main Container -->
<!-- Sidebar --> <div id="stream-output">
<div class="w-64 bg-dark-secondary border-r border-gray-700 flex flex-col"> <!-- Messages will appear here -->
<div class="p-4 border-b border-gray-700"> </div>
<h2 class="text-xl font-semibold text-dark-accent">History</h2>
</div>
<div id="history-list" class="flex-1 overflow-y-auto p-4 space-y-2"></div>
</div>
<!-- Main Content --> <!-- Input Area -->
<div class="flex-1 flex flex-col min-w-0"> <div class="input-area">
<!-- Chat Container --> <div class="input-container">
<div class="flex-1 overflow-y-auto p-4 space-y-4" id="chat-container"> <input
<div id="chat-messages"></div> type="text"
<div id="stream-output" class="hidden font-mono bg-dark-secondary rounded-lg p-4 text-sm"></div> id="user-input"
</div> placeholder="Type your message..."
>
<!-- Input Area --> <button id="clear-button" class="button clear" title="Clear conversation">
<div class="border-t border-gray-700 p-4 bg-dark-secondary"> <span class="icon">🗑️</span>
<div class="flex space-x-4"> </button>
<textarea <button id="send-button" class="button">
id="user-input" <span class="icon">📤</span>
class="flex-1 bg-dark-primary border border-gray-700 rounded-lg p-3 text-dark-text placeholder-gray-500 focus:outline-none focus:ring-2 focus:ring-dark-accent resize-none" Send
placeholder="Type your request here..." </button>
rows="3"
></textarea>
<button
id="send-button"
class="px-6 py-2 bg-dark-accent text-white rounded-lg hover:bg-opacity-90 focus:outline-none focus:ring-2 focus:ring-offset-2 focus:ring-dark-accent disabled:opacity-50 disabled:cursor-not-allowed h-fit"
>
Send
</button>
</div>
</div>
</div> </div>
</div> </div>
<script>
// Add dynamic styles for messages
const style = document.createElement('style');
style.textContent = `
.message {
@apply mb-4 p-4 rounded-lg max-w-3xl;
}
.user-message {
@apply bg-dark-accent text-white ml-auto;
}
.system-message {
@apply bg-dark-secondary mr-auto;
}
.error-message {
@apply bg-red-900 text-red-100 mr-auto;
}
.history-item {
@apply p-3 rounded-lg hover:bg-dark-primary cursor-pointer transition-colors duration-200 text-sm;
}
#stream-output:not(:empty) {
@apply block;
}
`;
document.head.appendChild(style);
</script>
<script src="/static/script.js"></script> <script src="/static/script.js"></script>
<script>
hljs.highlightAll();
</script>
</body> </body>
</html> </html>

View File

@ -1,206 +1,230 @@
class RAWebUI { class WebSocketHandler {
constructor() { constructor() {
this.messageHistory = []; // Wait for DOM to be ready
this.connectionAttempts = 0; if (document.readyState === 'loading') {
this.maxReconnectAttempts = 5; document.addEventListener('DOMContentLoaded', () => this.initialize());
this.setupElements(); } else {
this.setupEventListeners(); this.initialize();
this.connectWebSocket(); }
} }
setupElements() { initialize() {
this.userInput = document.getElementById('user-input'); // Store DOM elements as instance variables
this.messageInput = document.getElementById('user-input');
this.sendButton = document.getElementById('send-button'); this.sendButton = document.getElementById('send-button');
this.chatMessages = document.getElementById('chat-messages'); this.clearButton = document.getElementById('clear-button');
this.streamOutput = document.getElementById('stream-output'); this.streamOutput = document.getElementById('stream-output');
this.historyList = document.getElementById('history-list');
// Disable send button initially
this.sendButton.disabled = true;
}
setupEventListeners() { // Validate required elements exist
if (!this.messageInput || !this.sendButton || !this.streamOutput) {
console.error('Required elements not found:', {
messageInput: !!this.messageInput,
sendButton: !!this.sendButton,
streamOutput: !!this.streamOutput
});
return;
}
// Remove hidden class if present
this.streamOutput.classList.remove('hidden');
// Initialize WebSocket
this.connectWebSocket();
// Add event listeners
this.sendButton.addEventListener('click', () => this.sendMessage()); this.sendButton.addEventListener('click', () => this.sendMessage());
this.userInput.addEventListener('keypress', (e) => { this.clearButton?.addEventListener('click', () => this.clearConversation());
this.messageInput.addEventListener('keypress', (e) => {
if (e.key === 'Enter' && !e.shiftKey) { if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault(); e.preventDefault();
this.sendMessage(); this.sendMessage();
} }
}); });
console.log('WebSocketHandler initialized with elements:', {
messageInput: this.messageInput,
sendButton: this.sendButton,
streamOutput: this.streamOutput
});
} }
async connectWebSocket() { connectWebSocket() {
// Don't try to reconnect if we've exceeded the maximum attempts
if (this.connectionAttempts >= this.maxReconnectAttempts) {
this.appendMessage(
'Maximum reconnection attempts reached. Please refresh the page.',
'error'
);
return;
}
try { try {
// Get the server port from the meta tag const wsUrl = `ws://${window.location.host}/ws`;
const serverPort = document.querySelector('meta[name="server-port"]')?.content || '8080';
// Construct WebSocket URL
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
const wsUrl = `${protocol}//${window.location.hostname}:${serverPort}/ws`;
console.log('Attempting to connect to WebSocket URL:', wsUrl); console.log('Attempting to connect to WebSocket URL:', wsUrl);
// Close existing connection if any
if (this.ws) {
this.ws.close();
}
// Create new WebSocket connection
console.log('Creating new WebSocket connection...');
this.ws = new WebSocket(wsUrl); this.ws = new WebSocket(wsUrl);
this.connectionAttempts++;
// Setup WebSocket event handlers
this.ws.onopen = () => { this.ws.onopen = () => {
console.log('WebSocket connection established successfully'); console.log('WebSocket connection established successfully');
this.connectionAttempts = 0; // Reset counter on successful connection
this.sendButton.disabled = false; this.sendButton.disabled = false;
}; };
this.ws.onclose = (event) => { this.ws.onclose = () => {
console.log('WebSocket closed:', event.code, event.reason); console.log('WebSocket connection closed');
this.sendButton.disabled = true; this.sendButton.disabled = true;
// Try to reconnect after a delay
setTimeout(() => this.connectWebSocket(), 2000);
};
// Only attempt reconnect if not a normal closure and within retry limits this.ws.onmessage = (event) => {
if (event.code !== 1000 && this.connectionAttempts < this.maxReconnectAttempts) { try {
const delay = Math.min(1000 * Math.pow(2, this.connectionAttempts), 10000); console.log('Raw WebSocket message:', event.data);
this.appendMessage( const message = JSON.parse(event.data);
`Connection lost. Reconnecting in ${delay/1000} seconds...`, console.log('Parsed WebSocket message:', message);
'error' this.handleMessage(message);
); } catch (error) {
setTimeout(() => this.connectWebSocket(), delay); console.error('Error handling message:', error);
this.appendOutput({
content: `Error: ${error.message}`,
status: 'error'
});
} }
}; };
this.ws.onerror = (error) => { this.ws.onerror = (error) => {
console.error('WebSocket error:', error); console.error('WebSocket error:', error);
this.sendButton.disabled = true;
this.appendOutput({
content: 'Connection error. Retrying...',
status: 'error'
});
}; };
this.ws.onmessage = (event) => {
try {
const data = JSON.parse(event.data);
this.handleServerMessage(data);
} catch (error) {
console.error('Error parsing message:', error);
this.appendMessage('Error processing server message', 'error');
}
};
} catch (error) { } catch (error) {
console.error('Failed to connect to WebSocket:', error); console.error('Error connecting to WebSocket:', error);
this.appendMessage( this.appendOutput({
`Connection error: ${error.message}. Retrying...`, content: `Connection error: ${error.message}`,
'error' status: 'error'
); });
// Attempt to reconnect with exponential backoff
const delay = Math.min(1000 * Math.pow(2, this.connectionAttempts), 10000);
setTimeout(() => this.connectWebSocket(), delay);
} }
} }
handleServerMessage(data) { handleMessage(message) {
if (data.type === 'stream_start') { switch (message.type) {
this.streamOutput.textContent = ''; case 'stream_start':
this.streamOutput.style.display = 'block'; this.handleStreamStart();
} else if (data.type === 'stream_end') { break;
this.streamOutput.style.display = 'none'; case 'chunk':
this.addToHistory(data.request); this.handleChunk(message.chunk);
this.sendButton.disabled = false; break;
} else if (data.type === 'chunk') { case 'stream_end':
this.handleChunk(data.chunk); this.handleStreamEnd();
break;
default:
console.warn('Unknown message type:', message.type);
} }
} }
handleStreamStart() {
console.log('Stream starting');
this.clearStreamOutput();
this.appendOutput({
content: 'Starting new conversation...',
status: 'info'
});
}
handleStreamEnd() {
console.log('Stream ending');
this.appendOutput({
content: 'Conversation complete.',
status: 'success'
});
this.messageInput.disabled = false;
this.sendButton.disabled = false;
}
handleChunk(chunk) { handleChunk(chunk) {
console.log(' Processing chunk:', chunk);
if (chunk.agent && chunk.agent.messages) { if (chunk.agent && chunk.agent.messages) {
chunk.agent.messages.forEach(msg => { chunk.agent.messages.forEach(message => {
if (msg.content) { console.log(' Processing agent message:', message);
if (Array.isArray(msg.content)) { console.log(' Adding agent message:', message.content);
msg.content.forEach(content => { this.appendOutput(message);
if (content.type === 'text' && content.text.trim()) {
this.appendMessage(content.text.trim(), 'system');
}
});
} else if (msg.content.trim()) {
this.appendMessage(msg.content.trim(), 'system');
}
}
});
} else if (chunk.tools && chunk.tools.messages) {
chunk.tools.messages.forEach(msg => {
if (msg.status === 'error' && msg.content) {
this.appendMessage(msg.content.trim(), 'error');
}
}); });
} }
} }
appendMessage(content, type) { clearStreamOutput() {
const messageDiv = document.createElement('div'); console.log('Clearing stream output');
messageDiv.className = `message ${type}-message`; while (this.streamOutput.firstChild) {
messageDiv.textContent = content; this.streamOutput.removeChild(this.streamOutput.firstChild);
this.chatMessages.appendChild(messageDiv); }
this.chatMessages.scrollTop = this.chatMessages.scrollHeight; console.log('Stream output cleared');
} }
addToHistory(request) { clearConversation() {
const historyItem = document.createElement('div'); this.clearStreamOutput();
historyItem.className = 'history-item'; this.appendOutput({
historyItem.textContent = request.slice(0, 50) + (request.length > 50 ? '...' : ''); content: 'Conversation cleared.',
historyItem.title = request; status: 'info'
historyItem.addEventListener('click', () => { });
this.userInput.value = request; }
this.userInput.focus();
appendOutput(message) {
console.log(' Appending output:', message);
const messageDiv = document.createElement('div');
messageDiv.className = `message ${message.status || 'info'}`;
const contentSpan = document.createElement('span');
// Convert ANSI escape codes to HTML
let content = message.content;
content = this.convertAnsiToHtml(content);
// Check for code blocks and apply syntax highlighting
if (content.includes('```')) {
content = this.highlightCodeBlocks(content);
}
contentSpan.innerHTML = content;
messageDiv.appendChild(contentSpan);
this.streamOutput.appendChild(messageDiv);
messageDiv.scrollIntoView({ behavior: 'smooth', block: 'end' });
}
convertAnsiToHtml(text) {
// ANSI color codes to CSS classes
const ansiToClass = {
'[94m': '<span class="text-blue-400">',
'[1;32m': '<span class="text-green-400 font-bold">',
'[0m': '</span>'
};
// Replace ANSI codes with HTML
let html = text;
for (const [ansi, htmlClass] of Object.entries(ansiToClass)) {
html = html.replaceAll('\u001b' + ansi, htmlClass);
}
return html;
}
highlightCodeBlocks(content) {
const codeBlockRegex = /```(\w+)?\n([\s\S]*?)```/g;
return content.replace(codeBlockRegex, (match, lang, code) => {
const language = lang || 'plaintext';
const highlighted = hljs.highlight(code.trim(), { language }).value;
return `<pre><code class="language-${language}">${highlighted}</code></pre>`;
}); });
this.historyList.insertBefore(historyItem, this.historyList.firstChild);
this.messageHistory.push(request);
} }
sendMessage() { sendMessage() {
console.log('Send button clicked'); const message = this.messageInput.value.trim();
const message = this.userInput.value.trim(); if (!message) return;
console.log('Message content:', message);
if (!message) {
console.log('Message is empty, not sending');
return;
}
if (!this.ws || this.ws.readyState !== WebSocket.OPEN) { console.log('Sending message:', message);
console.error('WebSocket is not connected'); this.ws.send(JSON.stringify({
this.appendMessage('Error: Not connected to server. Please wait...', 'error'); type: "request",
return; content: message
} }));
this.messageInput.value = '';
try { this.messageInput.disabled = true;
console.log('Sending message to server'); this.sendButton.disabled = true;
this.appendMessage(message, 'user');
const payload = { type: 'request', content: message };
console.log('Payload:', payload);
this.ws.send(JSON.stringify(payload));
console.log('Message sent successfully');
this.userInput.value = '';
this.sendButton.disabled = true;
} catch (error) {
console.error('Error sending message:', error);
this.appendMessage(`Error sending message: ${error.message}`, 'error');
this.sendButton.disabled = false;
}
} }
} }
// Initialize the UI when the page loads // Initialize WebSocket handler when the page loads
document.addEventListener('DOMContentLoaded', () => { window.addEventListener('load', () => {
window.raWebUI = new RAWebUI(); new WebSocketHandler();
}); });

View File

@ -0,0 +1,66 @@
"""Tests for Windows-specific functionality."""
import os
import sys
import subprocess
import pytest
from unittest.mock import patch, MagicMock
from ra_aid.proc.interactive import get_terminal_size, create_process, run_interactive_command
@pytest.mark.skipif(sys.platform != "win32", reason="Windows-specific tests")
class TestWindowsCompatibility:
"""Test suite for Windows-specific functionality."""
def test_get_terminal_size(self):
"""Test terminal size detection on Windows."""
with patch('shutil.get_terminal_size') as mock_get_size:
mock_get_size.return_value = MagicMock(columns=120, lines=30)
cols, rows = get_terminal_size()
assert cols == 120
assert rows == 30
mock_get_size.assert_called_once()
def test_create_process(self):
"""Test process creation on Windows."""
with patch('subprocess.Popen') as mock_popen:
mock_process = MagicMock()
mock_process.returncode = 0
mock_popen.return_value = mock_process
proc, _ = create_process(['echo', 'test'])
assert mock_popen.called
args, kwargs = mock_popen.call_args
assert kwargs['stdin'] == subprocess.PIPE
assert kwargs['stdout'] == subprocess.PIPE
assert kwargs['stderr'] == subprocess.PIPE
assert 'startupinfo' in kwargs
assert kwargs['startupinfo'].dwFlags & subprocess.STARTF_USESHOWWINDOW
def test_run_interactive_command(self):
"""Test running an interactive command on Windows."""
test_output = "Test output\n"
with patch('subprocess.Popen') as mock_popen:
mock_process = MagicMock()
mock_process.stdout = MagicMock()
mock_process.stdout.read.return_value = test_output.encode()
mock_process.wait.return_value = 0
mock_popen.return_value = mock_process
output, return_code = run_interactive_command(['echo', 'test'])
assert return_code == 0
assert "Test output" in output.decode()
def test_windows_dependencies(self):
"""Test that required Windows dependencies are available."""
if sys.platform == "win32":
import msvcrt
import win32pipe
import win32file
import win32con
import win32process
# If we get here without ImportError, the test passes
assert True