Add Windows compatibility improvements (#105)

* Add Windows compatibility improvements

1. Add error handling for Windows-specific modules
2. Update README with Windows installation instructions
3. Add Windows-specific tests
4. Improve cross-platform support in interactive.py

* Fix: Add missing subprocess import in Windows compatibility tests

* Improve Windows compatibility:

1. Add detailed error handling for Windows I/O operations
2. Enhance timeout messages with more descriptive information
3. Add comprehensive comments explaining Windows-specific code

* Fix WebUI: Improve message display, add syntax highlighting, animations, and fix WebSocket communication
This commit is contained in:
Mark Varkevisser 2025-02-25 10:41:29 -08:00 committed by GitHub
parent 2f132bdeb5
commit d163a74c47
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 866 additions and 520 deletions

View File

@ -97,6 +97,31 @@ What sets RA.Aid apart is its ability to handle complex programming tasks that e
## Installation
### Windows Installation
1. Install Python 3.8 or higher from [python.org](https://www.python.org/downloads/)
2. Install required system dependencies:
```powershell
# Install Chocolatey if not already installed (run in admin PowerShell)
Set-ExecutionPolicy Bypass -Scope Process -Force; [System.Net.ServicePointManager]::SecurityProtocol = [System.Net.ServicePointManager]::SecurityProtocol -bor 3072; iex ((New-Object System.Net.WebClient).DownloadString('https://community.chocolatey.org/install.ps1'))
# Install ripgrep using Chocolatey
choco install ripgrep
```
3. Install RA.Aid:
```powershell
pip install ra-aid
```
4. Install Windows-specific dependencies:
```powershell
pip install pywin32
```
5. Set up your API keys in a `.env` file:
```env
ANTHROPIC_API_KEY=your_anthropic_key
OPENAI_API_KEY=your_openai_key
```
### Unix/Linux Installation
RA.Aid can be installed directly using pip:
```bash

View File

@ -2,6 +2,7 @@
import os
import sys
import subprocess
from abc import ABC, abstractmethod
from ra_aid import print_error
@ -21,8 +22,13 @@ class RipGrepDependency(Dependency):
def check(self):
"""Check if ripgrep is installed."""
result = os.system("rg --version > /dev/null 2>&1")
if result != 0:
try:
result = subprocess.run(['rg', '--version'],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL)
if result.returncode != 0:
raise FileNotFoundError()
except (FileNotFoundError, subprocess.SubprocessError):
print_error("Required dependency 'ripgrep' is not installed.")
print("Please install ripgrep:")
print(" - Ubuntu/Debian: sudo apt-get install ripgrep")

View File

@ -3,51 +3,123 @@
Module for running interactive subprocesses with output capture,
with full raw input passthrough for interactive commands.
It uses a pseudo-tty and integrates pyte's HistoryScreen to simulate
It uses a pseudo-tty on Unix systems and direct pipes on Windows to simulate
a terminal and capture the final scrollback history (non-blank lines).
The interface remains compatible with external callers expecting a tuple (output, return_code),
where output is a bytes object (UTF-8 encoded).
"""
import errno
import io
import os
import select
import shutil
import signal
import subprocess
import sys
import termios
import time
import tty
from typing import List, Tuple
import shutil
from typing import List, Tuple, Optional
import pyte
from pyte.screens import HistoryScreen
# Windows-specific imports
if sys.platform == "win32":
try:
# msvcrt: Provides Windows console I/O functionality
import msvcrt
# win32pipe, win32file: For low-level pipe operations
import win32pipe
import win32file
# win32con: Windows API constants
import win32con
# win32process: Process management on Windows
import win32process
except ImportError as e:
print("Error: Required Windows dependencies not found.")
print("Please install the required packages using:")
print(" pip install pywin32")
sys.exit(1)
else:
# Unix-specific imports for terminal handling
import termios
import fcntl
import pty
def render_line(line, columns: int) -> str:
"""Render a single screen line from the pyte buffer (a mapping of column to Char)."""
return "".join(line[x].data for x in range(columns))
def get_terminal_size():
"""Get the current terminal size."""
if sys.platform == "win32":
import shutil
size = shutil.get_terminal_size()
return size.columns, size.lines
else:
import struct
try:
with open(sys.stdout.fileno(), 'wb', buffering=0) as fd:
size = struct.unpack('hh', fcntl.ioctl(fd, termios.TIOCGWINSZ, '1234'))
return size[1], size[0]
except (IOError, AttributeError):
return 80, 24
def create_process(cmd: List[str]) -> Tuple[subprocess.Popen, Optional[int]]:
"""Create a subprocess with appropriate handling for the platform.
On Windows:
- Uses STARTUPINFO to hide the console window
- Creates a new process group for proper signal handling
- Returns direct pipe handles for I/O
On Unix:
- Creates a pseudo-terminal (PTY) for proper terminal emulation
- Sets up process group for signal handling
- Returns master PTY file descriptor for I/O
"""
if sys.platform == "win32":
# Windows process creation with hidden console
startupinfo = subprocess.STARTUPINFO()
startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW # Hide the console window
# Create process with proper pipe handling
proc = subprocess.Popen(
cmd,
stdin=subprocess.PIPE, # Allow writing to stdin
stdout=subprocess.PIPE, # Capture stdout
stderr=subprocess.PIPE, # Capture stderr
startupinfo=startupinfo,
# CREATE_NEW_PROCESS_GROUP allows proper Ctrl+C handling
creationflags=subprocess.CREATE_NEW_PROCESS_GROUP
)
return proc, None # No PTY master_fd needed on Windows
else:
# Unix process creation with PTY
master_fd, slave_fd = pty.openpty()
proc = subprocess.Popen(
cmd,
stdin=slave_fd,
stdout=slave_fd,
stderr=slave_fd,
preexec_fn=os.setsid
)
os.close(slave_fd)
return proc, master_fd
def run_interactive_command(
cmd: List[str], expected_runtime_seconds: int = 30
cmd: List[str],
expected_runtime_seconds: int = 1800,
ratio: float = 0.5
) -> Tuple[bytes, int]:
"""
Runs an interactive command with a pseudo-tty, capturing final scrollback history.
Assumptions and constraints:
- Running on a Linux system.
- Running on a Linux system or Windows.
- `cmd` is a non-empty list where cmd[0] is the executable.
- The executable is on PATH.
Args:
cmd: A list containing the command and its arguments.
expected_runtime_seconds: Expected runtime in seconds, defaults to 30.
expected_runtime_seconds: Expected runtime in seconds, defaults to 1800.
If process exceeds 2x this value, it will be terminated gracefully.
If process exceeds 3x this value, it will be killed forcefully.
Must be between 1 and 1800 seconds (30 minutes).
ratio: Ratio of history to keep from top vs bottom (default: 0.5)
Returns:
A tuple of (captured_output, return_code), where captured_output is a UTF-8 encoded
@ -69,154 +141,164 @@ def run_interactive_command(
)
try:
term_size = os.get_terminal_size()
term_size = get_terminal_size()
cols, rows = term_size.columns, term_size.lines
except OSError:
cols, rows = 80, 24
# Set up pyte screen and stream to capture terminal output.
screen = HistoryScreen(cols, rows, history=2000, ratio=0.5)
screen = pyte.HistoryScreen(cols, rows, history=2000, ratio=ratio)
stream = pyte.Stream(screen)
# Open a new pseudo-tty.
master_fd, slave_fd = os.openpty()
# Set master_fd to non-blocking to avoid indefinite blocking.
os.set_blocking(master_fd, False)
try:
stdin_fd = sys.stdin.fileno()
except (AttributeError, io.UnsupportedOperation):
stdin_fd = None
# Set up environment variables for the subprocess using detected terminal size.
env = os.environ.copy()
env.update(
{
"DEBIAN_FRONTEND": "noninteractive",
"GIT_PAGER": "",
"PYTHONUNBUFFERED": "1",
"CI": "true",
"LANG": "C.UTF-8",
"LC_ALL": "C.UTF-8",
"COLUMNS": str(cols),
"LINES": str(rows),
"FORCE_COLOR": "1",
"GIT_TERMINAL_PROMPT": "0",
"PYTHONDONTWRITEBYTECODE": "1",
"NODE_OPTIONS": "--unhandled-rejections=strict",
}
)
proc = subprocess.Popen(
cmd,
stdin=slave_fd,
stdout=slave_fd,
stderr=slave_fd,
bufsize=0,
close_fds=True,
env=env,
preexec_fn=os.setsid, # Create new process group for proper signal handling.
)
os.close(slave_fd) # Close slave end in the parent process.
proc, master_fd = create_process(cmd)
captured_data = []
start_time = time.time()
was_terminated = False
timeout_type = None
def check_timeout():
nonlocal timeout_type
elapsed = time.time() - start_time
if elapsed > 3 * expected_runtime_seconds:
os.killpg(os.getpgid(proc.pid), signal.SIGKILL)
if sys.platform == "win32":
print("\nProcess exceeded hard timeout limit, forcefully terminating...")
proc.terminate()
time.sleep(0.5)
if proc.poll() is None:
print("Process did not respond to termination, killing...")
proc.kill()
else:
print("\nProcess exceeded hard timeout limit, sending SIGKILL...")
os.killpg(os.getpgid(proc.pid), signal.SIGKILL)
timeout_type = "hard_timeout"
return True
elif elapsed > 2 * expected_runtime_seconds:
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
if sys.platform == "win32":
print("\nProcess exceeded soft timeout limit, attempting graceful termination...")
proc.terminate()
else:
print("\nProcess exceeded soft timeout limit, sending SIGTERM...")
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
timeout_type = "soft_timeout"
return True
return False
# Interactive mode: forward input if running in a TTY.
if stdin_fd is not None and sys.stdin.isatty():
old_settings = termios.tcgetattr(stdin_fd)
tty.setraw(stdin_fd)
if sys.platform == "win32":
# Windows handling
try:
while True:
if check_timeout():
was_terminated = True
break
# Use a finite timeout to avoid indefinite blocking.
rlist, _, _ = select.select([master_fd, stdin_fd], [], [], 1.0)
if master_fd in rlist:
try:
data = os.read(master_fd, 1024)
except OSError as e:
if e.errno == errno.EIO:
try:
# Check stdout with proper error handling
stdout_data = proc.stdout.read1(1024)
if stdout_data:
captured_data.append(stdout_data)
try:
stream.feed(stdout_data.decode(errors='ignore'))
except Exception as e:
print(f"Warning: Error processing stdout: {e}")
# Check stderr with proper error handling
stderr_data = proc.stderr.read1(1024)
if stderr_data:
captured_data.append(stderr_data)
try:
stream.feed(stderr_data.decode(errors='ignore'))
except Exception as e:
print(f"Warning: Error processing stderr: {e}")
# Check for input with proper error handling
if msvcrt.kbhit():
try:
char = msvcrt.getch()
proc.stdin.write(char)
proc.stdin.flush()
except (IOError, OSError) as e:
print(f"Warning: Error handling keyboard input: {e}")
break
else:
raise
if not data: # EOF detected.
except (IOError, OSError) as e:
if isinstance(e, OSError) and e.winerror == 6: # Invalid handle
break
captured_data.append(data)
decoded = data.decode("utf-8", errors="ignore")
stream.feed(decoded)
os.write(1, data)
if stdin_fd in rlist:
print(f"Warning: I/O error during process communication: {e}")
break
except Exception as e:
print(f"Error in Windows process handling: {e}")
proc.terminate()
else:
# Unix handling
import tty
try:
old_settings = termios.tcgetattr(sys.stdin)
tty.setraw(sys.stdin)
while True:
if check_timeout():
was_terminated = True
break
rlist, _, _ = select.select([master_fd, sys.stdin], [], [], 0.1)
for fd in rlist:
try:
input_data = os.read(stdin_fd, 1024)
except OSError:
input_data = b""
if input_data:
os.write(master_fd, input_data)
if fd == master_fd:
data = os.read(master_fd, 1024)
if not data:
break
captured_data.append(data)
stream.feed(data.decode(errors='ignore'))
else:
data = os.read(fd, 1024)
os.write(master_fd, data)
except (IOError, OSError):
break
except KeyboardInterrupt:
proc.terminate()
finally:
termios.tcsetattr(stdin_fd, termios.TCSADRAIN, old_settings)
else:
# Non-interactive mode.
try:
while True:
if check_timeout():
was_terminated = True
break
rlist, _, _ = select.select([master_fd], [], [], 1.0)
if not rlist:
continue
try:
data = os.read(master_fd, 1024)
except OSError as e:
if e.errno == errno.EIO:
break
else:
raise
if not data: # EOF detected.
break
captured_data.append(data)
decoded = data.decode("utf-8", errors="ignore")
stream.feed(decoded)
os.write(1, data)
except KeyboardInterrupt:
proc.terminate()
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings)
if not sys.platform == "win32" and master_fd is not None:
os.close(master_fd)
os.close(master_fd)
proc.wait()
# Assemble full scrollback: combine history.top, the current display, and history.bottom.
top_lines = [render_line(line, cols) for line in screen.history.top]
def render_line(line, width):
return ''.join(char.data for char in line[:width]).rstrip()
# Combine history and current screen content
final_output = []
# Add lines from history
history_lines = [render_line(line, cols) for line in screen.history.top]
final_output.extend(line for line in history_lines if line.strip())
# Add current screen content
screen_lines = [render_line(line, cols) for line in screen.display]
final_output.extend(line for line in screen_lines if line.strip())
# Add bottom history
bottom_lines = [render_line(line, cols) for line in screen.history.bottom]
display_lines = screen.display # List of strings representing the current screen.
all_lines = top_lines + display_lines + bottom_lines
final_output.extend(line for line in bottom_lines if line.strip())
# Trim out empty lines to get only meaningful "history" lines.
trimmed_lines = [line for line in all_lines if line.strip()]
final_output = "\n".join(trimmed_lines)
# Add timeout message if process was terminated due to timeout.
# Add timeout message if process was terminated
if was_terminated:
timeout_msg = f"\n[Process exceeded timeout ({expected_runtime_seconds} seconds expected)]"
final_output += timeout_msg
if timeout_type == "hard_timeout":
timeout_msg = f"\n[Process forcefully terminated after exceeding {3 * expected_runtime_seconds:.1f} seconds (expected: {expected_runtime_seconds} seconds)]"
else:
timeout_msg = f"\n[Process gracefully terminated after exceeding {2 * expected_runtime_seconds:.1f} seconds (expected: {expected_runtime_seconds} seconds)]"
final_output.append(timeout_msg)
# Limit output to the last 8000 bytes.
# Limit output size
final_output = final_output[-8000:]
final_output = '\n'.join(final_output)
return final_output.encode("utf-8"), proc.returncode
return final_output.encode('utf-8'), proc.returncode
if __name__ == "__main__":

View File

@ -1,29 +1,37 @@
"""Web interface server implementation for RA.Aid."""
import asyncio
import logging
import shutil
#!/usr/bin/env python3
import sys
import os
from pathlib import Path
from typing import Any, Dict, List
import asyncio
from typing import List
import json
import threading
import queue
import traceback
import shutil
import logging
import uvicorn
from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.__stderr__) # Use the real stderr
]
)
logger = logging.getLogger(__name__)
# Add project root to Python path
project_root = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
if project_root not in sys.path:
sys.path.insert(0, project_root)
from fastapi import FastAPI, WebSocket, Request, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
# Configure logging
logging.basicConfig(level=logging.DEBUG) # Set to DEBUG for more info
logger = logging.getLogger(__name__)
# Verify ra-aid is available
if not shutil.which("ra-aid"):
logger.error(
"ra-aid command not found. Please ensure it's installed and in your PATH"
)
sys.exit(1)
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
app = FastAPI()
@ -36,67 +44,105 @@ app.add_middleware(
allow_headers=["*"],
)
# Get the directory containing static files
STATIC_DIR = Path(__file__).parent / "static"
if not STATIC_DIR.exists():
logger.error(f"Static directory not found at {STATIC_DIR}")
sys.exit(1)
# Setup templates and static files directories
CURRENT_DIR = Path(__file__).parent
templates = Jinja2Templates(directory=CURRENT_DIR)
logger.info(f"Using static directory: {STATIC_DIR}")
# Mount static files for js and other assets
static_dir = CURRENT_DIR / "static"
app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")
# Setup templates
templates = Jinja2Templates(directory=str(STATIC_DIR))
# Store active WebSocket connections
active_connections: List[WebSocket] = []
def run_ra_aid(message_content, output_queue):
"""Run ra-aid in a separate thread"""
try:
import ra_aid.__main__
logger.info("Successfully imported ra_aid.__main__")
class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
# Override sys.argv
sys.argv = [sys.argv[0], "-m", message_content, "--cowboy-mode"]
logger.info(f"Set sys.argv to: {sys.argv}")
# Create custom output handler
class QueueHandler:
def __init__(self, queue):
self.queue = queue
self.buffer = []
self.box_start = False
self._real_stderr = sys.__stderr__
def write(self, text):
# Always log raw output for debugging
logger.debug(f"Raw output: {repr(text)}")
# Check if this is a box drawing character
if any(c in text for c in '╭╮╰╯│─'):
self.box_start = True
self.buffer.append(text)
elif self.box_start and text.strip():
self.buffer.append(text)
if '' in text: # End of box
full_text = ''.join(self.buffer)
# Extract content from inside the box
lines = full_text.split('\n')
content_lines = []
for line in lines:
# Remove box characters and leading/trailing spaces
clean_line = line.strip('╭╮╰╯│─ ')
if clean_line:
content_lines.append(clean_line)
if content_lines:
self.queue.put('\n'.join(content_lines))
self.buffer = []
self.box_start = False
elif not self.box_start and text.strip():
self.queue.put(text.strip())
def flush(self):
if self.buffer:
full_text = ''.join(self.buffer)
# Extract content from partial box
lines = full_text.split('\n')
content_lines = []
for line in lines:
# Remove box characters and leading/trailing spaces
clean_line = line.strip('╭╮╰╯│─ ')
if clean_line:
content_lines.append(clean_line)
if content_lines:
self.queue.put('\n'.join(content_lines))
self.buffer = []
self.box_start = False
# Replace stdout and stderr
old_stdout = sys.stdout
old_stderr = sys.stderr
queue_handler = QueueHandler(output_queue)
sys.stdout = queue_handler
sys.stderr = queue_handler
async def connect(self, websocket: WebSocket) -> bool:
try:
logger.debug("Accepting WebSocket connection...")
await websocket.accept()
logger.debug("WebSocket connection accepted")
self.active_connections.append(websocket)
return True
logger.info("Starting ra_aid.main()")
ra_aid.__main__.main()
logger.info("Finished ra_aid.main()")
except SystemExit:
logger.info("Caught SystemExit - this is normal")
except Exception as e:
logger.error(f"Error accepting WebSocket connection: {e}")
return False
def disconnect(self, websocket: WebSocket):
if websocket in self.active_connections:
self.active_connections.remove(websocket)
async def send_message(self, websocket: WebSocket, message: Dict[str, Any]):
try:
await websocket.send_json(message)
except Exception as e:
logger.error(f"Error sending message: {e}")
await self.handle_error(websocket, str(e))
async def handle_error(self, websocket: WebSocket, error_message: str):
try:
await websocket.send_json(
{
"type": "chunk",
"chunk": {
"tools": {
"messages": [
{
"content": f"Error: {error_message}",
"status": "error",
}
]
}
},
}
)
except Exception as e:
logger.error(f"Error sending error message: {e}")
manager = ConnectionManager()
logger.error(f"Error in main: {str(e)}")
traceback.print_exc(file=sys.__stderr__)
finally:
# Flush any remaining output
queue_handler.flush()
# Restore stdout and stderr
sys.stdout = old_stdout
sys.stderr = old_stderr
except Exception as e:
logger.error(f"Error in thread: {str(e)}")
traceback.print_exc(file=sys.__stderr__)
output_queue.put(f"Error: {str(e)}")
@app.get("/", response_class=HTMLResponse)
async def get_root(request: Request):
@ -105,135 +151,108 @@ async def get_root(request: Request):
"index.html", {"request": request, "server_port": request.url.port or 8080}
)
# Mount static files for js and other assets
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
client_id = id(websocket)
logger.info(f"New WebSocket connection attempt from client {client_id}")
if not await manager.connect(websocket):
logger.error(f"Failed to accept WebSocket connection for client {client_id}")
return
logger.info(f"WebSocket connection accepted for client {client_id}")
await websocket.accept()
logger.info("WebSocket connection established")
active_connections.append(websocket)
try:
# Send initial connection success message
await manager.send_message(
websocket,
{
"type": "chunk",
"chunk": {
"agent": {
"messages": [
{"content": "Connected to RA.Aid server", "status": "info"}
]
}
},
},
)
while True:
try:
message = await websocket.receive_json()
logger.debug(f"Received message from client {client_id}: {message}")
message = await websocket.receive_json()
logger.info(f"Received message: {message}")
if message["type"] == "request":
await manager.send_message(websocket, {"type": "stream_start"})
if message["type"] == "request":
content = message["content"]
logger.info(f"Processing request: {content}")
try:
# Run ra-aid with the request
cmd = ["ra-aid", "-m", message["content"], "--cowboy-mode"]
logger.info(f"Executing command: {' '.join(cmd)}")
# Create queue for output
output_queue = queue.Queue()
process = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
logger.info(f"Process started with PID: {process.pid}")
# Create and start thread
thread = threading.Thread(target=run_ra_aid, args=(content, output_queue))
thread.start()
async def read_stream(stream, is_error=False):
while True:
line = await stream.readline()
if not line:
break
try:
# Send stream start
await websocket.send_json({"type": "stream_start"})
try:
decoded_line = line.decode().strip()
if decoded_line:
await manager.send_message(
websocket,
{
"type": "chunk",
"chunk": {
"tools" if is_error else "agent": {
"messages": [
{
"content": decoded_line,
"status": "error"
if is_error
else "info",
}
]
}
},
},
)
except Exception as e:
logger.error(f"Error processing output: {e}")
while thread.is_alive() or not output_queue.empty():
try:
# Get output with timeout to allow checking thread status
line = output_queue.get(timeout=0.1)
if line and line.strip(): # Only send non-empty messages
logger.debug(f"WebSocket sending: {repr(line)}")
await websocket.send_json({
"type": "chunk",
"chunk": {
"agent": {
"messages": [{
"content": line.strip(),
"status": "info"
}]
}
}
})
except queue.Empty:
await asyncio.sleep(0.1)
except Exception as e:
logger.error(f"WebSocket error: {e}")
traceback.print_exc(file=sys.__stderr__)
# Create tasks for reading stdout and stderr
stdout_task = asyncio.create_task(read_stream(process.stdout))
stderr_task = asyncio.create_task(
read_stream(process.stderr, True)
)
# Wait for thread to finish
thread.join()
logger.info("Thread finished")
# Wait for both streams to complete
await asyncio.gather(stdout_task, stderr_task)
# Send stream end
await websocket.send_json({"type": "stream_end"})
logger.info("Sent stream_end message")
# Wait for process to complete
return_code = await process.wait()
except Exception as e:
error_msg = f"Error running ra-aid: {str(e)}"
logger.error(error_msg)
await websocket.send_json({
"type": "error",
"message": error_msg
})
if return_code != 0:
await manager.handle_error(
websocket, f"Process exited with code {return_code}"
)
await manager.send_message(
websocket,
{"type": "stream_end", "request": message["content"]},
)
except Exception as e:
logger.error(f"Error executing ra-aid: {e}")
await manager.handle_error(websocket, str(e))
except Exception as e:
logger.error(f"Error processing message: {e}")
await manager.handle_error(websocket, str(e))
logger.info("Waiting for message...")
except WebSocketDisconnect:
logger.info(f"WebSocket client {client_id} disconnected")
logger.info("WebSocket client disconnected")
active_connections.remove(websocket)
except Exception as e:
logger.error(f"WebSocket error for client {client_id}: {e}")
logger.error(f"WebSocket error: {e}")
traceback.print_exc()
finally:
manager.disconnect(websocket)
logger.info(f"WebSocket connection cleaned up for client {client_id}")
if websocket in active_connections:
active_connections.remove(websocket)
logger.info("WebSocket connection closed")
@app.get("/config")
async def get_config(request: Request):
"""Return server configuration including host and port."""
return {"host": request.client.host, "port": request.scope.get("server")[1]}
def run_server(host: str = "0.0.0.0", port: int = 8080):
"""Run the FastAPI server."""
logger.info(f"Starting server on {host}:{port}")
uvicorn.run(
app,
host=host,
port=port,
log_level="debug",
ws_max_size=16777216, # 16MB
timeout_keep_alive=0, # Disable keep-alive timeout
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="RA.Aid Web Interface Server")
parser.add_argument(
"--port", type=int, default=8080, help="Port to listen on (default: 8080)"
)
parser.add_argument(
"--host",
type=str,
default="0.0.0.0",
help="Host to listen on (default: 0.0.0.0)",
)
args = parser.parse_args()
run_server(host=args.host, port=args.port)

View File

@ -1,90 +1,214 @@
<!DOCTYPE html>
<html lang="en" class="h-full bg-gray-900">
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<meta name="server-port" content="{{ server_port }}">
<title>RA.Aid Web Interface</title>
<title>RA.Aid Web UI</title>
<script src="https://cdn.tailwindcss.com"></script>
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/styles/tokyo-night-dark.min.css">
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/highlight.min.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/languages/python.min.js"></script>
<script>
tailwind.config = {
darkMode: 'class',
theme: {
extend: {
colors: {
'dark-primary': '#1a1b26',
'dark-secondary': '#24283b',
'dark-accent': '#7aa2f7',
'dark-text': '#c0caf5'
'dark-accent': '#414868',
'dark-text': '#a9b1d6',
'dark-background': '#16161e'
}
}
}
}
</script>
<style>
/* Core styles */
body {
background-color: #16161e;
color: #a9b1d6;
margin: 0;
padding: 0;
height: 100vh;
display: flex;
flex-direction: column;
}
/* Output area */
#stream-output {
flex: 1;
overflow-y: auto;
padding: 1rem;
background-color: #1a1b26;
margin: 0;
display: block !important;
min-height: 300px;
font-family: monospace;
white-space: pre-wrap;
word-break: break-word;
}
/* Message styling */
.message {
margin-bottom: 1.5rem;
padding: 0.75rem 1rem;
border-radius: 0.375rem;
background-color: #24283b;
color: #a9b1d6;
display: block !important;
border: 1px solid #414868;
opacity: 0;
transform: translateY(20px);
animation: slideIn 0.3s ease forwards;
}
@keyframes slideIn {
to {
opacity: 1;
transform: translateY(0);
}
}
/* Message types */
.message.info {
border-color: #7aa2f7;
}
.message.success {
border-color: #9ece6a;
background-color: rgba(158, 206, 106, 0.1);
}
.message.error {
border-color: #f7768e;
background-color: rgba(247, 118, 142, 0.1);
}
.message.warning {
border-color: #e0af68;
background-color: rgba(224, 175, 104, 0.1);
}
/* Section spacing */
.message.section {
margin-top: 2rem;
margin-bottom: 2rem;
padding: 1rem;
background-color: rgba(122, 162, 247, 0.1);
}
/* Code blocks */
.message pre {
margin: 0.5rem 0;
padding: 0.75rem;
border-radius: 0.25rem;
background-color: #1a1b26 !important;
border: 1px solid #414868;
overflow-x: auto;
}
.message code {
font-family: 'JetBrains Mono', monospace;
font-size: 0.9em;
}
/* Input area */
.input-area {
padding: 1rem;
background-color: #24283b;
border-top: 1px solid #414868;
}
.input-container {
display: flex;
gap: 1rem;
align-items: center;
}
#user-input {
flex: 1;
padding: 0.75rem;
background-color: #1a1b26;
border: 1px solid #414868;
border-radius: 0.375rem;
color: #a9b1d6;
font-family: 'JetBrains Mono', monospace;
transition: border-color 0.2s;
}
#user-input:focus {
outline: none;
border-color: #7aa2f7;
}
/* Buttons */
.button {
padding: 0.75rem 1rem;
background-color: #7aa2f7;
color: white;
border: none;
border-radius: 0.375rem;
cursor: pointer;
min-width: 80px;
transition: background-color 0.2s;
display: flex;
align-items: center;
justify-content: center;
gap: 0.5rem;
}
.button:hover {
background-color: #5d87e6;
}
.button:disabled {
opacity: 0.5;
cursor: not-allowed;
}
.button.clear {
background-color: #414868;
}
.button.clear:hover {
background-color: #363b54;
}
/* Icons */
.icon {
font-size: 1.2em;
line-height: 1;
}
</style>
</head>
<body class="h-full bg-dark-primary text-dark-text">
<div class="flex h-full">
<!-- Sidebar -->
<div class="w-64 bg-dark-secondary border-r border-gray-700 flex flex-col">
<div class="p-4 border-b border-gray-700">
<h2 class="text-xl font-semibold text-dark-accent">History</h2>
</div>
<div id="history-list" class="flex-1 overflow-y-auto p-4 space-y-2"></div>
</div>
<body>
<!-- Main Container -->
<div id="stream-output">
<!-- Messages will appear here -->
</div>
<!-- Main Content -->
<div class="flex-1 flex flex-col min-w-0">
<!-- Chat Container -->
<div class="flex-1 overflow-y-auto p-4 space-y-4" id="chat-container">
<div id="chat-messages"></div>
<div id="stream-output" class="hidden font-mono bg-dark-secondary rounded-lg p-4 text-sm"></div>
</div>
<!-- Input Area -->
<div class="border-t border-gray-700 p-4 bg-dark-secondary">
<div class="flex space-x-4">
<textarea
id="user-input"
class="flex-1 bg-dark-primary border border-gray-700 rounded-lg p-3 text-dark-text placeholder-gray-500 focus:outline-none focus:ring-2 focus:ring-dark-accent resize-none"
placeholder="Type your request here..."
rows="3"
></textarea>
<button
id="send-button"
class="px-6 py-2 bg-dark-accent text-white rounded-lg hover:bg-opacity-90 focus:outline-none focus:ring-2 focus:ring-offset-2 focus:ring-dark-accent disabled:opacity-50 disabled:cursor-not-allowed h-fit"
>
Send
</button>
</div>
</div>
<!-- Input Area -->
<div class="input-area">
<div class="input-container">
<input
type="text"
id="user-input"
placeholder="Type your message..."
>
<button id="clear-button" class="button clear" title="Clear conversation">
<span class="icon">🗑️</span>
</button>
<button id="send-button" class="button">
<span class="icon">📤</span>
Send
</button>
</div>
</div>
<script>
// Add dynamic styles for messages
const style = document.createElement('style');
style.textContent = `
.message {
@apply mb-4 p-4 rounded-lg max-w-3xl;
}
.user-message {
@apply bg-dark-accent text-white ml-auto;
}
.system-message {
@apply bg-dark-secondary mr-auto;
}
.error-message {
@apply bg-red-900 text-red-100 mr-auto;
}
.history-item {
@apply p-3 rounded-lg hover:bg-dark-primary cursor-pointer transition-colors duration-200 text-sm;
}
#stream-output:not(:empty) {
@apply block;
}
`;
document.head.appendChild(style);
</script>
<script src="/static/script.js"></script>
<script>
hljs.highlightAll();
</script>
</body>
</html>

View File

@ -1,206 +1,230 @@
class RAWebUI {
class WebSocketHandler {
constructor() {
this.messageHistory = [];
this.connectionAttempts = 0;
this.maxReconnectAttempts = 5;
this.setupElements();
this.setupEventListeners();
this.connectWebSocket();
// Wait for DOM to be ready
if (document.readyState === 'loading') {
document.addEventListener('DOMContentLoaded', () => this.initialize());
} else {
this.initialize();
}
}
setupElements() {
this.userInput = document.getElementById('user-input');
initialize() {
// Store DOM elements as instance variables
this.messageInput = document.getElementById('user-input');
this.sendButton = document.getElementById('send-button');
this.chatMessages = document.getElementById('chat-messages');
this.clearButton = document.getElementById('clear-button');
this.streamOutput = document.getElementById('stream-output');
this.historyList = document.getElementById('history-list');
// Disable send button initially
this.sendButton.disabled = true;
}
setupEventListeners() {
// Validate required elements exist
if (!this.messageInput || !this.sendButton || !this.streamOutput) {
console.error('Required elements not found:', {
messageInput: !!this.messageInput,
sendButton: !!this.sendButton,
streamOutput: !!this.streamOutput
});
return;
}
// Remove hidden class if present
this.streamOutput.classList.remove('hidden');
// Initialize WebSocket
this.connectWebSocket();
// Add event listeners
this.sendButton.addEventListener('click', () => this.sendMessage());
this.userInput.addEventListener('keypress', (e) => {
this.clearButton?.addEventListener('click', () => this.clearConversation());
this.messageInput.addEventListener('keypress', (e) => {
if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault();
this.sendMessage();
}
});
console.log('WebSocketHandler initialized with elements:', {
messageInput: this.messageInput,
sendButton: this.sendButton,
streamOutput: this.streamOutput
});
}
async connectWebSocket() {
// Don't try to reconnect if we've exceeded the maximum attempts
if (this.connectionAttempts >= this.maxReconnectAttempts) {
this.appendMessage(
'Maximum reconnection attempts reached. Please refresh the page.',
'error'
);
return;
}
connectWebSocket() {
try {
// Get the server port from the meta tag
const serverPort = document.querySelector('meta[name="server-port"]')?.content || '8080';
// Construct WebSocket URL
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
const wsUrl = `${protocol}//${window.location.hostname}:${serverPort}/ws`;
const wsUrl = `ws://${window.location.host}/ws`;
console.log('Attempting to connect to WebSocket URL:', wsUrl);
// Close existing connection if any
if (this.ws) {
this.ws.close();
}
// Create new WebSocket connection
console.log('Creating new WebSocket connection...');
this.ws = new WebSocket(wsUrl);
this.connectionAttempts++;
// Setup WebSocket event handlers
this.ws.onopen = () => {
console.log('WebSocket connection established successfully');
this.connectionAttempts = 0; // Reset counter on successful connection
this.sendButton.disabled = false;
};
this.ws.onclose = (event) => {
console.log('WebSocket closed:', event.code, event.reason);
this.ws.onclose = () => {
console.log('WebSocket connection closed');
this.sendButton.disabled = true;
// Try to reconnect after a delay
setTimeout(() => this.connectWebSocket(), 2000);
};
// Only attempt reconnect if not a normal closure and within retry limits
if (event.code !== 1000 && this.connectionAttempts < this.maxReconnectAttempts) {
const delay = Math.min(1000 * Math.pow(2, this.connectionAttempts), 10000);
this.appendMessage(
`Connection lost. Reconnecting in ${delay/1000} seconds...`,
'error'
);
setTimeout(() => this.connectWebSocket(), delay);
this.ws.onmessage = (event) => {
try {
console.log('Raw WebSocket message:', event.data);
const message = JSON.parse(event.data);
console.log('Parsed WebSocket message:', message);
this.handleMessage(message);
} catch (error) {
console.error('Error handling message:', error);
this.appendOutput({
content: `Error: ${error.message}`,
status: 'error'
});
}
};
this.ws.onerror = (error) => {
console.error('WebSocket error:', error);
this.sendButton.disabled = true;
this.appendOutput({
content: 'Connection error. Retrying...',
status: 'error'
});
};
this.ws.onmessage = (event) => {
try {
const data = JSON.parse(event.data);
this.handleServerMessage(data);
} catch (error) {
console.error('Error parsing message:', error);
this.appendMessage('Error processing server message', 'error');
}
};
} catch (error) {
console.error('Failed to connect to WebSocket:', error);
this.appendMessage(
`Connection error: ${error.message}. Retrying...`,
'error'
);
// Attempt to reconnect with exponential backoff
const delay = Math.min(1000 * Math.pow(2, this.connectionAttempts), 10000);
setTimeout(() => this.connectWebSocket(), delay);
console.error('Error connecting to WebSocket:', error);
this.appendOutput({
content: `Connection error: ${error.message}`,
status: 'error'
});
}
}
handleServerMessage(data) {
if (data.type === 'stream_start') {
this.streamOutput.textContent = '';
this.streamOutput.style.display = 'block';
} else if (data.type === 'stream_end') {
this.streamOutput.style.display = 'none';
this.addToHistory(data.request);
this.sendButton.disabled = false;
} else if (data.type === 'chunk') {
this.handleChunk(data.chunk);
handleMessage(message) {
switch (message.type) {
case 'stream_start':
this.handleStreamStart();
break;
case 'chunk':
this.handleChunk(message.chunk);
break;
case 'stream_end':
this.handleStreamEnd();
break;
default:
console.warn('Unknown message type:', message.type);
}
}
handleStreamStart() {
console.log('Stream starting');
this.clearStreamOutput();
this.appendOutput({
content: 'Starting new conversation...',
status: 'info'
});
}
handleStreamEnd() {
console.log('Stream ending');
this.appendOutput({
content: 'Conversation complete.',
status: 'success'
});
this.messageInput.disabled = false;
this.sendButton.disabled = false;
}
handleChunk(chunk) {
console.log(' Processing chunk:', chunk);
if (chunk.agent && chunk.agent.messages) {
chunk.agent.messages.forEach(msg => {
if (msg.content) {
if (Array.isArray(msg.content)) {
msg.content.forEach(content => {
if (content.type === 'text' && content.text.trim()) {
this.appendMessage(content.text.trim(), 'system');
}
});
} else if (msg.content.trim()) {
this.appendMessage(msg.content.trim(), 'system');
}
}
});
} else if (chunk.tools && chunk.tools.messages) {
chunk.tools.messages.forEach(msg => {
if (msg.status === 'error' && msg.content) {
this.appendMessage(msg.content.trim(), 'error');
}
chunk.agent.messages.forEach(message => {
console.log(' Processing agent message:', message);
console.log(' Adding agent message:', message.content);
this.appendOutput(message);
});
}
}
appendMessage(content, type) {
const messageDiv = document.createElement('div');
messageDiv.className = `message ${type}-message`;
messageDiv.textContent = content;
this.chatMessages.appendChild(messageDiv);
this.chatMessages.scrollTop = this.chatMessages.scrollHeight;
clearStreamOutput() {
console.log('Clearing stream output');
while (this.streamOutput.firstChild) {
this.streamOutput.removeChild(this.streamOutput.firstChild);
}
console.log('Stream output cleared');
}
addToHistory(request) {
const historyItem = document.createElement('div');
historyItem.className = 'history-item';
historyItem.textContent = request.slice(0, 50) + (request.length > 50 ? '...' : '');
historyItem.title = request;
historyItem.addEventListener('click', () => {
this.userInput.value = request;
this.userInput.focus();
clearConversation() {
this.clearStreamOutput();
this.appendOutput({
content: 'Conversation cleared.',
status: 'info'
});
}
appendOutput(message) {
console.log(' Appending output:', message);
const messageDiv = document.createElement('div');
messageDiv.className = `message ${message.status || 'info'}`;
const contentSpan = document.createElement('span');
// Convert ANSI escape codes to HTML
let content = message.content;
content = this.convertAnsiToHtml(content);
// Check for code blocks and apply syntax highlighting
if (content.includes('```')) {
content = this.highlightCodeBlocks(content);
}
contentSpan.innerHTML = content;
messageDiv.appendChild(contentSpan);
this.streamOutput.appendChild(messageDiv);
messageDiv.scrollIntoView({ behavior: 'smooth', block: 'end' });
}
convertAnsiToHtml(text) {
// ANSI color codes to CSS classes
const ansiToClass = {
'[94m': '<span class="text-blue-400">',
'[1;32m': '<span class="text-green-400 font-bold">',
'[0m': '</span>'
};
// Replace ANSI codes with HTML
let html = text;
for (const [ansi, htmlClass] of Object.entries(ansiToClass)) {
html = html.replaceAll('\u001b' + ansi, htmlClass);
}
return html;
}
highlightCodeBlocks(content) {
const codeBlockRegex = /```(\w+)?\n([\s\S]*?)```/g;
return content.replace(codeBlockRegex, (match, lang, code) => {
const language = lang || 'plaintext';
const highlighted = hljs.highlight(code.trim(), { language }).value;
return `<pre><code class="language-${language}">${highlighted}</code></pre>`;
});
this.historyList.insertBefore(historyItem, this.historyList.firstChild);
this.messageHistory.push(request);
}
sendMessage() {
console.log('Send button clicked');
const message = this.userInput.value.trim();
console.log('Message content:', message);
if (!message) {
console.log('Message is empty, not sending');
return;
}
const message = this.messageInput.value.trim();
if (!message) return;
if (!this.ws || this.ws.readyState !== WebSocket.OPEN) {
console.error('WebSocket is not connected');
this.appendMessage('Error: Not connected to server. Please wait...', 'error');
return;
}
try {
console.log('Sending message to server');
this.appendMessage(message, 'user');
const payload = { type: 'request', content: message };
console.log('Payload:', payload);
this.ws.send(JSON.stringify(payload));
console.log('Message sent successfully');
this.userInput.value = '';
this.sendButton.disabled = true;
} catch (error) {
console.error('Error sending message:', error);
this.appendMessage(`Error sending message: ${error.message}`, 'error');
this.sendButton.disabled = false;
}
console.log('Sending message:', message);
this.ws.send(JSON.stringify({
type: "request",
content: message
}));
this.messageInput.value = '';
this.messageInput.disabled = true;
this.sendButton.disabled = true;
}
}
// Initialize the UI when the page loads
document.addEventListener('DOMContentLoaded', () => {
window.raWebUI = new RAWebUI();
// Initialize WebSocket handler when the page loads
window.addEventListener('load', () => {
new WebSocketHandler();
});

View File

@ -0,0 +1,66 @@
"""Tests for Windows-specific functionality."""
import os
import sys
import subprocess
import pytest
from unittest.mock import patch, MagicMock
from ra_aid.proc.interactive import get_terminal_size, create_process, run_interactive_command
@pytest.mark.skipif(sys.platform != "win32", reason="Windows-specific tests")
class TestWindowsCompatibility:
"""Test suite for Windows-specific functionality."""
def test_get_terminal_size(self):
"""Test terminal size detection on Windows."""
with patch('shutil.get_terminal_size') as mock_get_size:
mock_get_size.return_value = MagicMock(columns=120, lines=30)
cols, rows = get_terminal_size()
assert cols == 120
assert rows == 30
mock_get_size.assert_called_once()
def test_create_process(self):
"""Test process creation on Windows."""
with patch('subprocess.Popen') as mock_popen:
mock_process = MagicMock()
mock_process.returncode = 0
mock_popen.return_value = mock_process
proc, _ = create_process(['echo', 'test'])
assert mock_popen.called
args, kwargs = mock_popen.call_args
assert kwargs['stdin'] == subprocess.PIPE
assert kwargs['stdout'] == subprocess.PIPE
assert kwargs['stderr'] == subprocess.PIPE
assert 'startupinfo' in kwargs
assert kwargs['startupinfo'].dwFlags & subprocess.STARTF_USESHOWWINDOW
def test_run_interactive_command(self):
"""Test running an interactive command on Windows."""
test_output = "Test output\n"
with patch('subprocess.Popen') as mock_popen:
mock_process = MagicMock()
mock_process.stdout = MagicMock()
mock_process.stdout.read.return_value = test_output.encode()
mock_process.wait.return_value = 0
mock_popen.return_value = mock_process
output, return_code = run_interactive_command(['echo', 'test'])
assert return_code == 0
assert "Test output" in output.decode()
def test_windows_dependencies(self):
"""Test that required Windows dependencies are available."""
if sys.platform == "win32":
import msvcrt
import win32pipe
import win32file
import win32con
import win32process
# If we get here without ImportError, the test passes
assert True