RA.Aid/webui/server.py

211 lines
7.9 KiB
Python
Executable File

#!/usr/bin/env python3
import argparse
import asyncio
import shutil
import sys
from pathlib import Path
from typing import List
import uvicorn
from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
# Verify ra-aid is available
if not shutil.which("ra-aid"):
print(
"Error: ra-aid command not found. Please ensure it's installed and in your PATH"
)
sys.exit(1)
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, replace with specific origins
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Setup templates
templates = Jinja2Templates(directory=Path(__file__).parent)
# Create a route for the root to serve index.html with port parameter
@app.get("/", response_class=HTMLResponse)
async def get_root(request: Request):
"""Serve the index.html file with port parameter."""
return templates.TemplateResponse(
"index.html", {"request": request, "server_port": request.url.port or 8080}
)
# Mount static files for js and other assets
app.mount("/static", StaticFiles(directory=Path(__file__).parent), name="static")
# Store WebSocket connections
# Store active WebSocket connections
active_connections: List[WebSocket] = []
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
print(f"New WebSocket connection from {websocket.client}")
await websocket.accept()
print("WebSocket connection accepted")
active_connections.append(websocket)
try:
while True:
print("Waiting for message...")
message = await websocket.receive_json()
print(f"Received message: {message}")
if message["type"] == "request":
print(f"Processing request: {message['content']}")
# Notify client that streaming is starting
await websocket.send_json({"type": "stream_start"})
try:
# Run ra-aid with the request using -m flag and cowboy mode
cmd = ["ra-aid", "-m", message["content"], "--cowboy-mode"]
print(f"Executing command: {' '.join(cmd)}")
# Create subprocess
process = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
print(f"Process started with PID: {process.pid}")
# Read output and errors concurrently
async def read_stream(stream, is_error=False):
stream_type = "stderr" if is_error else "stdout"
print(f"Starting to read from {stream_type}")
while True:
line = await stream.readline()
if not line:
print(f"End of {stream_type} stream")
break
try:
decoded_line = line.decode().strip()
print(f"{stream_type} line: {decoded_line}")
if decoded_line:
await websocket.send_json(
{
"type": "chunk",
"chunk": {
"tools" if is_error else "agent": {
"messages": [
{
"content": decoded_line,
"status": (
"error"
if is_error
else "info"
),
}
]
}
},
}
)
except Exception as e:
print(f"Error sending output: {e}")
# Create tasks for reading stdout and stderr
stdout_task = asyncio.create_task(read_stream(process.stdout))
stderr_task = asyncio.create_task(read_stream(process.stderr, True))
# Wait for both streams to complete
await asyncio.gather(stdout_task, stderr_task)
# Wait for process to complete
return_code = await process.wait()
if return_code != 0:
await websocket.send_json(
{
"type": "chunk",
"chunk": {
"tools": {
"messages": [
{
"content": f"Process exited with code {return_code}",
"status": "error",
}
]
}
},
}
)
# Notify client that streaming is complete
await websocket.send_json(
{"type": "stream_end", "request": message["content"]}
)
except Exception as e:
error_msg = f"Error executing ra-aid: {str(e)}"
print(error_msg)
await websocket.send_json(
{
"type": "chunk",
"chunk": {
"tools": {
"messages": [
{"content": error_msg, "status": "error"}
]
}
},
}
)
except WebSocketDisconnect:
print("WebSocket client disconnected")
active_connections.remove(websocket)
except Exception as e:
print(f"WebSocket error: {e}")
try:
await websocket.send_json({"type": "error", "error": str(e)})
except Exception:
pass
finally:
if websocket in active_connections:
active_connections.remove(websocket)
print("WebSocket connection cleaned up")
@app.get("/config")
async def get_config(request: Request):
"""Return server configuration including host and port."""
return {"host": request.client.host, "port": request.scope.get("server")[1]}
def run_server(host: str = "0.0.0.0", port: int = 8080):
"""Run the FastAPI server."""
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="RA.Aid Web Interface Server")
parser.add_argument(
"--port", type=int, default=8080, help="Port to listen on (default: 8080)"
)
parser.add_argument(
"--host",
type=str,
default="0.0.0.0",
help="Host to listen on (default: 0.0.0.0)",
)
args = parser.parse_args()
run_server(host=args.host, port=args.port)