Add Windows compatibility improvements (#105)
* Add Windows compatibility improvements 1. Add error handling for Windows-specific modules 2. Update README with Windows installation instructions 3. Add Windows-specific tests 4. Improve cross-platform support in interactive.py * Fix: Add missing subprocess import in Windows compatibility tests * Improve Windows compatibility: 1. Add detailed error handling for Windows I/O operations 2. Enhance timeout messages with more descriptive information 3. Add comprehensive comments explaining Windows-specific code * Fix WebUI: Improve message display, add syntax highlighting, animations, and fix WebSocket communication
This commit is contained in:
parent
2f132bdeb5
commit
d163a74c47
25
README.md
25
README.md
|
|
@ -97,6 +97,31 @@ What sets RA.Aid apart is its ability to handle complex programming tasks that e
|
|||
|
||||
## Installation
|
||||
|
||||
### Windows Installation
|
||||
1. Install Python 3.8 or higher from [python.org](https://www.python.org/downloads/)
|
||||
2. Install required system dependencies:
|
||||
```powershell
|
||||
# Install Chocolatey if not already installed (run in admin PowerShell)
|
||||
Set-ExecutionPolicy Bypass -Scope Process -Force; [System.Net.ServicePointManager]::SecurityProtocol = [System.Net.ServicePointManager]::SecurityProtocol -bor 3072; iex ((New-Object System.Net.WebClient).DownloadString('https://community.chocolatey.org/install.ps1'))
|
||||
|
||||
# Install ripgrep using Chocolatey
|
||||
choco install ripgrep
|
||||
```
|
||||
3. Install RA.Aid:
|
||||
```powershell
|
||||
pip install ra-aid
|
||||
```
|
||||
4. Install Windows-specific dependencies:
|
||||
```powershell
|
||||
pip install pywin32
|
||||
```
|
||||
5. Set up your API keys in a `.env` file:
|
||||
```env
|
||||
ANTHROPIC_API_KEY=your_anthropic_key
|
||||
OPENAI_API_KEY=your_openai_key
|
||||
```
|
||||
|
||||
### Unix/Linux Installation
|
||||
RA.Aid can be installed directly using pip:
|
||||
|
||||
```bash
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from ra_aid import print_error
|
||||
|
|
@ -21,8 +22,13 @@ class RipGrepDependency(Dependency):
|
|||
|
||||
def check(self):
|
||||
"""Check if ripgrep is installed."""
|
||||
result = os.system("rg --version > /dev/null 2>&1")
|
||||
if result != 0:
|
||||
try:
|
||||
result = subprocess.run(['rg', '--version'],
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL)
|
||||
if result.returncode != 0:
|
||||
raise FileNotFoundError()
|
||||
except (FileNotFoundError, subprocess.SubprocessError):
|
||||
print_error("Required dependency 'ripgrep' is not installed.")
|
||||
print("Please install ripgrep:")
|
||||
print(" - Ubuntu/Debian: sudo apt-get install ripgrep")
|
||||
|
|
|
|||
|
|
@ -3,51 +3,123 @@
|
|||
Module for running interactive subprocesses with output capture,
|
||||
with full raw input passthrough for interactive commands.
|
||||
|
||||
It uses a pseudo-tty and integrates pyte's HistoryScreen to simulate
|
||||
It uses a pseudo-tty on Unix systems and direct pipes on Windows to simulate
|
||||
a terminal and capture the final scrollback history (non-blank lines).
|
||||
The interface remains compatible with external callers expecting a tuple (output, return_code),
|
||||
where output is a bytes object (UTF-8 encoded).
|
||||
"""
|
||||
|
||||
import errno
|
||||
import io
|
||||
import os
|
||||
import select
|
||||
import shutil
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import termios
|
||||
import time
|
||||
import tty
|
||||
from typing import List, Tuple
|
||||
import shutil
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
import pyte
|
||||
from pyte.screens import HistoryScreen
|
||||
|
||||
# Windows-specific imports
|
||||
if sys.platform == "win32":
|
||||
try:
|
||||
# msvcrt: Provides Windows console I/O functionality
|
||||
import msvcrt
|
||||
# win32pipe, win32file: For low-level pipe operations
|
||||
import win32pipe
|
||||
import win32file
|
||||
# win32con: Windows API constants
|
||||
import win32con
|
||||
# win32process: Process management on Windows
|
||||
import win32process
|
||||
except ImportError as e:
|
||||
print("Error: Required Windows dependencies not found.")
|
||||
print("Please install the required packages using:")
|
||||
print(" pip install pywin32")
|
||||
sys.exit(1)
|
||||
else:
|
||||
# Unix-specific imports for terminal handling
|
||||
import termios
|
||||
import fcntl
|
||||
import pty
|
||||
|
||||
def render_line(line, columns: int) -> str:
|
||||
"""Render a single screen line from the pyte buffer (a mapping of column to Char)."""
|
||||
return "".join(line[x].data for x in range(columns))
|
||||
def get_terminal_size():
|
||||
"""Get the current terminal size."""
|
||||
if sys.platform == "win32":
|
||||
import shutil
|
||||
size = shutil.get_terminal_size()
|
||||
return size.columns, size.lines
|
||||
else:
|
||||
import struct
|
||||
try:
|
||||
with open(sys.stdout.fileno(), 'wb', buffering=0) as fd:
|
||||
size = struct.unpack('hh', fcntl.ioctl(fd, termios.TIOCGWINSZ, '1234'))
|
||||
return size[1], size[0]
|
||||
except (IOError, AttributeError):
|
||||
return 80, 24
|
||||
|
||||
def create_process(cmd: List[str]) -> Tuple[subprocess.Popen, Optional[int]]:
|
||||
"""Create a subprocess with appropriate handling for the platform.
|
||||
|
||||
On Windows:
|
||||
- Uses STARTUPINFO to hide the console window
|
||||
- Creates a new process group for proper signal handling
|
||||
- Returns direct pipe handles for I/O
|
||||
|
||||
On Unix:
|
||||
- Creates a pseudo-terminal (PTY) for proper terminal emulation
|
||||
- Sets up process group for signal handling
|
||||
- Returns master PTY file descriptor for I/O
|
||||
"""
|
||||
if sys.platform == "win32":
|
||||
# Windows process creation with hidden console
|
||||
startupinfo = subprocess.STARTUPINFO()
|
||||
startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW # Hide the console window
|
||||
|
||||
# Create process with proper pipe handling
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdin=subprocess.PIPE, # Allow writing to stdin
|
||||
stdout=subprocess.PIPE, # Capture stdout
|
||||
stderr=subprocess.PIPE, # Capture stderr
|
||||
startupinfo=startupinfo,
|
||||
# CREATE_NEW_PROCESS_GROUP allows proper Ctrl+C handling
|
||||
creationflags=subprocess.CREATE_NEW_PROCESS_GROUP
|
||||
)
|
||||
return proc, None # No PTY master_fd needed on Windows
|
||||
else:
|
||||
# Unix process creation with PTY
|
||||
master_fd, slave_fd = pty.openpty()
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdin=slave_fd,
|
||||
stdout=slave_fd,
|
||||
stderr=slave_fd,
|
||||
preexec_fn=os.setsid
|
||||
)
|
||||
os.close(slave_fd)
|
||||
return proc, master_fd
|
||||
|
||||
def run_interactive_command(
|
||||
cmd: List[str], expected_runtime_seconds: int = 30
|
||||
cmd: List[str],
|
||||
expected_runtime_seconds: int = 1800,
|
||||
ratio: float = 0.5
|
||||
) -> Tuple[bytes, int]:
|
||||
"""
|
||||
Runs an interactive command with a pseudo-tty, capturing final scrollback history.
|
||||
|
||||
Assumptions and constraints:
|
||||
- Running on a Linux system.
|
||||
- Running on a Linux system or Windows.
|
||||
- `cmd` is a non-empty list where cmd[0] is the executable.
|
||||
- The executable is on PATH.
|
||||
|
||||
Args:
|
||||
cmd: A list containing the command and its arguments.
|
||||
expected_runtime_seconds: Expected runtime in seconds, defaults to 30.
|
||||
expected_runtime_seconds: Expected runtime in seconds, defaults to 1800.
|
||||
If process exceeds 2x this value, it will be terminated gracefully.
|
||||
If process exceeds 3x this value, it will be killed forcefully.
|
||||
Must be between 1 and 1800 seconds (30 minutes).
|
||||
ratio: Ratio of history to keep from top vs bottom (default: 0.5)
|
||||
|
||||
Returns:
|
||||
A tuple of (captured_output, return_code), where captured_output is a UTF-8 encoded
|
||||
|
|
@ -69,154 +141,164 @@ def run_interactive_command(
|
|||
)
|
||||
|
||||
try:
|
||||
term_size = os.get_terminal_size()
|
||||
term_size = get_terminal_size()
|
||||
cols, rows = term_size.columns, term_size.lines
|
||||
except OSError:
|
||||
cols, rows = 80, 24
|
||||
|
||||
# Set up pyte screen and stream to capture terminal output.
|
||||
screen = HistoryScreen(cols, rows, history=2000, ratio=0.5)
|
||||
screen = pyte.HistoryScreen(cols, rows, history=2000, ratio=ratio)
|
||||
stream = pyte.Stream(screen)
|
||||
|
||||
# Open a new pseudo-tty.
|
||||
master_fd, slave_fd = os.openpty()
|
||||
# Set master_fd to non-blocking to avoid indefinite blocking.
|
||||
os.set_blocking(master_fd, False)
|
||||
|
||||
try:
|
||||
stdin_fd = sys.stdin.fileno()
|
||||
except (AttributeError, io.UnsupportedOperation):
|
||||
stdin_fd = None
|
||||
|
||||
# Set up environment variables for the subprocess using detected terminal size.
|
||||
env = os.environ.copy()
|
||||
env.update(
|
||||
{
|
||||
"DEBIAN_FRONTEND": "noninteractive",
|
||||
"GIT_PAGER": "",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"CI": "true",
|
||||
"LANG": "C.UTF-8",
|
||||
"LC_ALL": "C.UTF-8",
|
||||
"COLUMNS": str(cols),
|
||||
"LINES": str(rows),
|
||||
"FORCE_COLOR": "1",
|
||||
"GIT_TERMINAL_PROMPT": "0",
|
||||
"PYTHONDONTWRITEBYTECODE": "1",
|
||||
"NODE_OPTIONS": "--unhandled-rejections=strict",
|
||||
}
|
||||
)
|
||||
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdin=slave_fd,
|
||||
stdout=slave_fd,
|
||||
stderr=slave_fd,
|
||||
bufsize=0,
|
||||
close_fds=True,
|
||||
env=env,
|
||||
preexec_fn=os.setsid, # Create new process group for proper signal handling.
|
||||
)
|
||||
os.close(slave_fd) # Close slave end in the parent process.
|
||||
proc, master_fd = create_process(cmd)
|
||||
|
||||
captured_data = []
|
||||
start_time = time.time()
|
||||
was_terminated = False
|
||||
timeout_type = None
|
||||
|
||||
def check_timeout():
|
||||
nonlocal timeout_type
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > 3 * expected_runtime_seconds:
|
||||
os.killpg(os.getpgid(proc.pid), signal.SIGKILL)
|
||||
if sys.platform == "win32":
|
||||
print("\nProcess exceeded hard timeout limit, forcefully terminating...")
|
||||
proc.terminate()
|
||||
time.sleep(0.5)
|
||||
if proc.poll() is None:
|
||||
print("Process did not respond to termination, killing...")
|
||||
proc.kill()
|
||||
else:
|
||||
print("\nProcess exceeded hard timeout limit, sending SIGKILL...")
|
||||
os.killpg(os.getpgid(proc.pid), signal.SIGKILL)
|
||||
timeout_type = "hard_timeout"
|
||||
return True
|
||||
elif elapsed > 2 * expected_runtime_seconds:
|
||||
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
|
||||
if sys.platform == "win32":
|
||||
print("\nProcess exceeded soft timeout limit, attempting graceful termination...")
|
||||
proc.terminate()
|
||||
else:
|
||||
print("\nProcess exceeded soft timeout limit, sending SIGTERM...")
|
||||
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
|
||||
timeout_type = "soft_timeout"
|
||||
return True
|
||||
return False
|
||||
|
||||
# Interactive mode: forward input if running in a TTY.
|
||||
if stdin_fd is not None and sys.stdin.isatty():
|
||||
old_settings = termios.tcgetattr(stdin_fd)
|
||||
tty.setraw(stdin_fd)
|
||||
if sys.platform == "win32":
|
||||
# Windows handling
|
||||
try:
|
||||
while True:
|
||||
if check_timeout():
|
||||
was_terminated = True
|
||||
break
|
||||
# Use a finite timeout to avoid indefinite blocking.
|
||||
rlist, _, _ = select.select([master_fd, stdin_fd], [], [], 1.0)
|
||||
if master_fd in rlist:
|
||||
try:
|
||||
data = os.read(master_fd, 1024)
|
||||
except OSError as e:
|
||||
if e.errno == errno.EIO:
|
||||
|
||||
try:
|
||||
# Check stdout with proper error handling
|
||||
stdout_data = proc.stdout.read1(1024)
|
||||
if stdout_data:
|
||||
captured_data.append(stdout_data)
|
||||
try:
|
||||
stream.feed(stdout_data.decode(errors='ignore'))
|
||||
except Exception as e:
|
||||
print(f"Warning: Error processing stdout: {e}")
|
||||
|
||||
# Check stderr with proper error handling
|
||||
stderr_data = proc.stderr.read1(1024)
|
||||
if stderr_data:
|
||||
captured_data.append(stderr_data)
|
||||
try:
|
||||
stream.feed(stderr_data.decode(errors='ignore'))
|
||||
except Exception as e:
|
||||
print(f"Warning: Error processing stderr: {e}")
|
||||
|
||||
# Check for input with proper error handling
|
||||
if msvcrt.kbhit():
|
||||
try:
|
||||
char = msvcrt.getch()
|
||||
proc.stdin.write(char)
|
||||
proc.stdin.flush()
|
||||
except (IOError, OSError) as e:
|
||||
print(f"Warning: Error handling keyboard input: {e}")
|
||||
break
|
||||
else:
|
||||
raise
|
||||
if not data: # EOF detected.
|
||||
|
||||
except (IOError, OSError) as e:
|
||||
if isinstance(e, OSError) and e.winerror == 6: # Invalid handle
|
||||
break
|
||||
captured_data.append(data)
|
||||
decoded = data.decode("utf-8", errors="ignore")
|
||||
stream.feed(decoded)
|
||||
os.write(1, data)
|
||||
if stdin_fd in rlist:
|
||||
print(f"Warning: I/O error during process communication: {e}")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in Windows process handling: {e}")
|
||||
proc.terminate()
|
||||
else:
|
||||
# Unix handling
|
||||
import tty
|
||||
try:
|
||||
old_settings = termios.tcgetattr(sys.stdin)
|
||||
tty.setraw(sys.stdin)
|
||||
while True:
|
||||
if check_timeout():
|
||||
was_terminated = True
|
||||
break
|
||||
rlist, _, _ = select.select([master_fd, sys.stdin], [], [], 0.1)
|
||||
|
||||
for fd in rlist:
|
||||
try:
|
||||
input_data = os.read(stdin_fd, 1024)
|
||||
except OSError:
|
||||
input_data = b""
|
||||
if input_data:
|
||||
os.write(master_fd, input_data)
|
||||
if fd == master_fd:
|
||||
data = os.read(master_fd, 1024)
|
||||
if not data:
|
||||
break
|
||||
captured_data.append(data)
|
||||
stream.feed(data.decode(errors='ignore'))
|
||||
else:
|
||||
data = os.read(fd, 1024)
|
||||
os.write(master_fd, data)
|
||||
except (IOError, OSError):
|
||||
break
|
||||
|
||||
except KeyboardInterrupt:
|
||||
proc.terminate()
|
||||
finally:
|
||||
termios.tcsetattr(stdin_fd, termios.TCSADRAIN, old_settings)
|
||||
else:
|
||||
# Non-interactive mode.
|
||||
try:
|
||||
while True:
|
||||
if check_timeout():
|
||||
was_terminated = True
|
||||
break
|
||||
rlist, _, _ = select.select([master_fd], [], [], 1.0)
|
||||
if not rlist:
|
||||
continue
|
||||
try:
|
||||
data = os.read(master_fd, 1024)
|
||||
except OSError as e:
|
||||
if e.errno == errno.EIO:
|
||||
break
|
||||
else:
|
||||
raise
|
||||
if not data: # EOF detected.
|
||||
break
|
||||
captured_data.append(data)
|
||||
decoded = data.decode("utf-8", errors="ignore")
|
||||
stream.feed(decoded)
|
||||
os.write(1, data)
|
||||
except KeyboardInterrupt:
|
||||
proc.terminate()
|
||||
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings)
|
||||
|
||||
if not sys.platform == "win32" and master_fd is not None:
|
||||
os.close(master_fd)
|
||||
|
||||
os.close(master_fd)
|
||||
proc.wait()
|
||||
|
||||
# Assemble full scrollback: combine history.top, the current display, and history.bottom.
|
||||
top_lines = [render_line(line, cols) for line in screen.history.top]
|
||||
def render_line(line, width):
|
||||
return ''.join(char.data for char in line[:width]).rstrip()
|
||||
|
||||
# Combine history and current screen content
|
||||
final_output = []
|
||||
|
||||
# Add lines from history
|
||||
history_lines = [render_line(line, cols) for line in screen.history.top]
|
||||
final_output.extend(line for line in history_lines if line.strip())
|
||||
|
||||
# Add current screen content
|
||||
screen_lines = [render_line(line, cols) for line in screen.display]
|
||||
final_output.extend(line for line in screen_lines if line.strip())
|
||||
|
||||
# Add bottom history
|
||||
bottom_lines = [render_line(line, cols) for line in screen.history.bottom]
|
||||
display_lines = screen.display # List of strings representing the current screen.
|
||||
all_lines = top_lines + display_lines + bottom_lines
|
||||
final_output.extend(line for line in bottom_lines if line.strip())
|
||||
|
||||
# Trim out empty lines to get only meaningful "history" lines.
|
||||
trimmed_lines = [line for line in all_lines if line.strip()]
|
||||
final_output = "\n".join(trimmed_lines)
|
||||
|
||||
# Add timeout message if process was terminated due to timeout.
|
||||
# Add timeout message if process was terminated
|
||||
if was_terminated:
|
||||
timeout_msg = f"\n[Process exceeded timeout ({expected_runtime_seconds} seconds expected)]"
|
||||
final_output += timeout_msg
|
||||
if timeout_type == "hard_timeout":
|
||||
timeout_msg = f"\n[Process forcefully terminated after exceeding {3 * expected_runtime_seconds:.1f} seconds (expected: {expected_runtime_seconds} seconds)]"
|
||||
else:
|
||||
timeout_msg = f"\n[Process gracefully terminated after exceeding {2 * expected_runtime_seconds:.1f} seconds (expected: {expected_runtime_seconds} seconds)]"
|
||||
final_output.append(timeout_msg)
|
||||
|
||||
# Limit output to the last 8000 bytes.
|
||||
# Limit output size
|
||||
final_output = final_output[-8000:]
|
||||
final_output = '\n'.join(final_output)
|
||||
|
||||
return final_output.encode("utf-8"), proc.returncode
|
||||
return final_output.encode('utf-8'), proc.returncode
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -1,29 +1,37 @@
|
|||
"""Web interface server implementation for RA.Aid."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import shutil
|
||||
#!/usr/bin/env python3
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
import asyncio
|
||||
from typing import List
|
||||
import json
|
||||
import threading
|
||||
import queue
|
||||
import traceback
|
||||
import shutil
|
||||
import logging
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler(sys.__stderr__) # Use the real stderr
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Add project root to Python path
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
from fastapi import FastAPI, WebSocket, Request, WebSocketDisconnect
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.templating import Jinja2Templates
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.DEBUG) # Set to DEBUG for more info
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Verify ra-aid is available
|
||||
if not shutil.which("ra-aid"):
|
||||
logger.error(
|
||||
"ra-aid command not found. Please ensure it's installed and in your PATH"
|
||||
)
|
||||
sys.exit(1)
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import uvicorn
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
|
@ -36,67 +44,105 @@ app.add_middleware(
|
|||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Get the directory containing static files
|
||||
STATIC_DIR = Path(__file__).parent / "static"
|
||||
if not STATIC_DIR.exists():
|
||||
logger.error(f"Static directory not found at {STATIC_DIR}")
|
||||
sys.exit(1)
|
||||
# Setup templates and static files directories
|
||||
CURRENT_DIR = Path(__file__).parent
|
||||
templates = Jinja2Templates(directory=CURRENT_DIR)
|
||||
|
||||
logger.info(f"Using static directory: {STATIC_DIR}")
|
||||
# Mount static files for js and other assets
|
||||
static_dir = CURRENT_DIR / "static"
|
||||
app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")
|
||||
|
||||
# Setup templates
|
||||
templates = Jinja2Templates(directory=str(STATIC_DIR))
|
||||
# Store active WebSocket connections
|
||||
active_connections: List[WebSocket] = []
|
||||
|
||||
def run_ra_aid(message_content, output_queue):
|
||||
"""Run ra-aid in a separate thread"""
|
||||
try:
|
||||
import ra_aid.__main__
|
||||
logger.info("Successfully imported ra_aid.__main__")
|
||||
|
||||
class ConnectionManager:
|
||||
def __init__(self):
|
||||
self.active_connections: List[WebSocket] = []
|
||||
# Override sys.argv
|
||||
sys.argv = [sys.argv[0], "-m", message_content, "--cowboy-mode"]
|
||||
logger.info(f"Set sys.argv to: {sys.argv}")
|
||||
|
||||
# Create custom output handler
|
||||
class QueueHandler:
|
||||
def __init__(self, queue):
|
||||
self.queue = queue
|
||||
self.buffer = []
|
||||
self.box_start = False
|
||||
self._real_stderr = sys.__stderr__
|
||||
|
||||
def write(self, text):
|
||||
# Always log raw output for debugging
|
||||
logger.debug(f"Raw output: {repr(text)}")
|
||||
|
||||
# Check if this is a box drawing character
|
||||
if any(c in text for c in '╭╮╰╯│─'):
|
||||
self.box_start = True
|
||||
self.buffer.append(text)
|
||||
elif self.box_start and text.strip():
|
||||
self.buffer.append(text)
|
||||
if '╯' in text: # End of box
|
||||
full_text = ''.join(self.buffer)
|
||||
# Extract content from inside the box
|
||||
lines = full_text.split('\n')
|
||||
content_lines = []
|
||||
for line in lines:
|
||||
# Remove box characters and leading/trailing spaces
|
||||
clean_line = line.strip('╭╮╰╯│─ ')
|
||||
if clean_line:
|
||||
content_lines.append(clean_line)
|
||||
if content_lines:
|
||||
self.queue.put('\n'.join(content_lines))
|
||||
self.buffer = []
|
||||
self.box_start = False
|
||||
elif not self.box_start and text.strip():
|
||||
self.queue.put(text.strip())
|
||||
|
||||
def flush(self):
|
||||
if self.buffer:
|
||||
full_text = ''.join(self.buffer)
|
||||
# Extract content from partial box
|
||||
lines = full_text.split('\n')
|
||||
content_lines = []
|
||||
for line in lines:
|
||||
# Remove box characters and leading/trailing spaces
|
||||
clean_line = line.strip('╭╮╰╯│─ ')
|
||||
if clean_line:
|
||||
content_lines.append(clean_line)
|
||||
if content_lines:
|
||||
self.queue.put('\n'.join(content_lines))
|
||||
self.buffer = []
|
||||
self.box_start = False
|
||||
|
||||
# Replace stdout and stderr
|
||||
old_stdout = sys.stdout
|
||||
old_stderr = sys.stderr
|
||||
queue_handler = QueueHandler(output_queue)
|
||||
sys.stdout = queue_handler
|
||||
sys.stderr = queue_handler
|
||||
|
||||
async def connect(self, websocket: WebSocket) -> bool:
|
||||
try:
|
||||
logger.debug("Accepting WebSocket connection...")
|
||||
await websocket.accept()
|
||||
logger.debug("WebSocket connection accepted")
|
||||
self.active_connections.append(websocket)
|
||||
return True
|
||||
logger.info("Starting ra_aid.main()")
|
||||
ra_aid.__main__.main()
|
||||
logger.info("Finished ra_aid.main()")
|
||||
except SystemExit:
|
||||
logger.info("Caught SystemExit - this is normal")
|
||||
except Exception as e:
|
||||
logger.error(f"Error accepting WebSocket connection: {e}")
|
||||
return False
|
||||
|
||||
def disconnect(self, websocket: WebSocket):
|
||||
if websocket in self.active_connections:
|
||||
self.active_connections.remove(websocket)
|
||||
|
||||
async def send_message(self, websocket: WebSocket, message: Dict[str, Any]):
|
||||
try:
|
||||
await websocket.send_json(message)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending message: {e}")
|
||||
await self.handle_error(websocket, str(e))
|
||||
|
||||
async def handle_error(self, websocket: WebSocket, error_message: str):
|
||||
try:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "chunk",
|
||||
"chunk": {
|
||||
"tools": {
|
||||
"messages": [
|
||||
{
|
||||
"content": f"Error: {error_message}",
|
||||
"status": "error",
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending error message: {e}")
|
||||
|
||||
|
||||
manager = ConnectionManager()
|
||||
logger.error(f"Error in main: {str(e)}")
|
||||
traceback.print_exc(file=sys.__stderr__)
|
||||
finally:
|
||||
# Flush any remaining output
|
||||
queue_handler.flush()
|
||||
# Restore stdout and stderr
|
||||
sys.stdout = old_stdout
|
||||
sys.stderr = old_stderr
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in thread: {str(e)}")
|
||||
traceback.print_exc(file=sys.__stderr__)
|
||||
output_queue.put(f"Error: {str(e)}")
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def get_root(request: Request):
|
||||
|
|
@ -105,135 +151,108 @@ async def get_root(request: Request):
|
|||
"index.html", {"request": request, "server_port": request.url.port or 8080}
|
||||
)
|
||||
|
||||
|
||||
# Mount static files for js and other assets
|
||||
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
|
||||
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
client_id = id(websocket)
|
||||
logger.info(f"New WebSocket connection attempt from client {client_id}")
|
||||
|
||||
if not await manager.connect(websocket):
|
||||
logger.error(f"Failed to accept WebSocket connection for client {client_id}")
|
||||
return
|
||||
|
||||
logger.info(f"WebSocket connection accepted for client {client_id}")
|
||||
await websocket.accept()
|
||||
logger.info("WebSocket connection established")
|
||||
active_connections.append(websocket)
|
||||
|
||||
try:
|
||||
# Send initial connection success message
|
||||
await manager.send_message(
|
||||
websocket,
|
||||
{
|
||||
"type": "chunk",
|
||||
"chunk": {
|
||||
"agent": {
|
||||
"messages": [
|
||||
{"content": "Connected to RA.Aid server", "status": "info"}
|
||||
]
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
while True:
|
||||
try:
|
||||
message = await websocket.receive_json()
|
||||
logger.debug(f"Received message from client {client_id}: {message}")
|
||||
message = await websocket.receive_json()
|
||||
logger.info(f"Received message: {message}")
|
||||
|
||||
if message["type"] == "request":
|
||||
await manager.send_message(websocket, {"type": "stream_start"})
|
||||
if message["type"] == "request":
|
||||
content = message["content"]
|
||||
logger.info(f"Processing request: {content}")
|
||||
|
||||
try:
|
||||
# Run ra-aid with the request
|
||||
cmd = ["ra-aid", "-m", message["content"], "--cowboy-mode"]
|
||||
logger.info(f"Executing command: {' '.join(cmd)}")
|
||||
# Create queue for output
|
||||
output_queue = queue.Queue()
|
||||
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
logger.info(f"Process started with PID: {process.pid}")
|
||||
# Create and start thread
|
||||
thread = threading.Thread(target=run_ra_aid, args=(content, output_queue))
|
||||
thread.start()
|
||||
|
||||
async def read_stream(stream, is_error=False):
|
||||
while True:
|
||||
line = await stream.readline()
|
||||
if not line:
|
||||
break
|
||||
try:
|
||||
# Send stream start
|
||||
await websocket.send_json({"type": "stream_start"})
|
||||
|
||||
try:
|
||||
decoded_line = line.decode().strip()
|
||||
if decoded_line:
|
||||
await manager.send_message(
|
||||
websocket,
|
||||
{
|
||||
"type": "chunk",
|
||||
"chunk": {
|
||||
"tools" if is_error else "agent": {
|
||||
"messages": [
|
||||
{
|
||||
"content": decoded_line,
|
||||
"status": "error"
|
||||
if is_error
|
||||
else "info",
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing output: {e}")
|
||||
while thread.is_alive() or not output_queue.empty():
|
||||
try:
|
||||
# Get output with timeout to allow checking thread status
|
||||
line = output_queue.get(timeout=0.1)
|
||||
if line and line.strip(): # Only send non-empty messages
|
||||
logger.debug(f"WebSocket sending: {repr(line)}")
|
||||
await websocket.send_json({
|
||||
"type": "chunk",
|
||||
"chunk": {
|
||||
"agent": {
|
||||
"messages": [{
|
||||
"content": line.strip(),
|
||||
"status": "info"
|
||||
}]
|
||||
}
|
||||
}
|
||||
})
|
||||
except queue.Empty:
|
||||
await asyncio.sleep(0.1)
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket error: {e}")
|
||||
traceback.print_exc(file=sys.__stderr__)
|
||||
|
||||
# Create tasks for reading stdout and stderr
|
||||
stdout_task = asyncio.create_task(read_stream(process.stdout))
|
||||
stderr_task = asyncio.create_task(
|
||||
read_stream(process.stderr, True)
|
||||
)
|
||||
# Wait for thread to finish
|
||||
thread.join()
|
||||
logger.info("Thread finished")
|
||||
|
||||
# Wait for both streams to complete
|
||||
await asyncio.gather(stdout_task, stderr_task)
|
||||
# Send stream end
|
||||
await websocket.send_json({"type": "stream_end"})
|
||||
logger.info("Sent stream_end message")
|
||||
|
||||
# Wait for process to complete
|
||||
return_code = await process.wait()
|
||||
except Exception as e:
|
||||
error_msg = f"Error running ra-aid: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": error_msg
|
||||
})
|
||||
|
||||
if return_code != 0:
|
||||
await manager.handle_error(
|
||||
websocket, f"Process exited with code {return_code}"
|
||||
)
|
||||
|
||||
await manager.send_message(
|
||||
websocket,
|
||||
{"type": "stream_end", "request": message["content"]},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing ra-aid: {e}")
|
||||
await manager.handle_error(websocket, str(e))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message: {e}")
|
||||
await manager.handle_error(websocket, str(e))
|
||||
logger.info("Waiting for message...")
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket client {client_id} disconnected")
|
||||
logger.info("WebSocket client disconnected")
|
||||
active_connections.remove(websocket)
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket error for client {client_id}: {e}")
|
||||
logger.error(f"WebSocket error: {e}")
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
manager.disconnect(websocket)
|
||||
logger.info(f"WebSocket connection cleaned up for client {client_id}")
|
||||
if websocket in active_connections:
|
||||
active_connections.remove(websocket)
|
||||
logger.info("WebSocket connection closed")
|
||||
|
||||
|
||||
@app.get("/config")
|
||||
async def get_config(request: Request):
|
||||
"""Return server configuration including host and port."""
|
||||
return {"host": request.client.host, "port": request.scope.get("server")[1]}
|
||||
|
||||
|
||||
def run_server(host: str = "0.0.0.0", port: int = 8080):
|
||||
"""Run the FastAPI server."""
|
||||
logger.info(f"Starting server on {host}:{port}")
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=host,
|
||||
port=port,
|
||||
log_level="debug",
|
||||
ws_max_size=16777216, # 16MB
|
||||
timeout_keep_alive=0, # Disable keep-alive timeout
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="RA.Aid Web Interface Server")
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=8080, help="Port to listen on (default: 8080)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="0.0.0.0",
|
||||
help="Host to listen on (default: 0.0.0.0)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
run_server(host=args.host, port=args.port)
|
||||
|
|
|
|||
|
|
@ -1,90 +1,214 @@
|
|||
<!DOCTYPE html>
|
||||
<html lang="en" class="h-full bg-gray-900">
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<meta name="server-port" content="{{ server_port }}">
|
||||
<title>RA.Aid Web Interface</title>
|
||||
<title>RA.Aid Web UI</title>
|
||||
<script src="https://cdn.tailwindcss.com"></script>
|
||||
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/styles/tokyo-night-dark.min.css">
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/highlight.min.js"></script>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/languages/python.min.js"></script>
|
||||
<script>
|
||||
tailwind.config = {
|
||||
darkMode: 'class',
|
||||
theme: {
|
||||
extend: {
|
||||
colors: {
|
||||
'dark-primary': '#1a1b26',
|
||||
'dark-secondary': '#24283b',
|
||||
'dark-accent': '#7aa2f7',
|
||||
'dark-text': '#c0caf5'
|
||||
'dark-accent': '#414868',
|
||||
'dark-text': '#a9b1d6',
|
||||
'dark-background': '#16161e'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
</script>
|
||||
<style>
|
||||
/* Core styles */
|
||||
body {
|
||||
background-color: #16161e;
|
||||
color: #a9b1d6;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
height: 100vh;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
/* Output area */
|
||||
#stream-output {
|
||||
flex: 1;
|
||||
overflow-y: auto;
|
||||
padding: 1rem;
|
||||
background-color: #1a1b26;
|
||||
margin: 0;
|
||||
display: block !important;
|
||||
min-height: 300px;
|
||||
font-family: monospace;
|
||||
white-space: pre-wrap;
|
||||
word-break: break-word;
|
||||
}
|
||||
|
||||
/* Message styling */
|
||||
.message {
|
||||
margin-bottom: 1.5rem;
|
||||
padding: 0.75rem 1rem;
|
||||
border-radius: 0.375rem;
|
||||
background-color: #24283b;
|
||||
color: #a9b1d6;
|
||||
display: block !important;
|
||||
border: 1px solid #414868;
|
||||
opacity: 0;
|
||||
transform: translateY(20px);
|
||||
animation: slideIn 0.3s ease forwards;
|
||||
}
|
||||
|
||||
@keyframes slideIn {
|
||||
to {
|
||||
opacity: 1;
|
||||
transform: translateY(0);
|
||||
}
|
||||
}
|
||||
|
||||
/* Message types */
|
||||
.message.info {
|
||||
border-color: #7aa2f7;
|
||||
}
|
||||
|
||||
.message.success {
|
||||
border-color: #9ece6a;
|
||||
background-color: rgba(158, 206, 106, 0.1);
|
||||
}
|
||||
|
||||
.message.error {
|
||||
border-color: #f7768e;
|
||||
background-color: rgba(247, 118, 142, 0.1);
|
||||
}
|
||||
|
||||
.message.warning {
|
||||
border-color: #e0af68;
|
||||
background-color: rgba(224, 175, 104, 0.1);
|
||||
}
|
||||
|
||||
/* Section spacing */
|
||||
.message.section {
|
||||
margin-top: 2rem;
|
||||
margin-bottom: 2rem;
|
||||
padding: 1rem;
|
||||
background-color: rgba(122, 162, 247, 0.1);
|
||||
}
|
||||
|
||||
/* Code blocks */
|
||||
.message pre {
|
||||
margin: 0.5rem 0;
|
||||
padding: 0.75rem;
|
||||
border-radius: 0.25rem;
|
||||
background-color: #1a1b26 !important;
|
||||
border: 1px solid #414868;
|
||||
overflow-x: auto;
|
||||
}
|
||||
|
||||
.message code {
|
||||
font-family: 'JetBrains Mono', monospace;
|
||||
font-size: 0.9em;
|
||||
}
|
||||
|
||||
/* Input area */
|
||||
.input-area {
|
||||
padding: 1rem;
|
||||
background-color: #24283b;
|
||||
border-top: 1px solid #414868;
|
||||
}
|
||||
|
||||
.input-container {
|
||||
display: flex;
|
||||
gap: 1rem;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
#user-input {
|
||||
flex: 1;
|
||||
padding: 0.75rem;
|
||||
background-color: #1a1b26;
|
||||
border: 1px solid #414868;
|
||||
border-radius: 0.375rem;
|
||||
color: #a9b1d6;
|
||||
font-family: 'JetBrains Mono', monospace;
|
||||
transition: border-color 0.2s;
|
||||
}
|
||||
|
||||
#user-input:focus {
|
||||
outline: none;
|
||||
border-color: #7aa2f7;
|
||||
}
|
||||
|
||||
/* Buttons */
|
||||
.button {
|
||||
padding: 0.75rem 1rem;
|
||||
background-color: #7aa2f7;
|
||||
color: white;
|
||||
border: none;
|
||||
border-radius: 0.375rem;
|
||||
cursor: pointer;
|
||||
min-width: 80px;
|
||||
transition: background-color 0.2s;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
gap: 0.5rem;
|
||||
}
|
||||
|
||||
.button:hover {
|
||||
background-color: #5d87e6;
|
||||
}
|
||||
|
||||
.button:disabled {
|
||||
opacity: 0.5;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.button.clear {
|
||||
background-color: #414868;
|
||||
}
|
||||
|
||||
.button.clear:hover {
|
||||
background-color: #363b54;
|
||||
}
|
||||
|
||||
/* Icons */
|
||||
.icon {
|
||||
font-size: 1.2em;
|
||||
line-height: 1;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body class="h-full bg-dark-primary text-dark-text">
|
||||
<div class="flex h-full">
|
||||
<!-- Sidebar -->
|
||||
<div class="w-64 bg-dark-secondary border-r border-gray-700 flex flex-col">
|
||||
<div class="p-4 border-b border-gray-700">
|
||||
<h2 class="text-xl font-semibold text-dark-accent">History</h2>
|
||||
</div>
|
||||
<div id="history-list" class="flex-1 overflow-y-auto p-4 space-y-2"></div>
|
||||
</div>
|
||||
<body>
|
||||
<!-- Main Container -->
|
||||
<div id="stream-output">
|
||||
<!-- Messages will appear here -->
|
||||
</div>
|
||||
|
||||
<!-- Main Content -->
|
||||
<div class="flex-1 flex flex-col min-w-0">
|
||||
<!-- Chat Container -->
|
||||
<div class="flex-1 overflow-y-auto p-4 space-y-4" id="chat-container">
|
||||
<div id="chat-messages"></div>
|
||||
<div id="stream-output" class="hidden font-mono bg-dark-secondary rounded-lg p-4 text-sm"></div>
|
||||
</div>
|
||||
|
||||
<!-- Input Area -->
|
||||
<div class="border-t border-gray-700 p-4 bg-dark-secondary">
|
||||
<div class="flex space-x-4">
|
||||
<textarea
|
||||
id="user-input"
|
||||
class="flex-1 bg-dark-primary border border-gray-700 rounded-lg p-3 text-dark-text placeholder-gray-500 focus:outline-none focus:ring-2 focus:ring-dark-accent resize-none"
|
||||
placeholder="Type your request here..."
|
||||
rows="3"
|
||||
></textarea>
|
||||
<button
|
||||
id="send-button"
|
||||
class="px-6 py-2 bg-dark-accent text-white rounded-lg hover:bg-opacity-90 focus:outline-none focus:ring-2 focus:ring-offset-2 focus:ring-dark-accent disabled:opacity-50 disabled:cursor-not-allowed h-fit"
|
||||
>
|
||||
Send
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<!-- Input Area -->
|
||||
<div class="input-area">
|
||||
<div class="input-container">
|
||||
<input
|
||||
type="text"
|
||||
id="user-input"
|
||||
placeholder="Type your message..."
|
||||
>
|
||||
<button id="clear-button" class="button clear" title="Clear conversation">
|
||||
<span class="icon">🗑️</span>
|
||||
</button>
|
||||
<button id="send-button" class="button">
|
||||
<span class="icon">📤</span>
|
||||
Send
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
// Add dynamic styles for messages
|
||||
const style = document.createElement('style');
|
||||
style.textContent = `
|
||||
.message {
|
||||
@apply mb-4 p-4 rounded-lg max-w-3xl;
|
||||
}
|
||||
.user-message {
|
||||
@apply bg-dark-accent text-white ml-auto;
|
||||
}
|
||||
.system-message {
|
||||
@apply bg-dark-secondary mr-auto;
|
||||
}
|
||||
.error-message {
|
||||
@apply bg-red-900 text-red-100 mr-auto;
|
||||
}
|
||||
.history-item {
|
||||
@apply p-3 rounded-lg hover:bg-dark-primary cursor-pointer transition-colors duration-200 text-sm;
|
||||
}
|
||||
#stream-output:not(:empty) {
|
||||
@apply block;
|
||||
}
|
||||
`;
|
||||
document.head.appendChild(style);
|
||||
</script>
|
||||
<script src="/static/script.js"></script>
|
||||
<script>
|
||||
hljs.highlightAll();
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
|
|
@ -1,206 +1,230 @@
|
|||
class RAWebUI {
|
||||
class WebSocketHandler {
|
||||
constructor() {
|
||||
this.messageHistory = [];
|
||||
this.connectionAttempts = 0;
|
||||
this.maxReconnectAttempts = 5;
|
||||
this.setupElements();
|
||||
this.setupEventListeners();
|
||||
this.connectWebSocket();
|
||||
// Wait for DOM to be ready
|
||||
if (document.readyState === 'loading') {
|
||||
document.addEventListener('DOMContentLoaded', () => this.initialize());
|
||||
} else {
|
||||
this.initialize();
|
||||
}
|
||||
}
|
||||
|
||||
setupElements() {
|
||||
this.userInput = document.getElementById('user-input');
|
||||
initialize() {
|
||||
// Store DOM elements as instance variables
|
||||
this.messageInput = document.getElementById('user-input');
|
||||
this.sendButton = document.getElementById('send-button');
|
||||
this.chatMessages = document.getElementById('chat-messages');
|
||||
this.clearButton = document.getElementById('clear-button');
|
||||
this.streamOutput = document.getElementById('stream-output');
|
||||
this.historyList = document.getElementById('history-list');
|
||||
|
||||
// Disable send button initially
|
||||
this.sendButton.disabled = true;
|
||||
}
|
||||
|
||||
setupEventListeners() {
|
||||
// Validate required elements exist
|
||||
if (!this.messageInput || !this.sendButton || !this.streamOutput) {
|
||||
console.error('Required elements not found:', {
|
||||
messageInput: !!this.messageInput,
|
||||
sendButton: !!this.sendButton,
|
||||
streamOutput: !!this.streamOutput
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Remove hidden class if present
|
||||
this.streamOutput.classList.remove('hidden');
|
||||
|
||||
// Initialize WebSocket
|
||||
this.connectWebSocket();
|
||||
|
||||
// Add event listeners
|
||||
this.sendButton.addEventListener('click', () => this.sendMessage());
|
||||
this.userInput.addEventListener('keypress', (e) => {
|
||||
this.clearButton?.addEventListener('click', () => this.clearConversation());
|
||||
this.messageInput.addEventListener('keypress', (e) => {
|
||||
if (e.key === 'Enter' && !e.shiftKey) {
|
||||
e.preventDefault();
|
||||
this.sendMessage();
|
||||
}
|
||||
});
|
||||
|
||||
console.log('WebSocketHandler initialized with elements:', {
|
||||
messageInput: this.messageInput,
|
||||
sendButton: this.sendButton,
|
||||
streamOutput: this.streamOutput
|
||||
});
|
||||
}
|
||||
|
||||
async connectWebSocket() {
|
||||
// Don't try to reconnect if we've exceeded the maximum attempts
|
||||
if (this.connectionAttempts >= this.maxReconnectAttempts) {
|
||||
this.appendMessage(
|
||||
'Maximum reconnection attempts reached. Please refresh the page.',
|
||||
'error'
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
connectWebSocket() {
|
||||
try {
|
||||
// Get the server port from the meta tag
|
||||
const serverPort = document.querySelector('meta[name="server-port"]')?.content || '8080';
|
||||
|
||||
// Construct WebSocket URL
|
||||
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
||||
const wsUrl = `${protocol}//${window.location.hostname}:${serverPort}/ws`;
|
||||
const wsUrl = `ws://${window.location.host}/ws`;
|
||||
console.log('Attempting to connect to WebSocket URL:', wsUrl);
|
||||
|
||||
// Close existing connection if any
|
||||
if (this.ws) {
|
||||
this.ws.close();
|
||||
}
|
||||
|
||||
// Create new WebSocket connection
|
||||
console.log('Creating new WebSocket connection...');
|
||||
this.ws = new WebSocket(wsUrl);
|
||||
this.connectionAttempts++;
|
||||
|
||||
// Setup WebSocket event handlers
|
||||
|
||||
this.ws.onopen = () => {
|
||||
console.log('WebSocket connection established successfully');
|
||||
this.connectionAttempts = 0; // Reset counter on successful connection
|
||||
this.sendButton.disabled = false;
|
||||
};
|
||||
|
||||
this.ws.onclose = (event) => {
|
||||
console.log('WebSocket closed:', event.code, event.reason);
|
||||
this.ws.onclose = () => {
|
||||
console.log('WebSocket connection closed');
|
||||
this.sendButton.disabled = true;
|
||||
// Try to reconnect after a delay
|
||||
setTimeout(() => this.connectWebSocket(), 2000);
|
||||
};
|
||||
|
||||
// Only attempt reconnect if not a normal closure and within retry limits
|
||||
if (event.code !== 1000 && this.connectionAttempts < this.maxReconnectAttempts) {
|
||||
const delay = Math.min(1000 * Math.pow(2, this.connectionAttempts), 10000);
|
||||
this.appendMessage(
|
||||
`Connection lost. Reconnecting in ${delay/1000} seconds...`,
|
||||
'error'
|
||||
);
|
||||
setTimeout(() => this.connectWebSocket(), delay);
|
||||
this.ws.onmessage = (event) => {
|
||||
try {
|
||||
console.log('Raw WebSocket message:', event.data);
|
||||
const message = JSON.parse(event.data);
|
||||
console.log('Parsed WebSocket message:', message);
|
||||
this.handleMessage(message);
|
||||
} catch (error) {
|
||||
console.error('Error handling message:', error);
|
||||
this.appendOutput({
|
||||
content: `Error: ${error.message}`,
|
||||
status: 'error'
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
this.ws.onerror = (error) => {
|
||||
console.error('WebSocket error:', error);
|
||||
this.sendButton.disabled = true;
|
||||
this.appendOutput({
|
||||
content: 'Connection error. Retrying...',
|
||||
status: 'error'
|
||||
});
|
||||
};
|
||||
|
||||
this.ws.onmessage = (event) => {
|
||||
try {
|
||||
const data = JSON.parse(event.data);
|
||||
this.handleServerMessage(data);
|
||||
} catch (error) {
|
||||
console.error('Error parsing message:', error);
|
||||
this.appendMessage('Error processing server message', 'error');
|
||||
}
|
||||
};
|
||||
|
||||
} catch (error) {
|
||||
console.error('Failed to connect to WebSocket:', error);
|
||||
this.appendMessage(
|
||||
`Connection error: ${error.message}. Retrying...`,
|
||||
'error'
|
||||
);
|
||||
|
||||
// Attempt to reconnect with exponential backoff
|
||||
const delay = Math.min(1000 * Math.pow(2, this.connectionAttempts), 10000);
|
||||
setTimeout(() => this.connectWebSocket(), delay);
|
||||
console.error('Error connecting to WebSocket:', error);
|
||||
this.appendOutput({
|
||||
content: `Connection error: ${error.message}`,
|
||||
status: 'error'
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
handleServerMessage(data) {
|
||||
if (data.type === 'stream_start') {
|
||||
this.streamOutput.textContent = '';
|
||||
this.streamOutput.style.display = 'block';
|
||||
} else if (data.type === 'stream_end') {
|
||||
this.streamOutput.style.display = 'none';
|
||||
this.addToHistory(data.request);
|
||||
this.sendButton.disabled = false;
|
||||
} else if (data.type === 'chunk') {
|
||||
this.handleChunk(data.chunk);
|
||||
handleMessage(message) {
|
||||
switch (message.type) {
|
||||
case 'stream_start':
|
||||
this.handleStreamStart();
|
||||
break;
|
||||
case 'chunk':
|
||||
this.handleChunk(message.chunk);
|
||||
break;
|
||||
case 'stream_end':
|
||||
this.handleStreamEnd();
|
||||
break;
|
||||
default:
|
||||
console.warn('Unknown message type:', message.type);
|
||||
}
|
||||
}
|
||||
|
||||
handleStreamStart() {
|
||||
console.log('Stream starting');
|
||||
this.clearStreamOutput();
|
||||
this.appendOutput({
|
||||
content: 'Starting new conversation...',
|
||||
status: 'info'
|
||||
});
|
||||
}
|
||||
|
||||
handleStreamEnd() {
|
||||
console.log('Stream ending');
|
||||
this.appendOutput({
|
||||
content: 'Conversation complete.',
|
||||
status: 'success'
|
||||
});
|
||||
this.messageInput.disabled = false;
|
||||
this.sendButton.disabled = false;
|
||||
}
|
||||
|
||||
handleChunk(chunk) {
|
||||
console.log(' Processing chunk:', chunk);
|
||||
if (chunk.agent && chunk.agent.messages) {
|
||||
chunk.agent.messages.forEach(msg => {
|
||||
if (msg.content) {
|
||||
if (Array.isArray(msg.content)) {
|
||||
msg.content.forEach(content => {
|
||||
if (content.type === 'text' && content.text.trim()) {
|
||||
this.appendMessage(content.text.trim(), 'system');
|
||||
}
|
||||
});
|
||||
} else if (msg.content.trim()) {
|
||||
this.appendMessage(msg.content.trim(), 'system');
|
||||
}
|
||||
}
|
||||
});
|
||||
} else if (chunk.tools && chunk.tools.messages) {
|
||||
chunk.tools.messages.forEach(msg => {
|
||||
if (msg.status === 'error' && msg.content) {
|
||||
this.appendMessage(msg.content.trim(), 'error');
|
||||
}
|
||||
chunk.agent.messages.forEach(message => {
|
||||
console.log(' Processing agent message:', message);
|
||||
console.log(' Adding agent message:', message.content);
|
||||
this.appendOutput(message);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
appendMessage(content, type) {
|
||||
const messageDiv = document.createElement('div');
|
||||
messageDiv.className = `message ${type}-message`;
|
||||
messageDiv.textContent = content;
|
||||
this.chatMessages.appendChild(messageDiv);
|
||||
this.chatMessages.scrollTop = this.chatMessages.scrollHeight;
|
||||
clearStreamOutput() {
|
||||
console.log('Clearing stream output');
|
||||
while (this.streamOutput.firstChild) {
|
||||
this.streamOutput.removeChild(this.streamOutput.firstChild);
|
||||
}
|
||||
console.log('Stream output cleared');
|
||||
}
|
||||
|
||||
addToHistory(request) {
|
||||
const historyItem = document.createElement('div');
|
||||
historyItem.className = 'history-item';
|
||||
historyItem.textContent = request.slice(0, 50) + (request.length > 50 ? '...' : '');
|
||||
historyItem.title = request;
|
||||
historyItem.addEventListener('click', () => {
|
||||
this.userInput.value = request;
|
||||
this.userInput.focus();
|
||||
clearConversation() {
|
||||
this.clearStreamOutput();
|
||||
this.appendOutput({
|
||||
content: 'Conversation cleared.',
|
||||
status: 'info'
|
||||
});
|
||||
}
|
||||
|
||||
appendOutput(message) {
|
||||
console.log(' Appending output:', message);
|
||||
const messageDiv = document.createElement('div');
|
||||
messageDiv.className = `message ${message.status || 'info'}`;
|
||||
|
||||
const contentSpan = document.createElement('span');
|
||||
|
||||
// Convert ANSI escape codes to HTML
|
||||
let content = message.content;
|
||||
content = this.convertAnsiToHtml(content);
|
||||
|
||||
// Check for code blocks and apply syntax highlighting
|
||||
if (content.includes('```')) {
|
||||
content = this.highlightCodeBlocks(content);
|
||||
}
|
||||
|
||||
contentSpan.innerHTML = content;
|
||||
messageDiv.appendChild(contentSpan);
|
||||
|
||||
this.streamOutput.appendChild(messageDiv);
|
||||
messageDiv.scrollIntoView({ behavior: 'smooth', block: 'end' });
|
||||
}
|
||||
|
||||
convertAnsiToHtml(text) {
|
||||
// ANSI color codes to CSS classes
|
||||
const ansiToClass = {
|
||||
'[94m': '<span class="text-blue-400">',
|
||||
'[1;32m': '<span class="text-green-400 font-bold">',
|
||||
'[0m': '</span>'
|
||||
};
|
||||
|
||||
// Replace ANSI codes with HTML
|
||||
let html = text;
|
||||
for (const [ansi, htmlClass] of Object.entries(ansiToClass)) {
|
||||
html = html.replaceAll('\u001b' + ansi, htmlClass);
|
||||
}
|
||||
return html;
|
||||
}
|
||||
|
||||
highlightCodeBlocks(content) {
|
||||
const codeBlockRegex = /```(\w+)?\n([\s\S]*?)```/g;
|
||||
return content.replace(codeBlockRegex, (match, lang, code) => {
|
||||
const language = lang || 'plaintext';
|
||||
const highlighted = hljs.highlight(code.trim(), { language }).value;
|
||||
return `<pre><code class="language-${language}">${highlighted}</code></pre>`;
|
||||
});
|
||||
this.historyList.insertBefore(historyItem, this.historyList.firstChild);
|
||||
this.messageHistory.push(request);
|
||||
}
|
||||
|
||||
sendMessage() {
|
||||
console.log('Send button clicked');
|
||||
const message = this.userInput.value.trim();
|
||||
console.log('Message content:', message);
|
||||
|
||||
if (!message) {
|
||||
console.log('Message is empty, not sending');
|
||||
return;
|
||||
}
|
||||
const message = this.messageInput.value.trim();
|
||||
if (!message) return;
|
||||
|
||||
if (!this.ws || this.ws.readyState !== WebSocket.OPEN) {
|
||||
console.error('WebSocket is not connected');
|
||||
this.appendMessage('Error: Not connected to server. Please wait...', 'error');
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
console.log('Sending message to server');
|
||||
this.appendMessage(message, 'user');
|
||||
const payload = { type: 'request', content: message };
|
||||
console.log('Payload:', payload);
|
||||
|
||||
this.ws.send(JSON.stringify(payload));
|
||||
console.log('Message sent successfully');
|
||||
|
||||
this.userInput.value = '';
|
||||
this.sendButton.disabled = true;
|
||||
} catch (error) {
|
||||
console.error('Error sending message:', error);
|
||||
this.appendMessage(`Error sending message: ${error.message}`, 'error');
|
||||
this.sendButton.disabled = false;
|
||||
}
|
||||
console.log('Sending message:', message);
|
||||
this.ws.send(JSON.stringify({
|
||||
type: "request",
|
||||
content: message
|
||||
}));
|
||||
this.messageInput.value = '';
|
||||
this.messageInput.disabled = true;
|
||||
this.sendButton.disabled = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize the UI when the page loads
|
||||
document.addEventListener('DOMContentLoaded', () => {
|
||||
window.raWebUI = new RAWebUI();
|
||||
// Initialize WebSocket handler when the page loads
|
||||
window.addEventListener('load', () => {
|
||||
new WebSocketHandler();
|
||||
});
|
||||
|
|
@ -0,0 +1,66 @@
|
|||
"""Tests for Windows-specific functionality."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from ra_aid.proc.interactive import get_terminal_size, create_process, run_interactive_command
|
||||
|
||||
@pytest.mark.skipif(sys.platform != "win32", reason="Windows-specific tests")
|
||||
class TestWindowsCompatibility:
|
||||
"""Test suite for Windows-specific functionality."""
|
||||
|
||||
def test_get_terminal_size(self):
|
||||
"""Test terminal size detection on Windows."""
|
||||
with patch('shutil.get_terminal_size') as mock_get_size:
|
||||
mock_get_size.return_value = MagicMock(columns=120, lines=30)
|
||||
cols, rows = get_terminal_size()
|
||||
assert cols == 120
|
||||
assert rows == 30
|
||||
mock_get_size.assert_called_once()
|
||||
|
||||
def test_create_process(self):
|
||||
"""Test process creation on Windows."""
|
||||
with patch('subprocess.Popen') as mock_popen:
|
||||
mock_process = MagicMock()
|
||||
mock_process.returncode = 0
|
||||
mock_popen.return_value = mock_process
|
||||
|
||||
proc, _ = create_process(['echo', 'test'])
|
||||
|
||||
assert mock_popen.called
|
||||
args, kwargs = mock_popen.call_args
|
||||
assert kwargs['stdin'] == subprocess.PIPE
|
||||
assert kwargs['stdout'] == subprocess.PIPE
|
||||
assert kwargs['stderr'] == subprocess.PIPE
|
||||
assert 'startupinfo' in kwargs
|
||||
assert kwargs['startupinfo'].dwFlags & subprocess.STARTF_USESHOWWINDOW
|
||||
|
||||
def test_run_interactive_command(self):
|
||||
"""Test running an interactive command on Windows."""
|
||||
test_output = "Test output\n"
|
||||
|
||||
with patch('subprocess.Popen') as mock_popen:
|
||||
mock_process = MagicMock()
|
||||
mock_process.stdout = MagicMock()
|
||||
mock_process.stdout.read.return_value = test_output.encode()
|
||||
mock_process.wait.return_value = 0
|
||||
mock_popen.return_value = mock_process
|
||||
|
||||
output, return_code = run_interactive_command(['echo', 'test'])
|
||||
assert return_code == 0
|
||||
assert "Test output" in output.decode()
|
||||
|
||||
def test_windows_dependencies(self):
|
||||
"""Test that required Windows dependencies are available."""
|
||||
if sys.platform == "win32":
|
||||
import msvcrt
|
||||
import win32pipe
|
||||
import win32file
|
||||
import win32con
|
||||
import win32process
|
||||
|
||||
# If we get here without ImportError, the test passes
|
||||
assert True
|
||||
Loading…
Reference in New Issue