Compare commits
1 Commits
| Author | SHA1 | Date |
|---|---|---|
|
|
9ef8d1157c |
|
|
@ -16,6 +16,3 @@ appmap.log
|
|||
*.swp
|
||||
/vsc/node_modules
|
||||
/vsc/dist
|
||||
node_modules/
|
||||
/frontend/common/dist
|
||||
/frontend/web/dist/
|
||||
|
|
|
|||
|
|
@ -5,13 +5,6 @@ All notable changes to this project will be documented in this file.
|
|||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [0.17.1] 2025-03-13
|
||||
|
||||
### Fixed
|
||||
- Fixed bug with `process_thinking_content` function by moving it from `agent_utils` to `ra_aid.text.processing` module
|
||||
- Fixed config parameter handling in research request functions
|
||||
- Updated development setup instructions in README to use `pip install -e ".[dev]"` instead of `pip install -r requirements-dev.txt`
|
||||
|
||||
## [0.17.0] 2025-03-12
|
||||
|
||||
### Added
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
include LICENSE
|
||||
include README.md
|
||||
include CHANGELOG.md
|
||||
recursive-include ra_aid/server/static *
|
||||
recursive-include ra_aid/webui/static *
|
||||
|
|
|
|||
30
README.md
30
README.md
|
|
@ -226,9 +226,9 @@ More information is available in our [Usage Examples](https://docs.ra-aid.ai/cat
|
|||
- `--max-test-cmd-retries`: Maximum number of test command retry attempts (default: 3)
|
||||
- `--test-cmd-timeout`: Timeout in seconds for test command execution (default: 300)
|
||||
- `--version`: Show program version number and exit
|
||||
- `--server`: Launch the server with web interface (alpha feature)
|
||||
- `--server-host`: Host to listen on for server (default: 0.0.0.0) (alpha feature)
|
||||
- `--server-port`: Port to listen on for server (default: 1818) (alpha feature)
|
||||
- `--webui`: Launch the web interface (alpha feature)
|
||||
- `--webui-host`: Host to listen on for web interface (default: 0.0.0.0) (alpha feature)
|
||||
- `--webui-port`: Port to listen on for web interface (default: 8080) (alpha feature)
|
||||
|
||||
### Example Tasks
|
||||
|
||||
|
|
@ -305,30 +305,30 @@ Make sure to set your TAVILY_API_KEY environment variable to enable this feature
|
|||
|
||||
Enable with `--chat` to transform ra-aid into an interactive assistant that guides you through research and implementation tasks. Have a natural conversation about what you want to build, explore options together, and dispatch work - all while maintaining context of your discussion. Perfect for when you want to think through problems collaboratively rather than just executing commands.
|
||||
|
||||
### Server with Web Interface
|
||||
### Web Interface
|
||||
|
||||
RA.Aid includes a modern server with web interface that provides:
|
||||
RA.Aid includes a modern web interface that provides:
|
||||
- Beautiful dark-themed chat interface
|
||||
- Real-time streaming of command output
|
||||
- Request history with quick resubmission
|
||||
- Responsive design that works on all devices
|
||||
|
||||
To launch the server with web interface:
|
||||
To launch the web interface:
|
||||
|
||||
```bash
|
||||
# Start with default settings (0.0.0.0:1818)
|
||||
ra-aid --server
|
||||
# Start with default settings (0.0.0.0:8080)
|
||||
ra-aid --webui
|
||||
|
||||
# Specify custom host and port
|
||||
ra-aid --server --server-host 127.0.0.1 --server-port 3000
|
||||
ra-aid --webui --webui-host 127.0.0.1 --webui-port 3000
|
||||
```
|
||||
|
||||
Command line options for server with web interface:
|
||||
- `--server`: Launch the server with web interface
|
||||
- `--server-host`: Host to listen on (default: 0.0.0.0)
|
||||
- `--server-port`: Port to listen on (default: 1818)
|
||||
Command line options for web interface:
|
||||
- `--webui`: Launch the web interface
|
||||
- `--webui-host`: Host to listen on (default: 0.0.0.0)
|
||||
- `--webui-port`: Port to listen on (default: 8080)
|
||||
|
||||
After starting the server, open your web browser to the displayed URL (e.g., http://localhost:1818). The interface provides:
|
||||
After starting the server, open your web browser to the displayed URL (e.g., http://localhost:8080). The interface provides:
|
||||
- Left sidebar showing request history
|
||||
- Main chat area with real-time output
|
||||
- Input box for typing requests
|
||||
|
|
@ -541,7 +541,7 @@ source venv/bin/activate # On Windows use `venv\Scripts\activate`
|
|||
|
||||
3. Install development dependencies:
|
||||
```bash
|
||||
pip install -e ".[dev]"
|
||||
pip install -r requirements-dev.txt
|
||||
```
|
||||
|
||||
4. Run tests:
|
||||
|
|
|
|||
|
|
@ -1,16 +0,0 @@
|
|||
{
|
||||
"$schema": "https://ui.shadcn.com/schema.json",
|
||||
"style": "new-york",
|
||||
"rsc": false,
|
||||
"tsx": true,
|
||||
"tailwind": {
|
||||
"config": "frontend/common/tailwind.config.js",
|
||||
"css": "frontend/common/src/styles/global.css",
|
||||
"baseColor": "zinc",
|
||||
"cssVariables": true
|
||||
},
|
||||
"aliases": {
|
||||
"components": "@ra-aid/common/components",
|
||||
"utils": "@ra-aid/common/utils"
|
||||
}
|
||||
}
|
||||
|
|
@ -1,11 +0,0 @@
|
|||
import * as React from "react";
|
||||
import { type VariantProps } from "class-variance-authority";
|
||||
declare const buttonVariants: (props?: ({
|
||||
variant?: "default" | "destructive" | "outline" | "secondary" | "ghost" | "link" | null | undefined;
|
||||
size?: "default" | "sm" | "lg" | "icon" | null | undefined;
|
||||
} & import("class-variance-authority/dist/types").ClassProp) | undefined) => string;
|
||||
export interface ButtonProps extends React.ButtonHTMLAttributes<HTMLButtonElement>, VariantProps<typeof buttonVariants> {
|
||||
asChild?: boolean;
|
||||
}
|
||||
declare const Button: React.ForwardRefExoticComponent<ButtonProps & React.RefAttributes<HTMLButtonElement>>;
|
||||
export { Button, buttonVariants };
|
||||
|
|
@ -1,44 +0,0 @@
|
|||
var __rest = (this && this.__rest) || function (s, e) {
|
||||
var t = {};
|
||||
for (var p in s) if (Object.prototype.hasOwnProperty.call(s, p) && e.indexOf(p) < 0)
|
||||
t[p] = s[p];
|
||||
if (s != null && typeof Object.getOwnPropertySymbols === "function")
|
||||
for (var i = 0, p = Object.getOwnPropertySymbols(s); i < p.length; i++) {
|
||||
if (e.indexOf(p[i]) < 0 && Object.prototype.propertyIsEnumerable.call(s, p[i]))
|
||||
t[p[i]] = s[p[i]];
|
||||
}
|
||||
return t;
|
||||
};
|
||||
import * as React from "react";
|
||||
import { Slot } from "@radix-ui/react-slot";
|
||||
import { cva } from "class-variance-authority";
|
||||
import { cn } from "../../utils";
|
||||
const buttonVariants = cva("inline-flex items-center justify-center whitespace-nowrap rounded-md text-sm font-medium transition-colors focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring disabled:pointer-events-none disabled:opacity-50", {
|
||||
variants: {
|
||||
variant: {
|
||||
default: "bg-primary text-primary-foreground shadow hover:bg-primary/90",
|
||||
destructive: "bg-destructive text-destructive-foreground shadow-sm hover:bg-destructive/90",
|
||||
outline: "border border-input bg-background shadow-sm hover:bg-accent hover:text-accent-foreground",
|
||||
secondary: "bg-secondary text-secondary-foreground shadow-sm hover:bg-secondary/80",
|
||||
ghost: "hover:bg-accent hover:text-accent-foreground",
|
||||
link: "text-primary underline-offset-4 hover:underline",
|
||||
},
|
||||
size: {
|
||||
default: "h-9 px-4 py-2",
|
||||
sm: "h-8 rounded-md px-3 text-xs",
|
||||
lg: "h-10 rounded-md px-8",
|
||||
icon: "h-9 w-9",
|
||||
},
|
||||
},
|
||||
defaultVariants: {
|
||||
variant: "default",
|
||||
size: "default",
|
||||
},
|
||||
});
|
||||
const Button = React.forwardRef((_a, ref) => {
|
||||
var { className, variant, size, asChild = false } = _a, props = __rest(_a, ["className", "variant", "size", "asChild"]);
|
||||
const Comp = asChild ? Slot : "button";
|
||||
return (React.createElement(Comp, Object.assign({ className: cn(buttonVariants({ variant, size, className })), ref: ref }, props)));
|
||||
});
|
||||
Button.displayName = "Button";
|
||||
export { Button, buttonVariants };
|
||||
|
|
@ -1,8 +0,0 @@
|
|||
import * as React from "react";
|
||||
declare const Card: React.ForwardRefExoticComponent<React.HTMLAttributes<HTMLDivElement> & React.RefAttributes<HTMLDivElement>>;
|
||||
declare const CardHeader: React.ForwardRefExoticComponent<React.HTMLAttributes<HTMLDivElement> & React.RefAttributes<HTMLDivElement>>;
|
||||
declare const CardTitle: React.ForwardRefExoticComponent<React.HTMLAttributes<HTMLHeadingElement> & React.RefAttributes<HTMLParagraphElement>>;
|
||||
declare const CardDescription: React.ForwardRefExoticComponent<React.HTMLAttributes<HTMLParagraphElement> & React.RefAttributes<HTMLParagraphElement>>;
|
||||
declare const CardContent: React.ForwardRefExoticComponent<React.HTMLAttributes<HTMLDivElement> & React.RefAttributes<HTMLDivElement>>;
|
||||
declare const CardFooter: React.ForwardRefExoticComponent<React.HTMLAttributes<HTMLDivElement> & React.RefAttributes<HTMLDivElement>>;
|
||||
export { Card, CardHeader, CardFooter, CardTitle, CardDescription, CardContent };
|
||||
|
|
@ -1,44 +0,0 @@
|
|||
var __rest = (this && this.__rest) || function (s, e) {
|
||||
var t = {};
|
||||
for (var p in s) if (Object.prototype.hasOwnProperty.call(s, p) && e.indexOf(p) < 0)
|
||||
t[p] = s[p];
|
||||
if (s != null && typeof Object.getOwnPropertySymbols === "function")
|
||||
for (var i = 0, p = Object.getOwnPropertySymbols(s); i < p.length; i++) {
|
||||
if (e.indexOf(p[i]) < 0 && Object.prototype.propertyIsEnumerable.call(s, p[i]))
|
||||
t[p[i]] = s[p[i]];
|
||||
}
|
||||
return t;
|
||||
};
|
||||
import * as React from "react";
|
||||
import { cn } from "../../utils";
|
||||
const Card = React.forwardRef((_a, ref) => {
|
||||
var { className } = _a, props = __rest(_a, ["className"]);
|
||||
return (React.createElement("div", Object.assign({ ref: ref, className: cn("rounded-xl border bg-card text-card-foreground shadow", className) }, props)));
|
||||
});
|
||||
Card.displayName = "Card";
|
||||
const CardHeader = React.forwardRef((_a, ref) => {
|
||||
var { className } = _a, props = __rest(_a, ["className"]);
|
||||
return (React.createElement("div", Object.assign({ ref: ref, className: cn("flex flex-col space-y-1.5 p-6", className) }, props)));
|
||||
});
|
||||
CardHeader.displayName = "CardHeader";
|
||||
const CardTitle = React.forwardRef((_a, ref) => {
|
||||
var { className } = _a, props = __rest(_a, ["className"]);
|
||||
return (React.createElement("h3", Object.assign({ ref: ref, className: cn("font-semibold leading-none tracking-tight", className) }, props)));
|
||||
});
|
||||
CardTitle.displayName = "CardTitle";
|
||||
const CardDescription = React.forwardRef((_a, ref) => {
|
||||
var { className } = _a, props = __rest(_a, ["className"]);
|
||||
return (React.createElement("p", Object.assign({ ref: ref, className: cn("text-sm text-muted-foreground", className) }, props)));
|
||||
});
|
||||
CardDescription.displayName = "CardDescription";
|
||||
const CardContent = React.forwardRef((_a, ref) => {
|
||||
var { className } = _a, props = __rest(_a, ["className"]);
|
||||
return (React.createElement("div", Object.assign({ ref: ref, className: cn("p-6 pt-0", className) }, props)));
|
||||
});
|
||||
CardContent.displayName = "CardContent";
|
||||
const CardFooter = React.forwardRef((_a, ref) => {
|
||||
var { className } = _a, props = __rest(_a, ["className"]);
|
||||
return (React.createElement("div", Object.assign({ ref: ref, className: cn("flex items-center p-6 pt-0", className) }, props)));
|
||||
});
|
||||
CardFooter.displayName = "CardFooter";
|
||||
export { Card, CardHeader, CardFooter, CardTitle, CardDescription, CardContent };
|
||||
|
|
@ -1,9 +0,0 @@
|
|||
export * from './button';
|
||||
export * from './card';
|
||||
export * from './collapsible';
|
||||
export * from './floating-action-button';
|
||||
export * from './input';
|
||||
export * from './layout';
|
||||
export * from './sheet';
|
||||
export * from './switch';
|
||||
export * from './scroll-area';
|
||||
|
|
@ -1,9 +0,0 @@
|
|||
export * from './button';
|
||||
export * from './card';
|
||||
export * from './collapsible';
|
||||
export * from './floating-action-button';
|
||||
export * from './input';
|
||||
export * from './layout';
|
||||
export * from './sheet';
|
||||
export * from './switch';
|
||||
export * from './scroll-area';
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
import * as React from "react";
|
||||
export interface InputProps extends React.InputHTMLAttributes<HTMLInputElement> {
|
||||
}
|
||||
declare const Input: React.ForwardRefExoticComponent<InputProps & React.RefAttributes<HTMLInputElement>>;
|
||||
export { Input };
|
||||
|
|
@ -1,19 +0,0 @@
|
|||
var __rest = (this && this.__rest) || function (s, e) {
|
||||
var t = {};
|
||||
for (var p in s) if (Object.prototype.hasOwnProperty.call(s, p) && e.indexOf(p) < 0)
|
||||
t[p] = s[p];
|
||||
if (s != null && typeof Object.getOwnPropertySymbols === "function")
|
||||
for (var i = 0, p = Object.getOwnPropertySymbols(s); i < p.length; i++) {
|
||||
if (e.indexOf(p[i]) < 0 && Object.prototype.propertyIsEnumerable.call(s, p[i]))
|
||||
t[p[i]] = s[p[i]];
|
||||
}
|
||||
return t;
|
||||
};
|
||||
import * as React from "react";
|
||||
import { cn } from "../../utils";
|
||||
const Input = React.forwardRef((_a, ref) => {
|
||||
var { className, type } = _a, props = __rest(_a, ["className", "type"]);
|
||||
return (React.createElement("input", Object.assign({ type: type, className: cn("flex h-9 w-full rounded-md border border-input bg-background px-3 py-1 text-sm shadow-sm transition-colors file:border-0 file:bg-transparent file:text-sm file:font-medium placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring disabled:cursor-not-allowed disabled:opacity-50", className), ref: ref }, props)));
|
||||
});
|
||||
Input.displayName = "Input";
|
||||
export { Input };
|
||||
|
|
@ -1,4 +0,0 @@
|
|||
import * as React from "react";
|
||||
import * as SwitchPrimitives from "@radix-ui/react-switch";
|
||||
declare const Switch: React.ForwardRefExoticComponent<Omit<SwitchPrimitives.SwitchProps & React.RefAttributes<HTMLButtonElement>, "ref"> & React.RefAttributes<HTMLButtonElement>>;
|
||||
export { Switch };
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
var __rest = (this && this.__rest) || function (s, e) {
|
||||
var t = {};
|
||||
for (var p in s) if (Object.prototype.hasOwnProperty.call(s, p) && e.indexOf(p) < 0)
|
||||
t[p] = s[p];
|
||||
if (s != null && typeof Object.getOwnPropertySymbols === "function")
|
||||
for (var i = 0, p = Object.getOwnPropertySymbols(s); i < p.length; i++) {
|
||||
if (e.indexOf(p[i]) < 0 && Object.prototype.propertyIsEnumerable.call(s, p[i]))
|
||||
t[p[i]] = s[p[i]];
|
||||
}
|
||||
return t;
|
||||
};
|
||||
import * as React from "react";
|
||||
import * as SwitchPrimitives from "@radix-ui/react-switch";
|
||||
import { cn } from "../../utils";
|
||||
const Switch = React.forwardRef((_a, ref) => {
|
||||
var { className } = _a, props = __rest(_a, ["className"]);
|
||||
return (React.createElement(SwitchPrimitives.Root, Object.assign({ className: cn("peer inline-flex h-5 w-9 shrink-0 cursor-pointer items-center rounded-full border-2 border-transparent shadow-sm transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 focus-visible:ring-offset-background disabled:cursor-not-allowed disabled:opacity-50 data-[state=checked]:bg-primary data-[state=unchecked]:bg-input", className) }, props, { ref: ref }),
|
||||
React.createElement(SwitchPrimitives.Thumb, { className: cn("pointer-events-none block h-4 w-4 rounded-full bg-background shadow-lg ring-0 transition-transform data-[state=checked]:translate-x-4 data-[state=unchecked]:translate-x-0") })));
|
||||
});
|
||||
Switch.displayName = SwitchPrimitives.Root.displayName;
|
||||
export { Switch };
|
||||
|
|
@ -1,11 +0,0 @@
|
|||
import './styles/global.css';
|
||||
export * from './utils/types';
|
||||
export * from './utils';
|
||||
export * from './components/ui';
|
||||
export * from './components/TimelineStep';
|
||||
export * from './components/TimelineFeed';
|
||||
export * from './components/SessionDrawer';
|
||||
export * from './components/SessionSidebar';
|
||||
export * from './components/DefaultAgentScreen';
|
||||
export declare const hello: () => void;
|
||||
export { getSampleAgentSteps, getSampleAgentSessions } from './utils/sample-data';
|
||||
|
|
@ -1,22 +0,0 @@
|
|||
// Entry point for @ra-aid/common package
|
||||
import './styles/global.css';
|
||||
// Export types first to avoid circular references
|
||||
export * from './utils/types';
|
||||
// Export utility functions
|
||||
export * from './utils';
|
||||
// Export UI components
|
||||
export * from './components/ui';
|
||||
// Export timeline components
|
||||
export * from './components/TimelineStep';
|
||||
export * from './components/TimelineFeed';
|
||||
// Export session navigation components
|
||||
export * from './components/SessionDrawer';
|
||||
export * from './components/SessionSidebar';
|
||||
// Export main screens
|
||||
export * from './components/DefaultAgentScreen';
|
||||
// Export the hello function (temporary example)
|
||||
export const hello = () => {
|
||||
console.log("Hello from @ra-aid/common");
|
||||
};
|
||||
// Directly export sample data functions
|
||||
export { getSampleAgentSteps, getSampleAgentSessions } from './utils/sample-data';
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,7 +0,0 @@
|
|||
import { type ClassValue } from "clsx";
|
||||
/**
|
||||
* Merges class names with Tailwind CSS classes
|
||||
* Combines clsx for conditional logic and tailwind-merge for handling conflicting tailwind classes
|
||||
*/
|
||||
export declare function cn(...inputs: ClassValue[]): string;
|
||||
export * from './utils';
|
||||
|
|
@ -1,11 +0,0 @@
|
|||
import { clsx } from "clsx";
|
||||
import { twMerge } from "tailwind-merge";
|
||||
/**
|
||||
* Merges class names with Tailwind CSS classes
|
||||
* Combines clsx for conditional logic and tailwind-merge for handling conflicting tailwind classes
|
||||
*/
|
||||
export function cn(...inputs) {
|
||||
return twMerge(clsx(inputs));
|
||||
}
|
||||
// Re-export everything from utils directory
|
||||
export * from './utils';
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,43 +0,0 @@
|
|||
{
|
||||
"name": "@ra-aid/common",
|
||||
"version": "1.0.0",
|
||||
"private": true,
|
||||
"main": "src/index.ts",
|
||||
"types": "src/index.ts",
|
||||
"scripts": {
|
||||
"build": "tsc && postcss src/styles/global.css -o dist/styles/global.css",
|
||||
"dev": "tsc --watch",
|
||||
"watch:css": "postcss src/styles/global.css -o dist/styles/global.css --watch",
|
||||
"watch": "concurrently \"npm run dev\" \"npm run watch:css\"",
|
||||
"prepare": "npm run build"
|
||||
},
|
||||
"dependencies": {
|
||||
"@radix-ui/react-collapsible": "^1.1.3",
|
||||
"@radix-ui/react-dialog": "^1.0.5",
|
||||
"@radix-ui/react-label": "^2.0.2",
|
||||
"@radix-ui/react-popover": "^1.0.7",
|
||||
"@radix-ui/react-scroll-area": "^1.2.3",
|
||||
"@radix-ui/react-select": "^2.0.0",
|
||||
"@radix-ui/react-slot": "^1.0.2",
|
||||
"@radix-ui/react-switch": "^1.1.3",
|
||||
"class-variance-authority": "^0.7.0",
|
||||
"clsx": "^2.1.0",
|
||||
"lucide-react": "^0.363.0",
|
||||
"tailwind-merge": "^2.2.0",
|
||||
"tailwindcss-animate": "^1.0.7"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/react": "^18.2.64",
|
||||
"@types/react-dom": "^18.2.21",
|
||||
"autoprefixer": "^10.4.17",
|
||||
"concurrently": "^8.2.2",
|
||||
"postcss": "^8.4.35",
|
||||
"postcss-cli": "^10.1.0",
|
||||
"tailwindcss": "^3.4.1",
|
||||
"typescript": "^5.0.0"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"react": ">=18.0.0",
|
||||
"react-dom": ">=18.0.0"
|
||||
}
|
||||
}
|
||||
|
|
@ -1,6 +0,0 @@
|
|||
module.exports = {
|
||||
plugins: {
|
||||
tailwindcss: {},
|
||||
autoprefixer: {},
|
||||
},
|
||||
}
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 23 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 25 KiB |
|
|
@ -1,258 +0,0 @@
|
|||
import React, { useState, useEffect } from 'react';
|
||||
import { createPortal } from 'react-dom';
|
||||
import { PanelLeft } from 'lucide-react';
|
||||
import {
|
||||
Button,
|
||||
Layout
|
||||
} from './ui';
|
||||
import { SessionDrawer } from './SessionDrawer';
|
||||
import { SessionList } from './SessionList';
|
||||
import { TimelineFeed } from './TimelineFeed';
|
||||
import { getSampleAgentSessions, getSampleAgentSteps } from '../utils/sample-data';
|
||||
import logoBlack from '../assets/logo-black-transparent.png';
|
||||
import logoWhite from '../assets/logo-white-transparent.gif';
|
||||
|
||||
/**
|
||||
* DefaultAgentScreen component
|
||||
*
|
||||
* Main application screen for displaying agent sessions and their steps.
|
||||
* Handles state management, responsive design, and UI interactions.
|
||||
*/
|
||||
export const DefaultAgentScreen: React.FC = () => {
|
||||
// State for drawer open/close
|
||||
const [isDrawerOpen, setIsDrawerOpen] = useState(false);
|
||||
|
||||
// State for selected session
|
||||
const [selectedSessionId, setSelectedSessionId] = useState<string | null>(null);
|
||||
|
||||
// State for theme (dark is default)
|
||||
const [isDarkTheme, setIsDarkTheme] = useState(true);
|
||||
|
||||
// Get sample data
|
||||
const sessions = getSampleAgentSessions();
|
||||
const allSteps = getSampleAgentSteps();
|
||||
|
||||
// Set up theme on component mount
|
||||
useEffect(() => {
|
||||
const isDark = setupTheme();
|
||||
setIsDarkTheme(isDark);
|
||||
}, []);
|
||||
|
||||
// Set initial selected session if none selected
|
||||
useEffect(() => {
|
||||
if (!selectedSessionId && sessions.length > 0) {
|
||||
setSelectedSessionId(sessions[0].id);
|
||||
}
|
||||
}, [sessions, selectedSessionId]);
|
||||
|
||||
// Close drawer when window resizes to desktop width
|
||||
useEffect(() => {
|
||||
const handleResize = () => {
|
||||
// Check if we're at desktop size (corresponds to md: breakpoint in Tailwind)
|
||||
if (window.innerWidth >= 768 && isDrawerOpen) {
|
||||
setIsDrawerOpen(false);
|
||||
}
|
||||
};
|
||||
|
||||
// Add event listener
|
||||
window.addEventListener('resize', handleResize);
|
||||
|
||||
// Clean up event listener on component unmount
|
||||
return () => window.removeEventListener('resize', handleResize);
|
||||
}, [isDrawerOpen]);
|
||||
|
||||
// Filter steps for selected session
|
||||
const selectedSessionSteps = selectedSessionId
|
||||
? allSteps.filter(step => sessions.find(s => s.id === selectedSessionId)?.steps.some(s => s.id === step.id))
|
||||
: [];
|
||||
|
||||
// Handle session selection
|
||||
const handleSessionSelect = (sessionId: string) => {
|
||||
setSelectedSessionId(sessionId);
|
||||
setIsDrawerOpen(false); // Close drawer on selection (mobile)
|
||||
};
|
||||
|
||||
// Toggle theme function
|
||||
const toggleTheme = () => {
|
||||
const newIsDark = !isDarkTheme;
|
||||
setIsDarkTheme(newIsDark);
|
||||
|
||||
// Update document element class
|
||||
if (newIsDark) {
|
||||
document.documentElement.classList.add('dark');
|
||||
} else {
|
||||
document.documentElement.classList.remove('dark');
|
||||
}
|
||||
|
||||
// Save to localStorage
|
||||
localStorage.setItem('theme', newIsDark ? 'dark' : 'light');
|
||||
};
|
||||
|
||||
// Render header content
|
||||
const headerContent = (
|
||||
<div className="w-full flex items-center justify-between h-full px-4">
|
||||
<div className="flex-initial">
|
||||
{/* Use the appropriate logo based on theme */}
|
||||
<img
|
||||
src={isDarkTheme ? logoWhite : logoBlack}
|
||||
alt="RA.Aid Logo"
|
||||
className="h-8"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="flex-initial ml-auto">
|
||||
{/* Theme toggle button */}
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
onClick={toggleTheme}
|
||||
aria-label={isDarkTheme ? "Switch to light mode" : "Switch to dark mode"}
|
||||
>
|
||||
{isDarkTheme ? (
|
||||
// Sun icon for light mode toggle
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
width="20"
|
||||
height="20"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
>
|
||||
<circle cx="12" cy="12" r="5" />
|
||||
<line x1="12" y1="1" x2="12" y2="3" />
|
||||
<line x1="12" y1="21" x2="12" y2="23" />
|
||||
<line x1="4.22" y1="4.22" x2="5.64" y2="5.64" />
|
||||
<line x1="18.36" y1="18.36" x2="19.78" y2="19.78" />
|
||||
<line x1="1" y1="12" x2="3" y2="12" />
|
||||
<line x1="21" y1="12" x2="23" y2="12" />
|
||||
<line x1="4.22" y1="19.78" x2="5.64" y2="18.36" />
|
||||
<line x1="18.36" y1="5.64" x2="19.78" y2="4.22" />
|
||||
</svg>
|
||||
) : (
|
||||
// Moon icon for dark mode toggle
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
width="20"
|
||||
height="20"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
>
|
||||
<path d="M21 12.79A9 9 0 1 1 11.21 3 7 7 0 0 0 21 12.79z" />
|
||||
</svg>
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
// Sidebar content with sessions list
|
||||
const sidebarContent = (
|
||||
<div className="h-full flex flex-col px-4 py-3">
|
||||
<SessionList
|
||||
sessions={sessions}
|
||||
onSelectSession={handleSessionSelect}
|
||||
currentSessionId={selectedSessionId || undefined}
|
||||
className="flex-1 pr-1 -mr-1"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
|
||||
// Render drawer
|
||||
const drawerContent = (
|
||||
<SessionDrawer
|
||||
sessions={sessions}
|
||||
currentSessionId={selectedSessionId || undefined}
|
||||
onSelectSession={handleSessionSelect}
|
||||
isOpen={isDrawerOpen}
|
||||
onClose={() => setIsDrawerOpen(false)}
|
||||
/>
|
||||
);
|
||||
|
||||
// Render main content
|
||||
const mainContent = (
|
||||
selectedSessionId ? (
|
||||
<>
|
||||
<h2 className="text-xl font-semibold mb-4">
|
||||
Session: {sessions.find(s => s.id === selectedSessionId)?.name || 'Unknown'}
|
||||
</h2>
|
||||
<TimelineFeed
|
||||
steps={selectedSessionSteps}
|
||||
/>
|
||||
</>
|
||||
) : (
|
||||
<div className="flex items-center justify-center h-full">
|
||||
<p className="text-muted-foreground">Select a session to view details</p>
|
||||
</div>
|
||||
)
|
||||
);
|
||||
|
||||
// Floating action button component that uses Portal to render at document body level
|
||||
const FloatingActionButton = ({ onClick }: { onClick: () => void }) => {
|
||||
// Only render the portal on the client side, not during SSR
|
||||
const [mounted, setMounted] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
setMounted(true);
|
||||
return () => setMounted(false);
|
||||
}, []);
|
||||
|
||||
const button = (
|
||||
<Button
|
||||
variant="default"
|
||||
size="icon"
|
||||
onClick={onClick}
|
||||
aria-label="Toggle sessions panel"
|
||||
className="h-14 w-14 rounded-full shadow-xl bg-zinc-800 hover:bg-zinc-700 text-zinc-100 flex items-center justify-center border-2 border-zinc-700 dark:border-zinc-600"
|
||||
>
|
||||
<PanelLeft className="h-6 w-6" />
|
||||
</Button>
|
||||
);
|
||||
|
||||
const container = (
|
||||
<div className="fixed bottom-6 right-6 z-[9999] md:hidden" style={{ pointerEvents: 'auto' }}>
|
||||
{button}
|
||||
</div>
|
||||
);
|
||||
|
||||
// Return null during SSR, or the portal on the client
|
||||
return mounted ? createPortal(container, document.body) : null;
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<Layout
|
||||
header={headerContent}
|
||||
sidebar={sidebarContent}
|
||||
drawer={drawerContent}
|
||||
>
|
||||
{mainContent}
|
||||
</Layout>
|
||||
<FloatingActionButton onClick={() => setIsDrawerOpen(true)} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
// Helper function for theme setup
|
||||
const setupTheme = () => {
|
||||
// Check if theme preference is stored in localStorage
|
||||
const storedTheme = localStorage.getItem('theme');
|
||||
|
||||
// Default to dark mode unless explicitly set to light
|
||||
const isDark = storedTheme ? storedTheme === 'dark' : true;
|
||||
|
||||
// Apply theme to document
|
||||
if (isDark) {
|
||||
document.documentElement.classList.add('dark');
|
||||
} else {
|
||||
document.documentElement.classList.remove('dark');
|
||||
}
|
||||
|
||||
return isDark;
|
||||
};
|
||||
|
|
@ -1,47 +0,0 @@
|
|||
import React from 'react';
|
||||
import {
|
||||
Sheet,
|
||||
SheetContent,
|
||||
SheetHeader,
|
||||
SheetTitle,
|
||||
SheetClose
|
||||
} from './ui/sheet';
|
||||
import { AgentSession } from '../utils/types';
|
||||
import { getSampleAgentSessions } from '../utils/sample-data';
|
||||
import { SessionList } from './SessionList';
|
||||
|
||||
interface SessionDrawerProps {
|
||||
onSelectSession?: (sessionId: string) => void;
|
||||
currentSessionId?: string;
|
||||
sessions?: AgentSession[];
|
||||
isOpen?: boolean;
|
||||
onClose?: () => void;
|
||||
}
|
||||
|
||||
export const SessionDrawer: React.FC<SessionDrawerProps> = ({
|
||||
onSelectSession,
|
||||
currentSessionId,
|
||||
sessions = getSampleAgentSessions(),
|
||||
isOpen = false,
|
||||
onClose
|
||||
}) => {
|
||||
return (
|
||||
<Sheet open={isOpen} onOpenChange={onClose}>
|
||||
<SheetContent
|
||||
side="left"
|
||||
className="w-full sm:max-w-md border-r border-border p-4"
|
||||
>
|
||||
<SheetHeader className="px-2">
|
||||
<SheetTitle>Sessions</SheetTitle>
|
||||
</SheetHeader>
|
||||
<SessionList
|
||||
sessions={sessions}
|
||||
currentSessionId={currentSessionId}
|
||||
onSelectSession={onSelectSession}
|
||||
className="h-[calc(100vh-9rem)] mt-4"
|
||||
wrapperComponent={SheetClose}
|
||||
/>
|
||||
</SheetContent>
|
||||
</Sheet>
|
||||
);
|
||||
};
|
||||
|
|
@ -1,93 +0,0 @@
|
|||
import React from 'react';
|
||||
import { ScrollArea } from './ui/scroll-area';
|
||||
import { AgentSession } from '../utils/types';
|
||||
import { getSampleAgentSessions } from '../utils/sample-data';
|
||||
|
||||
interface SessionListProps {
|
||||
onSelectSession?: (sessionId: string) => void;
|
||||
currentSessionId?: string;
|
||||
sessions?: AgentSession[];
|
||||
className?: string;
|
||||
wrapperComponent?: React.ElementType;
|
||||
closeAction?: React.ReactNode;
|
||||
}
|
||||
|
||||
export const SessionList: React.FC<SessionListProps> = ({
|
||||
onSelectSession,
|
||||
currentSessionId,
|
||||
sessions = getSampleAgentSessions(),
|
||||
className = '',
|
||||
wrapperComponent: WrapperComponent = 'button',
|
||||
closeAction
|
||||
}) => {
|
||||
// Get status color
|
||||
const getStatusColor = (status: string) => {
|
||||
switch (status) {
|
||||
case 'active':
|
||||
return 'bg-blue-500';
|
||||
case 'completed':
|
||||
return 'bg-green-500';
|
||||
case 'error':
|
||||
return 'bg-red-500';
|
||||
default:
|
||||
return 'bg-gray-500';
|
||||
}
|
||||
};
|
||||
|
||||
// Format timestamp
|
||||
const formatDate = (date: Date) => {
|
||||
return date.toLocaleDateString([], {
|
||||
month: 'short',
|
||||
day: 'numeric',
|
||||
hour: '2-digit',
|
||||
minute: '2-digit'
|
||||
});
|
||||
};
|
||||
|
||||
return (
|
||||
<ScrollArea className={className}>
|
||||
<div className="space-y-1.5 pt-1.5 pb-2">
|
||||
{sessions.map((session) => {
|
||||
const buttonContent = (
|
||||
<>
|
||||
<div className={`w-2.5 h-2.5 rounded-full ${getStatusColor(session.status)} mt-1.5 mr-3 flex-shrink-0`} />
|
||||
<div className="flex-1 min-w-0 pr-1">
|
||||
<div className="font-medium text-sm+ break-words">{session.name}</div>
|
||||
<div className="text-xs text-muted-foreground mt-1 break-words">
|
||||
{session.steps.length} steps • {formatDate(session.updated)}
|
||||
</div>
|
||||
<div className="text-xs text-muted-foreground mt-0.5 break-words">
|
||||
<span className="capitalize">{session.status}</span>
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
|
||||
return React.createElement(
|
||||
WrapperComponent,
|
||||
{
|
||||
key: session.id,
|
||||
onClick: () => onSelectSession?.(session.id),
|
||||
className: `w-full flex items-start px-3 py-2.5 text-left rounded-md transition-colors hover:bg-accent/50 ${
|
||||
currentSessionId === session.id ? 'bg-accent' : ''
|
||||
}`
|
||||
},
|
||||
closeAction ? (
|
||||
<>
|
||||
{buttonContent}
|
||||
<div className="ml-2 flex-shrink-0 self-center">
|
||||
{React.cloneElement(closeAction as React.ReactElement, {
|
||||
onClick: (e: React.MouseEvent) => {
|
||||
e.stopPropagation();
|
||||
onSelectSession?.(session.id);
|
||||
}
|
||||
})}
|
||||
</div>
|
||||
</>
|
||||
) : buttonContent
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</ScrollArea>
|
||||
);
|
||||
};
|
||||
|
|
@ -1,32 +0,0 @@
|
|||
import React from 'react';
|
||||
import { AgentSession } from '../utils/types';
|
||||
import { getSampleAgentSessions } from '../utils/sample-data';
|
||||
import { SessionList } from './SessionList';
|
||||
|
||||
interface SessionSidebarProps {
|
||||
onSelectSession?: (sessionId: string) => void;
|
||||
currentSessionId?: string;
|
||||
sessions?: AgentSession[];
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export const SessionSidebar: React.FC<SessionSidebarProps> = ({
|
||||
onSelectSession,
|
||||
currentSessionId,
|
||||
sessions = getSampleAgentSessions(),
|
||||
className = ''
|
||||
}) => {
|
||||
return (
|
||||
<div className={`flex flex-col h-full ${className}`}>
|
||||
<div className="p-4 border-b border-border">
|
||||
<h3 className="font-medium text-lg">Sessions</h3>
|
||||
</div>
|
||||
<SessionList
|
||||
sessions={sessions}
|
||||
currentSessionId={currentSessionId}
|
||||
onSelectSession={onSelectSession}
|
||||
className="flex-1"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
|
@ -1,45 +0,0 @@
|
|||
import React, { useMemo } from 'react';
|
||||
import { TimelineStep } from './TimelineStep';
|
||||
import { AgentStep } from '../utils/types';
|
||||
|
||||
interface TimelineFeedProps {
|
||||
steps: AgentStep[];
|
||||
maxHeight?: string;
|
||||
}
|
||||
|
||||
export const TimelineFeed: React.FC<TimelineFeedProps> = ({
|
||||
steps,
|
||||
maxHeight
|
||||
}) => {
|
||||
// Always use 'desc' (newest first) sort order
|
||||
const sortOrder = 'desc';
|
||||
|
||||
// Sort steps with newest first (desc order)
|
||||
const sortedSteps = useMemo(() => {
|
||||
return [...steps].sort((a, b) => {
|
||||
return b.timestamp.getTime() - a.timestamp.getTime();
|
||||
});
|
||||
}, [steps]);
|
||||
|
||||
return (
|
||||
<div className="w-full rounded-md bg-background">
|
||||
<div
|
||||
className="px-3 py-3 space-y-4 overflow-auto"
|
||||
style={{ maxHeight: maxHeight || undefined }}
|
||||
>
|
||||
{sortedSteps.length > 0 ? (
|
||||
sortedSteps.map((step) => (
|
||||
<TimelineStep key={step.id} step={step} />
|
||||
))
|
||||
) : (
|
||||
<div className="text-center text-muted-foreground py-12 border border-dashed border-border rounded-md">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" className="h-8 w-8 mx-auto mb-2 text-muted-foreground/50" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M12 8v4l3 3m6-3a9 9 0 11-18 0 9 9 0 0118 0z" />
|
||||
</svg>
|
||||
<p>No steps to display</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -1,99 +0,0 @@
|
|||
import React from 'react';
|
||||
import { Collapsible, CollapsibleContent, CollapsibleTrigger } from './ui/collapsible';
|
||||
import { AgentStep } from '../utils/types';
|
||||
|
||||
interface TimelineStepProps {
|
||||
step: AgentStep;
|
||||
}
|
||||
|
||||
export const TimelineStep: React.FC<TimelineStepProps> = ({ step }) => {
|
||||
// Get status color
|
||||
const getStatusColor = (status: string) => {
|
||||
switch (status) {
|
||||
case 'completed':
|
||||
return 'bg-green-500';
|
||||
case 'in-progress':
|
||||
return 'bg-blue-500';
|
||||
case 'error':
|
||||
return 'bg-red-500';
|
||||
case 'pending':
|
||||
return 'bg-yellow-500';
|
||||
default:
|
||||
return 'bg-gray-500';
|
||||
}
|
||||
};
|
||||
|
||||
// Get icon based on step type
|
||||
const getTypeIcon = (type: string) => {
|
||||
switch (type) {
|
||||
case 'tool-execution':
|
||||
return '🛠️';
|
||||
case 'thinking':
|
||||
return '💭';
|
||||
case 'planning':
|
||||
return '📝';
|
||||
case 'implementation':
|
||||
return '💻';
|
||||
case 'user-input':
|
||||
return '👤';
|
||||
default:
|
||||
return '▶️';
|
||||
}
|
||||
};
|
||||
|
||||
// Format timestamp
|
||||
const formatTime = (timestamp: Date) => {
|
||||
return timestamp.toLocaleTimeString([], { hour: '2-digit', minute: '2-digit' });
|
||||
};
|
||||
|
||||
return (
|
||||
<Collapsible className="w-full mb-5 border border-border rounded-md overflow-hidden shadow-sm hover:shadow-md transition-all duration-200">
|
||||
<CollapsibleTrigger className="w-full flex items-center justify-between p-4 text-left hover:bg-accent/30 cursor-pointer group">
|
||||
<div className="flex items-center space-x-3 min-w-0 flex-1 pr-3">
|
||||
<div className={`flex-shrink-0 w-3 h-3 rounded-full ${getStatusColor(step.status)} ring-1 ring-ring/20`} />
|
||||
<div className="flex-shrink-0 text-lg group-hover:scale-110 transition-transform">{getTypeIcon(step.type)}</div>
|
||||
<div className="min-w-0 flex-1">
|
||||
<div className="font-medium text-foreground break-words">{step.title}</div>
|
||||
<div className="text-sm text-muted-foreground line-clamp-2">
|
||||
{step.type === 'tool-execution' ? 'Run tool' : step.content.substring(0, 60)}
|
||||
{step.content.length > 60 ? '...' : ''}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div className="text-xs text-muted-foreground flex flex-col items-end flex-shrink-0 min-w-[70px] text-right">
|
||||
<span className="font-medium">{formatTime(step.timestamp)}</span>
|
||||
{step.duration && (
|
||||
<span className="mt-1 px-2 py-0.5 bg-secondary/50 rounded-full">
|
||||
{(step.duration / 1000).toFixed(1)}s
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</CollapsibleTrigger>
|
||||
<CollapsibleContent>
|
||||
<div className="p-5 bg-card/50 border-t border-border">
|
||||
<div className="text-sm break-words text-foreground leading-relaxed">
|
||||
{step.content}
|
||||
</div>
|
||||
{step.duration && (
|
||||
<div className="mt-4 pt-3 border-t border-border/50">
|
||||
<div className="text-xs text-muted-foreground flex items-center">
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
className="h-3.5 w-3.5 mr-1"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
strokeWidth={2}
|
||||
>
|
||||
<circle cx="12" cy="12" r="10" />
|
||||
<polyline points="12 6 12 12 16 14" />
|
||||
</svg>
|
||||
Duration: {(step.duration / 1000).toFixed(1)} seconds
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</CollapsibleContent>
|
||||
</Collapsible>
|
||||
);
|
||||
};
|
||||
|
|
@ -1,57 +0,0 @@
|
|||
import * as React from "react";
|
||||
import { Slot } from "@radix-ui/react-slot";
|
||||
import { cva, type VariantProps } from "class-variance-authority";
|
||||
|
||||
import { cn } from "../../utils";
|
||||
|
||||
const buttonVariants = cva(
|
||||
"inline-flex items-center justify-center whitespace-nowrap rounded-md text-sm font-medium transition-colors focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring disabled:pointer-events-none disabled:opacity-50",
|
||||
{
|
||||
variants: {
|
||||
variant: {
|
||||
default:
|
||||
"bg-primary text-primary-foreground shadow hover:bg-primary/90",
|
||||
destructive:
|
||||
"bg-destructive text-destructive-foreground shadow-sm hover:bg-destructive/90",
|
||||
outline:
|
||||
"border border-input bg-background shadow-sm hover:bg-accent hover:text-accent-foreground",
|
||||
secondary:
|
||||
"bg-secondary text-secondary-foreground shadow-sm hover:bg-secondary/80",
|
||||
ghost: "hover:bg-accent hover:text-accent-foreground",
|
||||
link: "text-primary underline-offset-4 hover:underline",
|
||||
},
|
||||
size: {
|
||||
default: "h-9 px-4 py-2",
|
||||
sm: "h-8 rounded-md px-3 text-xs",
|
||||
lg: "h-10 rounded-md px-8",
|
||||
icon: "h-9 w-9",
|
||||
},
|
||||
},
|
||||
defaultVariants: {
|
||||
variant: "default",
|
||||
size: "default",
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
export interface ButtonProps
|
||||
extends React.ButtonHTMLAttributes<HTMLButtonElement>,
|
||||
VariantProps<typeof buttonVariants> {
|
||||
asChild?: boolean;
|
||||
}
|
||||
|
||||
const Button = React.forwardRef<HTMLButtonElement, ButtonProps>(
|
||||
({ className, variant, size, asChild = false, ...props }, ref) => {
|
||||
const Comp = asChild ? Slot : "button";
|
||||
return (
|
||||
<Comp
|
||||
className={cn(buttonVariants({ variant, size, className }))}
|
||||
ref={ref}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
}
|
||||
);
|
||||
Button.displayName = "Button";
|
||||
|
||||
export { Button, buttonVariants };
|
||||
|
|
@ -1,76 +0,0 @@
|
|||
import * as React from "react";
|
||||
|
||||
import { cn } from "../../utils";
|
||||
|
||||
const Card = React.forwardRef<
|
||||
HTMLDivElement,
|
||||
React.HTMLAttributes<HTMLDivElement>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<div
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"rounded-xl border bg-card text-card-foreground shadow",
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
));
|
||||
Card.displayName = "Card";
|
||||
|
||||
const CardHeader = React.forwardRef<
|
||||
HTMLDivElement,
|
||||
React.HTMLAttributes<HTMLDivElement>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<div
|
||||
ref={ref}
|
||||
className={cn("flex flex-col space-y-1.5 p-6", className)}
|
||||
{...props}
|
||||
/>
|
||||
));
|
||||
CardHeader.displayName = "CardHeader";
|
||||
|
||||
const CardTitle = React.forwardRef<
|
||||
HTMLParagraphElement,
|
||||
React.HTMLAttributes<HTMLHeadingElement>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<h3
|
||||
ref={ref}
|
||||
className={cn("font-semibold leading-none tracking-tight", className)}
|
||||
{...props}
|
||||
/>
|
||||
));
|
||||
CardTitle.displayName = "CardTitle";
|
||||
|
||||
const CardDescription = React.forwardRef<
|
||||
HTMLParagraphElement,
|
||||
React.HTMLAttributes<HTMLParagraphElement>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<p
|
||||
ref={ref}
|
||||
className={cn("text-sm text-muted-foreground", className)}
|
||||
{...props}
|
||||
/>
|
||||
));
|
||||
CardDescription.displayName = "CardDescription";
|
||||
|
||||
const CardContent = React.forwardRef<
|
||||
HTMLDivElement,
|
||||
React.HTMLAttributes<HTMLDivElement>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<div ref={ref} className={cn("p-6 pt-0", className)} {...props} />
|
||||
));
|
||||
CardContent.displayName = "CardContent";
|
||||
|
||||
const CardFooter = React.forwardRef<
|
||||
HTMLDivElement,
|
||||
React.HTMLAttributes<HTMLDivElement>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<div
|
||||
ref={ref}
|
||||
className={cn("flex items-center p-6 pt-0", className)}
|
||||
{...props}
|
||||
/>
|
||||
));
|
||||
CardFooter.displayName = "CardFooter";
|
||||
|
||||
export { Card, CardHeader, CardFooter, CardTitle, CardDescription, CardContent };
|
||||
|
|
@ -1,27 +0,0 @@
|
|||
import * as React from "react"
|
||||
import * as CollapsiblePrimitive from "@radix-ui/react-collapsible"
|
||||
|
||||
import { cn } from "../../utils"
|
||||
|
||||
const Collapsible = CollapsiblePrimitive.Root
|
||||
|
||||
const CollapsibleTrigger = CollapsiblePrimitive.Trigger
|
||||
|
||||
const CollapsibleContent = React.forwardRef<
|
||||
React.ElementRef<typeof CollapsiblePrimitive.Content>,
|
||||
React.ComponentPropsWithoutRef<typeof CollapsiblePrimitive.Content>
|
||||
>(({ className, children, ...props }, ref) => (
|
||||
<CollapsiblePrimitive.Content
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"overflow-hidden data-[state=closed]:animate-accordion-up data-[state=open]:animate-accordion-down",
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
</CollapsiblePrimitive.Content>
|
||||
))
|
||||
CollapsibleContent.displayName = "CollapsibleContent"
|
||||
|
||||
export { Collapsible, CollapsibleTrigger, CollapsibleContent }
|
||||
|
|
@ -1,36 +0,0 @@
|
|||
import React, { ReactNode } from 'react';
|
||||
import { Button } from './button';
|
||||
|
||||
export interface FloatingActionButtonProps {
|
||||
icon: ReactNode;
|
||||
onClick: () => void;
|
||||
ariaLabel?: string;
|
||||
className?: string;
|
||||
variant?: 'default' | 'destructive' | 'outline' | 'secondary' | 'ghost' | 'link';
|
||||
}
|
||||
|
||||
/**
|
||||
* FloatingActionButton component
|
||||
*
|
||||
* A button typically used for primary actions on mobile layouts
|
||||
* Designed to be used with the Layout component's floatingAction prop
|
||||
*/
|
||||
export const FloatingActionButton: React.FC<FloatingActionButtonProps> = ({
|
||||
icon,
|
||||
onClick,
|
||||
ariaLabel = 'Action button',
|
||||
className = '',
|
||||
variant = 'default'
|
||||
}) => {
|
||||
return (
|
||||
<Button
|
||||
variant={variant}
|
||||
size="icon"
|
||||
onClick={onClick}
|
||||
aria-label={ariaLabel}
|
||||
className={`h-14 w-14 rounded-full shadow-xl bg-blue-600 hover:bg-blue-700 text-white flex items-center justify-center border-2 border-white dark:border-gray-800 ${className}`}
|
||||
>
|
||||
{icon}
|
||||
</Button>
|
||||
);
|
||||
};
|
||||
|
|
@ -1,9 +0,0 @@
|
|||
export * from './button';
|
||||
export * from './card';
|
||||
export * from './collapsible';
|
||||
export * from './floating-action-button';
|
||||
export * from './input';
|
||||
export * from './layout';
|
||||
export * from './sheet';
|
||||
export * from './switch';
|
||||
export * from './scroll-area';
|
||||
|
|
@ -1,25 +0,0 @@
|
|||
import * as React from "react";
|
||||
|
||||
import { cn } from "../../utils";
|
||||
|
||||
export interface InputProps
|
||||
extends React.InputHTMLAttributes<HTMLInputElement> {}
|
||||
|
||||
const Input = React.forwardRef<HTMLInputElement, InputProps>(
|
||||
({ className, type, ...props }, ref) => {
|
||||
return (
|
||||
<input
|
||||
type={type}
|
||||
className={cn(
|
||||
"flex h-9 w-full rounded-md border border-input bg-background px-3 py-1 text-sm shadow-sm transition-colors file:border-0 file:bg-transparent file:text-sm file:font-medium placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring disabled:cursor-not-allowed disabled:opacity-50",
|
||||
className
|
||||
)}
|
||||
ref={ref}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
}
|
||||
);
|
||||
Input.displayName = "Input";
|
||||
|
||||
export { Input };
|
||||
|
|
@ -1,56 +0,0 @@
|
|||
import React from 'react';
|
||||
|
||||
/**
|
||||
* Layout component using Tailwind Grid utilities
|
||||
* This component creates a responsive layout with:
|
||||
* - Sticky header at the top (z-index 30)
|
||||
* - Sidebar on desktop (hidden on mobile)
|
||||
* - Main content area with proper positioning
|
||||
* - Optional floating action button for mobile navigation
|
||||
*/
|
||||
export interface LayoutProps {
|
||||
header: React.ReactNode;
|
||||
sidebar?: React.ReactNode;
|
||||
drawer?: React.ReactNode;
|
||||
children: React.ReactNode;
|
||||
floatingAction?: React.ReactNode;
|
||||
}
|
||||
|
||||
export const Layout: React.FC<LayoutProps> = ({
|
||||
header,
|
||||
sidebar,
|
||||
drawer,
|
||||
children,
|
||||
floatingAction
|
||||
}) => {
|
||||
return (
|
||||
<div className="grid min-h-screen grid-cols-1 grid-rows-[64px_1fr] md:grid-cols-[280px_1fr] lg:grid-cols-[320px_1fr] xl:grid-cols-[350px_1fr] bg-background text-foreground relative">
|
||||
{/* Header - always visible, spans full width */}
|
||||
<header className="sticky top-0 z-30 h-16 flex items-center bg-background border-b border-border col-span-full">
|
||||
{header}
|
||||
</header>
|
||||
|
||||
{/* Sidebar - hidden on mobile, visible on tablet/desktop */}
|
||||
{sidebar && (
|
||||
<aside className="hidden md:block fixed top-16 bottom-0 w-[280px] lg:w-[320px] xl:w-[350px] overflow-y-auto z-20 bg-background border-r border-border">
|
||||
{sidebar}
|
||||
</aside>
|
||||
)}
|
||||
|
||||
{/* Main content area */}
|
||||
<main className="overflow-y-auto p-4 row-start-2 col-start-1 md:col-start-2 md:h-[calc(100vh-64px)]">
|
||||
{children}
|
||||
</main>
|
||||
|
||||
{/* Mobile drawer - rendered outside grid */}
|
||||
{drawer}
|
||||
|
||||
{/* Floating action button for mobile */}
|
||||
{floatingAction && (
|
||||
<div className="fixed bottom-6 right-6 z-50 md:hidden">
|
||||
{floatingAction}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
|
@ -1,47 +0,0 @@
|
|||
import * as React from "react"
|
||||
import * as ScrollAreaPrimitive from "@radix-ui/react-scroll-area"
|
||||
|
||||
import { cn } from "../../utils"
|
||||
|
||||
const ScrollArea = React.forwardRef<
|
||||
React.ElementRef<typeof ScrollAreaPrimitive.Root>,
|
||||
React.ComponentPropsWithoutRef<typeof ScrollAreaPrimitive.Root>
|
||||
>(({ className, children, ...props }, ref) => (
|
||||
<ScrollAreaPrimitive.Root
|
||||
ref={ref}
|
||||
className={cn("relative overflow-hidden", className)}
|
||||
{...props}
|
||||
>
|
||||
<ScrollAreaPrimitive.Viewport className="h-full w-full rounded-[inherit]">
|
||||
{children}
|
||||
</ScrollAreaPrimitive.Viewport>
|
||||
<ScrollBar />
|
||||
<ScrollBar orientation="horizontal" />
|
||||
<ScrollAreaPrimitive.Corner />
|
||||
</ScrollAreaPrimitive.Root>
|
||||
))
|
||||
ScrollArea.displayName = ScrollAreaPrimitive.Root.displayName
|
||||
|
||||
const ScrollBar = React.forwardRef<
|
||||
React.ElementRef<typeof ScrollAreaPrimitive.ScrollAreaScrollbar>,
|
||||
React.ComponentPropsWithoutRef<typeof ScrollAreaPrimitive.ScrollAreaScrollbar>
|
||||
>(({ className, orientation = "vertical", ...props }, ref) => (
|
||||
<ScrollAreaPrimitive.ScrollAreaScrollbar
|
||||
ref={ref}
|
||||
orientation={orientation}
|
||||
className={cn(
|
||||
"flex touch-none select-none transition-colors",
|
||||
orientation === "vertical" &&
|
||||
"h-full w-2.5 border-l border-l-transparent p-[1px]",
|
||||
orientation === "horizontal" &&
|
||||
"h-2.5 border-t border-t-transparent p-[1px]",
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
>
|
||||
<ScrollAreaPrimitive.ScrollAreaThumb className="relative flex-1 rounded-full bg-border" />
|
||||
</ScrollAreaPrimitive.ScrollAreaScrollbar>
|
||||
))
|
||||
ScrollBar.displayName = ScrollAreaPrimitive.ScrollAreaScrollbar.displayName
|
||||
|
||||
export { ScrollArea, ScrollBar }
|
||||
|
|
@ -1,134 +0,0 @@
|
|||
import * as React from "react"
|
||||
import * as SheetPrimitive from "@radix-ui/react-dialog"
|
||||
import { cva, type VariantProps } from "class-variance-authority"
|
||||
import { X } from "lucide-react"
|
||||
|
||||
import { cn } from "../../utils"
|
||||
|
||||
const Sheet = SheetPrimitive.Root
|
||||
|
||||
const SheetTrigger = SheetPrimitive.Trigger
|
||||
|
||||
const SheetClose = SheetPrimitive.Close
|
||||
|
||||
const SheetPortal = SheetPrimitive.Portal
|
||||
|
||||
const SheetOverlay = React.forwardRef<
|
||||
React.ElementRef<typeof SheetPrimitive.Overlay>,
|
||||
React.ComponentPropsWithoutRef<typeof SheetPrimitive.Overlay>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<SheetPrimitive.Overlay
|
||||
className={cn(
|
||||
"fixed inset-0 z-70 bg-background/80 backdrop-blur-sm data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0",
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
ref={ref}
|
||||
/>
|
||||
))
|
||||
SheetOverlay.displayName = SheetPrimitive.Overlay.displayName
|
||||
|
||||
const sheetVariants = cva(
|
||||
"fixed z-70 gap-4 bg-background p-6 shadow-lg transition ease-in-out data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:duration-300 data-[state=open]:duration-500",
|
||||
{
|
||||
variants: {
|
||||
side: {
|
||||
top: "inset-x-0 top-0 border-b data-[state=closed]:slide-out-to-top data-[state=open]:slide-in-from-top",
|
||||
right: "inset-y-0 right-0 h-full w-3/4 border-l data-[state=closed]:slide-out-to-right data-[state=open]:slide-in-from-right sm:max-w-sm",
|
||||
bottom: "inset-x-0 bottom-0 border-t data-[state=closed]:slide-out-to-bottom data-[state=open]:slide-in-from-bottom",
|
||||
left: "inset-y-0 left-0 h-full w-full border-r data-[state=closed]:slide-out-to-left data-[state=open]:slide-in-from-left sm:max-w-sm",
|
||||
},
|
||||
},
|
||||
defaultVariants: {
|
||||
side: "right",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
interface SheetContentProps
|
||||
extends React.ComponentPropsWithoutRef<typeof SheetPrimitive.Content>,
|
||||
VariantProps<typeof sheetVariants> {}
|
||||
|
||||
const SheetContent = React.forwardRef<
|
||||
React.ElementRef<typeof SheetPrimitive.Content>,
|
||||
SheetContentProps
|
||||
>(({ side = "right", className, children, ...props }, ref) => (
|
||||
<SheetPortal>
|
||||
<SheetOverlay />
|
||||
<SheetPrimitive.Content
|
||||
ref={ref}
|
||||
className={cn(sheetVariants({ side }), className)}
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
<SheetPrimitive.Close className="absolute right-4 top-4 rounded-sm opacity-70 ring-offset-background transition-opacity hover:opacity-100 focus:outline-none focus:ring-2 focus:ring-ring focus:ring-offset-2 disabled:pointer-events-none data-[state=open]:bg-secondary">
|
||||
<X className="h-4 w-4" />
|
||||
<span className="sr-only">Close</span>
|
||||
</SheetPrimitive.Close>
|
||||
</SheetPrimitive.Content>
|
||||
</SheetPortal>
|
||||
))
|
||||
SheetContent.displayName = SheetPrimitive.Content.displayName
|
||||
|
||||
const SheetHeader = ({
|
||||
className,
|
||||
...props
|
||||
}: React.HTMLAttributes<HTMLDivElement>) => (
|
||||
<div
|
||||
className={cn(
|
||||
"flex flex-col space-y-2 text-center sm:text-left",
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
)
|
||||
SheetHeader.displayName = "SheetHeader"
|
||||
|
||||
const SheetFooter = ({
|
||||
className,
|
||||
...props
|
||||
}: React.HTMLAttributes<HTMLDivElement>) => (
|
||||
<div
|
||||
className={cn(
|
||||
"flex flex-col-reverse sm:flex-row sm:justify-end sm:space-x-2",
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
)
|
||||
SheetFooter.displayName = "SheetFooter"
|
||||
|
||||
const SheetTitle = React.forwardRef<
|
||||
React.ElementRef<typeof SheetPrimitive.Title>,
|
||||
React.ComponentPropsWithoutRef<typeof SheetPrimitive.Title>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<SheetPrimitive.Title
|
||||
ref={ref}
|
||||
className={cn("text-lg font-semibold text-foreground", className)}
|
||||
{...props}
|
||||
/>
|
||||
))
|
||||
SheetTitle.displayName = SheetPrimitive.Title.displayName
|
||||
|
||||
const SheetDescription = React.forwardRef<
|
||||
React.ElementRef<typeof SheetPrimitive.Description>,
|
||||
React.ComponentPropsWithoutRef<typeof SheetPrimitive.Description>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<SheetPrimitive.Description
|
||||
ref={ref}
|
||||
className={cn("text-sm text-muted-foreground", className)}
|
||||
{...props}
|
||||
/>
|
||||
))
|
||||
SheetDescription.displayName = SheetPrimitive.Description.displayName
|
||||
|
||||
export {
|
||||
Sheet,
|
||||
SheetTrigger,
|
||||
SheetClose,
|
||||
SheetContent,
|
||||
SheetHeader,
|
||||
SheetFooter,
|
||||
SheetTitle,
|
||||
SheetDescription,
|
||||
}
|
||||
|
|
@ -1,27 +0,0 @@
|
|||
import * as React from "react";
|
||||
import * as SwitchPrimitives from "@radix-ui/react-switch";
|
||||
|
||||
import { cn } from "../../utils";
|
||||
|
||||
const Switch = React.forwardRef<
|
||||
React.ElementRef<typeof SwitchPrimitives.Root>,
|
||||
React.ComponentPropsWithoutRef<typeof SwitchPrimitives.Root>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<SwitchPrimitives.Root
|
||||
className={cn(
|
||||
"peer inline-flex h-5 w-9 shrink-0 cursor-pointer items-center rounded-full border-2 border-transparent shadow-sm transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 focus-visible:ring-offset-background disabled:cursor-not-allowed disabled:opacity-50 data-[state=checked]:bg-primary data-[state=unchecked]:bg-input",
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
ref={ref}
|
||||
>
|
||||
<SwitchPrimitives.Thumb
|
||||
className={cn(
|
||||
"pointer-events-none block h-4 w-4 rounded-full bg-background shadow-lg ring-0 transition-transform data-[state=checked]:translate-x-4 data-[state=unchecked]:translate-x-0"
|
||||
)}
|
||||
/>
|
||||
</SwitchPrimitives.Root>
|
||||
));
|
||||
Switch.displayName = SwitchPrimitives.Root.displayName;
|
||||
|
||||
export { Switch };
|
||||
|
|
@ -1,33 +0,0 @@
|
|||
// Entry point for @ra-aid/common package
|
||||
import './styles/global.css';
|
||||
|
||||
// Export types first to avoid circular references
|
||||
export * from './utils/types';
|
||||
|
||||
// Export utility functions
|
||||
export * from './utils';
|
||||
|
||||
// Export UI components
|
||||
export * from './components/ui';
|
||||
|
||||
// Export timeline components
|
||||
export * from './components/TimelineStep';
|
||||
export * from './components/TimelineFeed';
|
||||
|
||||
// Export session navigation components
|
||||
export * from './components/SessionDrawer';
|
||||
export * from './components/SessionSidebar';
|
||||
|
||||
// Export main screens
|
||||
export * from './components/DefaultAgentScreen';
|
||||
|
||||
// Export the hello function (temporary example)
|
||||
export const hello = (): void => {
|
||||
console.log("Hello from @ra-aid/common");
|
||||
};
|
||||
|
||||
// Directly export sample data functions
|
||||
export {
|
||||
getSampleAgentSteps,
|
||||
getSampleAgentSessions
|
||||
} from './utils/sample-data';
|
||||
|
|
@ -1,80 +0,0 @@
|
|||
@tailwind base;
|
||||
@tailwind components;
|
||||
@tailwind utilities;
|
||||
|
||||
@layer base {
|
||||
:root {
|
||||
--background: 0 0% 100%;
|
||||
--foreground: 222.2 47.4% 11.2%;
|
||||
|
||||
--muted: 210 40% 96.1%;
|
||||
--muted-foreground: 215.4 16.3% 46.9%;
|
||||
|
||||
--popover: 0 0% 100%;
|
||||
--popover-foreground: 222.2 47.4% 11.2%;
|
||||
|
||||
--card: 0 0% 100%;
|
||||
--card-foreground: 222.2 47.4% 11.2%;
|
||||
|
||||
--border: 214.3 31.8% 91.4%;
|
||||
--input: 214.3 31.8% 91.4%;
|
||||
|
||||
--primary: 222.2 47.4% 11.2%;
|
||||
--primary-foreground: 210 40% 98%;
|
||||
|
||||
--secondary: 210 40% 96.1%;
|
||||
--secondary-foreground: 222.2 47.4% 11.2%;
|
||||
|
||||
--accent: 210 40% 96.1%;
|
||||
--accent-foreground: 222.2 47.4% 11.2%;
|
||||
|
||||
--destructive: 0 100% 50%;
|
||||
--destructive-foreground: 210 40% 98%;
|
||||
|
||||
--ring: 215 20.2% 65.1%;
|
||||
|
||||
--radius: 0.5rem;
|
||||
}
|
||||
|
||||
.dark {
|
||||
--background: 240 10% 3.9%; /* zinc-950 */
|
||||
--foreground: 240 5% 96%; /* zinc-50 */
|
||||
|
||||
--card: 240 10% 3.9%; /* zinc-950 */
|
||||
--card-foreground: 240 5% 96%; /* zinc-50 */
|
||||
|
||||
--popover: 240 10% 3.9%; /* zinc-950 */
|
||||
--popover-foreground: 240 5% 96%; /* zinc-50 */
|
||||
|
||||
--primary: 240 5% 96%; /* zinc-50 */
|
||||
--primary-foreground: 240 6% 10%; /* zinc-900 */
|
||||
|
||||
--secondary: 240 4% 16%; /* zinc-800 */
|
||||
--secondary-foreground: 240 5% 96%; /* zinc-50 */
|
||||
|
||||
--muted: 240 4% 16%; /* zinc-800 */
|
||||
--muted-foreground: 240 5% 65%; /* zinc-400 */
|
||||
|
||||
--accent: 240 4% 16%; /* zinc-800 */
|
||||
--accent-foreground: 240 5% 96%; /* zinc-50 */
|
||||
|
||||
--destructive: 0 63% 31%; /* red-900 */
|
||||
--destructive-foreground: 240 5% 96%; /* zinc-50 */
|
||||
|
||||
--border: 240 4% 16%; /* zinc-800 */
|
||||
--input: 240 4% 16%; /* zinc-800 */
|
||||
--ring: 240 5% 84%; /* zinc-300 */
|
||||
|
||||
--radius: 0.5rem;
|
||||
}
|
||||
}
|
||||
|
||||
@layer base {
|
||||
* {
|
||||
@apply border-border;
|
||||
}
|
||||
body {
|
||||
@apply bg-background text-foreground;
|
||||
font-feature-settings: "rlig" 1, "calt" 1;
|
||||
}
|
||||
}
|
||||
|
|
@ -1,24 +0,0 @@
|
|||
declare module '*.png' {
|
||||
const content: string;
|
||||
export default content;
|
||||
}
|
||||
|
||||
declare module '*.gif' {
|
||||
const content: string;
|
||||
export default content;
|
||||
}
|
||||
|
||||
declare module '*.jpg' {
|
||||
const content: string;
|
||||
export default content;
|
||||
}
|
||||
|
||||
declare module '*.jpeg' {
|
||||
const content: string;
|
||||
export default content;
|
||||
}
|
||||
|
||||
declare module '*.svg' {
|
||||
const content: string;
|
||||
export default content;
|
||||
}
|
||||
|
|
@ -1,13 +0,0 @@
|
|||
import { clsx, type ClassValue } from "clsx";
|
||||
import { twMerge } from "tailwind-merge";
|
||||
|
||||
/**
|
||||
* Merges class names with Tailwind CSS classes
|
||||
* Combines clsx for conditional logic and tailwind-merge for handling conflicting tailwind classes
|
||||
*/
|
||||
export function cn(...inputs: ClassValue[]) {
|
||||
return twMerge(clsx(inputs));
|
||||
}
|
||||
|
||||
// Re-export everything from utils directory
|
||||
export * from './utils';
|
||||
|
|
@ -1,13 +0,0 @@
|
|||
import { clsx, type ClassValue } from "clsx";
|
||||
import { twMerge } from "tailwind-merge";
|
||||
|
||||
/**
|
||||
* Merges class names with Tailwind CSS classes
|
||||
* Combines clsx for conditional logic and tailwind-merge for handling conflicting tailwind classes
|
||||
*/
|
||||
export function cn(...inputs: ClassValue[]) {
|
||||
return twMerge(clsx(inputs));
|
||||
}
|
||||
|
||||
// Note: Sample data functions and types are now exported directly from the root index.ts
|
||||
// to avoid circular references
|
||||
|
|
@ -1,164 +0,0 @@
|
|||
/**
|
||||
* Sample data utility for agent UI components demonstration
|
||||
*/
|
||||
|
||||
import { AgentStep, AgentSession } from './types';
|
||||
|
||||
/**
|
||||
* Returns an array of sample agent steps
|
||||
*/
|
||||
export function getSampleAgentSteps(): AgentStep[] {
|
||||
return [
|
||||
{
|
||||
id: "step-1",
|
||||
timestamp: new Date(Date.now() - 30 * 60000), // 30 minutes ago
|
||||
status: 'completed',
|
||||
type: 'planning',
|
||||
title: 'Initial Planning',
|
||||
content: 'I need to analyze the codebase structure to understand the existing components and their relationships.',
|
||||
duration: 5200
|
||||
},
|
||||
{
|
||||
id: "step-2",
|
||||
timestamp: new Date(Date.now() - 25 * 60000), // 25 minutes ago
|
||||
status: 'completed',
|
||||
type: 'tool-execution',
|
||||
title: 'List Directory Structure',
|
||||
content: 'Executing: list_directory_tree(path="src/", max_depth=2)\n\n📁 /project/src/\n├── 📁 components/\n│ ├── 📁 ui/\n│ └── App.tsx\n├── 📁 utils/\n└── index.tsx',
|
||||
duration: 1800
|
||||
},
|
||||
{
|
||||
id: "step-3",
|
||||
timestamp: new Date(Date.now() - 20 * 60000), // 20 minutes ago
|
||||
status: 'completed',
|
||||
type: 'thinking',
|
||||
title: 'Component Analysis',
|
||||
content: 'Based on the directory structure, I see that the UI components are organized in a dedicated folder. I should examine the existing component patterns before implementing new ones.',
|
||||
duration: 3500
|
||||
},
|
||||
{
|
||||
id: "step-4",
|
||||
timestamp: new Date(Date.now() - 15 * 60000), // 15 minutes ago
|
||||
status: 'completed',
|
||||
type: 'tool-execution',
|
||||
title: 'Read Component Code',
|
||||
content: 'Executing: read_file_tool(filepath="src/components/ui/Button.tsx")\n\n```tsx\nimport { cn } from "../../utils";\n\nexport interface ButtonProps {\n // Component props...\n}\n\nexport function Button({ children, ...props }: ButtonProps) {\n // Component implementation...\n}\n```',
|
||||
duration: 2100
|
||||
},
|
||||
{
|
||||
id: "step-5",
|
||||
timestamp: new Date(Date.now() - 10 * 60000), // 10 minutes ago
|
||||
status: 'completed',
|
||||
type: 'implementation',
|
||||
title: 'Creating NavBar Component',
|
||||
content: 'I\'m creating a NavBar component following the design system patterns:\n\n```tsx\nimport { cn } from "../../utils";\n\nexport interface NavBarProps {\n // New component props...\n}\n\nexport function NavBar({ ...props }: NavBarProps) {\n // New component implementation...\n}\n```',
|
||||
duration: 6800
|
||||
},
|
||||
{
|
||||
id: "step-6",
|
||||
timestamp: new Date(Date.now() - 5 * 60000), // 5 minutes ago
|
||||
status: 'in-progress',
|
||||
type: 'implementation',
|
||||
title: 'Styling Timeline Component',
|
||||
content: 'Currently working on styling the Timeline component to match the design system:\n\n```tsx\n// Work in progress...\nexport function Timeline({ steps, ...props }: TimelineProps) {\n // Current implementation...\n}\n```',
|
||||
},
|
||||
{
|
||||
id: "step-7",
|
||||
timestamp: new Date(Date.now() - 2 * 60000), // 2 minutes ago
|
||||
status: 'error',
|
||||
type: 'tool-execution',
|
||||
title: 'Running Tests',
|
||||
content: 'Error executing: run_shell_command(command="npm test")\n\nTest failed: TypeError: Cannot read property \'steps\' of undefined',
|
||||
duration: 3200
|
||||
},
|
||||
{
|
||||
id: "step-8",
|
||||
timestamp: new Date(), // Now
|
||||
status: 'pending',
|
||||
type: 'planning',
|
||||
title: 'Next Steps',
|
||||
content: 'Need to plan the implementation of the SessionDrawer component...',
|
||||
}
|
||||
];
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an array of sample agent sessions
|
||||
*/
|
||||
export function getSampleAgentSessions(): AgentSession[] {
|
||||
const steps = getSampleAgentSteps();
|
||||
|
||||
return [
|
||||
{
|
||||
id: "session-1",
|
||||
name: "UI Component Implementation",
|
||||
created: new Date(Date.now() - 35 * 60000), // 35 minutes ago
|
||||
updated: new Date(), // Now
|
||||
status: 'active',
|
||||
steps: steps
|
||||
},
|
||||
{
|
||||
id: "session-2",
|
||||
name: "API Integration",
|
||||
created: new Date(Date.now() - 2 * 3600000), // 2 hours ago
|
||||
updated: new Date(Date.now() - 30 * 60000), // 30 minutes ago
|
||||
status: 'completed',
|
||||
steps: [
|
||||
{
|
||||
id: "other-step-1",
|
||||
timestamp: new Date(Date.now() - 2 * 3600000), // 2 hours ago
|
||||
status: 'completed',
|
||||
type: 'planning',
|
||||
title: 'API Integration Planning',
|
||||
content: 'Planning the integration with the backend API...',
|
||||
duration: 4500
|
||||
},
|
||||
{
|
||||
id: "other-step-2",
|
||||
timestamp: new Date(Date.now() - 1.5 * 3600000), // 1.5 hours ago
|
||||
status: 'completed',
|
||||
type: 'implementation',
|
||||
title: 'Implementing API Client',
|
||||
content: 'Creating API client with fetch utilities...',
|
||||
duration: 7200
|
||||
},
|
||||
{
|
||||
id: "other-step-3",
|
||||
timestamp: new Date(Date.now() - 1 * 3600000), // 1 hour ago
|
||||
status: 'completed',
|
||||
type: 'tool-execution',
|
||||
title: 'Testing API Endpoints',
|
||||
content: 'Running tests against API endpoints...',
|
||||
duration: 5000
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
id: "session-3",
|
||||
name: "Bug Fixes",
|
||||
created: new Date(Date.now() - 5 * 3600000), // 5 hours ago
|
||||
updated: new Date(Date.now() - 4 * 3600000), // 4 hours ago
|
||||
status: 'error',
|
||||
steps: [
|
||||
{
|
||||
id: "bug-step-1",
|
||||
timestamp: new Date(Date.now() - 5 * 3600000), // 5 hours ago
|
||||
status: 'completed',
|
||||
type: 'planning',
|
||||
title: 'Bug Analysis',
|
||||
content: 'Analyzing reported bugs from issue tracker...',
|
||||
duration: 3600
|
||||
},
|
||||
{
|
||||
id: "bug-step-2",
|
||||
timestamp: new Date(Date.now() - 4.5 * 3600000), // 4.5 hours ago
|
||||
status: 'error',
|
||||
type: 'implementation',
|
||||
title: 'Fixing Authentication Bug',
|
||||
content: 'Error: Unable to resolve dependency conflict with auth package',
|
||||
duration: 2500
|
||||
}
|
||||
]
|
||||
}
|
||||
];
|
||||
}
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
/**
|
||||
* Common types for agent UI components
|
||||
*/
|
||||
|
||||
/**
|
||||
* Represents a single step in the agent process
|
||||
*/
|
||||
export interface AgentStep {
|
||||
id: string;
|
||||
timestamp: Date;
|
||||
status: 'completed' | 'in-progress' | 'error' | 'pending';
|
||||
type: 'tool-execution' | 'thinking' | 'planning' | 'implementation' | 'user-input';
|
||||
title: string;
|
||||
content: string;
|
||||
duration?: number; // in milliseconds
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents a session with multiple steps
|
||||
*/
|
||||
export interface AgentSession {
|
||||
id: string;
|
||||
name: string;
|
||||
created: Date;
|
||||
updated: Date;
|
||||
status: 'active' | 'completed' | 'error';
|
||||
steps: AgentStep[];
|
||||
}
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
/** @type {import('tailwindcss').Config} */
|
||||
module.exports = {
|
||||
presets: [require('./tailwind.preset')],
|
||||
content: [
|
||||
'./src/**/*.{js,jsx,ts,tsx}',
|
||||
],
|
||||
safelist: [
|
||||
'dark',
|
||||
{
|
||||
pattern: /^dark:/,
|
||||
variants: ['hover', 'focus', 'active']
|
||||
}
|
||||
],
|
||||
theme: {
|
||||
extend: {},
|
||||
},
|
||||
plugins: [],
|
||||
}
|
||||
|
|
@ -1,70 +0,0 @@
|
|||
/** @type {import('tailwindcss').Config} */
|
||||
module.exports = {
|
||||
darkMode: ["class"],
|
||||
theme: {
|
||||
container: {
|
||||
center: true,
|
||||
padding: "2rem",
|
||||
screens: {
|
||||
"2xl": "1400px",
|
||||
},
|
||||
},
|
||||
extend: {
|
||||
colors: {
|
||||
border: "hsl(var(--border))",
|
||||
input: "hsl(var(--input))",
|
||||
ring: "hsl(var(--ring))",
|
||||
background: "hsl(var(--background))",
|
||||
foreground: "hsl(var(--foreground))",
|
||||
primary: {
|
||||
DEFAULT: "hsl(var(--primary))",
|
||||
foreground: "hsl(var(--primary-foreground))",
|
||||
},
|
||||
secondary: {
|
||||
DEFAULT: "hsl(var(--secondary))",
|
||||
foreground: "hsl(var(--secondary-foreground))",
|
||||
},
|
||||
destructive: {
|
||||
DEFAULT: "hsl(var(--destructive))",
|
||||
foreground: "hsl(var(--destructive-foreground))",
|
||||
},
|
||||
muted: {
|
||||
DEFAULT: "hsl(var(--muted))",
|
||||
foreground: "hsl(var(--muted-foreground))",
|
||||
},
|
||||
accent: {
|
||||
DEFAULT: "hsl(var(--accent))",
|
||||
foreground: "hsl(var(--accent-foreground))",
|
||||
},
|
||||
popover: {
|
||||
DEFAULT: "hsl(var(--popover))",
|
||||
foreground: "hsl(var(--popover-foreground))",
|
||||
},
|
||||
card: {
|
||||
DEFAULT: "hsl(var(--card))",
|
||||
foreground: "hsl(var(--card-foreground))",
|
||||
},
|
||||
},
|
||||
borderRadius: {
|
||||
lg: "var(--radius)",
|
||||
md: "calc(var(--radius) - 2px)",
|
||||
sm: "calc(var(--radius) - 4px)",
|
||||
},
|
||||
keyframes: {
|
||||
"accordion-down": {
|
||||
from: { height: "0" },
|
||||
to: { height: "var(--radix-accordion-content-height)" },
|
||||
},
|
||||
"accordion-up": {
|
||||
from: { height: "var(--radix-accordion-content-height)" },
|
||||
to: { height: "0" },
|
||||
},
|
||||
},
|
||||
animation: {
|
||||
"accordion-down": "accordion-down 0.2s ease-out",
|
||||
"accordion-up": "accordion-up 0.2s ease-out",
|
||||
},
|
||||
},
|
||||
},
|
||||
plugins: [require("tailwindcss-animate")],
|
||||
}
|
||||
|
|
@ -1,17 +0,0 @@
|
|||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES6",
|
||||
"module": "ESNext",
|
||||
"moduleResolution": "node",
|
||||
"declaration": true,
|
||||
"jsx": "react",
|
||||
"strict": true,
|
||||
"esModuleInterop": true,
|
||||
"skipLibCheck": true,
|
||||
"forceConsistentCasingInFileNames": true,
|
||||
"outDir": "dist",
|
||||
"rootDir": "src",
|
||||
"lib": ["DOM", "DOM.Iterable", "ESNext", "ES2016"]
|
||||
},
|
||||
"include": ["src"]
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,13 +0,0 @@
|
|||
{
|
||||
"name": "frontend-monorepo",
|
||||
"private": true,
|
||||
"workspaces": [
|
||||
"common",
|
||||
"web",
|
||||
"vsc"
|
||||
],
|
||||
"scripts": {
|
||||
"install-all": "npm install",
|
||||
"dev:web": "npm --workspace @ra-aid/web run dev"
|
||||
}
|
||||
}
|
||||
|
|
@ -1,140 +0,0 @@
|
|||
"use strict";
|
||||
var __create = Object.create;
|
||||
var __defProp = Object.defineProperty;
|
||||
var __getOwnPropDesc = Object.getOwnPropertyDescriptor;
|
||||
var __getOwnPropNames = Object.getOwnPropertyNames;
|
||||
var __getProtoOf = Object.getPrototypeOf;
|
||||
var __hasOwnProp = Object.prototype.hasOwnProperty;
|
||||
var __export = (target, all) => {
|
||||
for (var name in all)
|
||||
__defProp(target, name, { get: all[name], enumerable: true });
|
||||
};
|
||||
var __copyProps = (to, from, except, desc) => {
|
||||
if (from && typeof from === "object" || typeof from === "function") {
|
||||
for (let key of __getOwnPropNames(from))
|
||||
if (!__hasOwnProp.call(to, key) && key !== except)
|
||||
__defProp(to, key, { get: () => from[key], enumerable: !(desc = __getOwnPropDesc(from, key)) || desc.enumerable });
|
||||
}
|
||||
return to;
|
||||
};
|
||||
var __toESM = (mod, isNodeMode, target) => (target = mod != null ? __create(__getProtoOf(mod)) : {}, __copyProps(
|
||||
// If the importer is in node compatibility mode or this is not an ESM
|
||||
// file that has been converted to a CommonJS file using a Babel-
|
||||
// compatible transform (i.e. "__esModule" has not been set), then set
|
||||
// "default" to the CommonJS "module.exports" for node compatibility.
|
||||
isNodeMode || !mod || !mod.__esModule ? __defProp(target, "default", { value: mod, enumerable: true }) : target,
|
||||
mod
|
||||
));
|
||||
var __toCommonJS = (mod) => __copyProps(__defProp({}, "__esModule", { value: true }), mod);
|
||||
|
||||
// src/extension.ts
|
||||
var extension_exports = {};
|
||||
__export(extension_exports, {
|
||||
activate: () => activate,
|
||||
deactivate: () => deactivate
|
||||
});
|
||||
module.exports = __toCommonJS(extension_exports);
|
||||
var vscode = __toESM(require("vscode"));
|
||||
var RAWebviewViewProvider = class {
|
||||
constructor(_extensionUri) {
|
||||
this._extensionUri = _extensionUri;
|
||||
}
|
||||
/**
|
||||
* Called when a view is first created to initialize the webview
|
||||
*/
|
||||
resolveWebviewView(webviewView, context, _token) {
|
||||
webviewView.webview.options = {
|
||||
// Enable JavaScript in the webview
|
||||
enableScripts: true,
|
||||
// Restrict the webview to only load resources from the extension's directory
|
||||
localResourceRoots: [this._extensionUri]
|
||||
};
|
||||
webviewView.webview.html = this._getHtmlForWebview(webviewView.webview);
|
||||
}
|
||||
/**
|
||||
* Creates HTML content for the webview with proper security policies
|
||||
*/
|
||||
_getHtmlForWebview(webview) {
|
||||
const logoUri = webview.asWebviewUri(vscode.Uri.joinPath(this._extensionUri, "assets", "RA.png"));
|
||||
const nonce = getNonce();
|
||||
return `<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta http-equiv="Content-Security-Policy" content="default-src 'none'; img-src ${webview.cspSource} https:; style-src ${webview.cspSource} 'unsafe-inline'; script-src 'nonce-${nonce}';">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>RA.Aid</title>
|
||||
<style>
|
||||
body {
|
||||
padding: 0;
|
||||
color: var(--vscode-foreground);
|
||||
font-size: var(--vscode-font-size);
|
||||
font-weight: var(--vscode-font-weight);
|
||||
font-family: var(--vscode-font-family);
|
||||
background-color: var(--vscode-editor-background);
|
||||
}
|
||||
.container {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
padding: 20px;
|
||||
text-align: center;
|
||||
}
|
||||
.logo {
|
||||
width: 100px;
|
||||
height: 100px;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
h1 {
|
||||
color: var(--vscode-editor-foreground);
|
||||
font-size: 1.3em;
|
||||
margin-bottom: 15px;
|
||||
}
|
||||
p {
|
||||
color: var(--vscode-foreground);
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<img src="${logoUri}" alt="RA.Aid Logo" class="logo">
|
||||
<h1>RA.Aid</h1>
|
||||
<p>Your research and development assistant.</p>
|
||||
<p>More features coming soon!</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>`;
|
||||
}
|
||||
};
|
||||
function getNonce() {
|
||||
let text = "";
|
||||
const possible = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
|
||||
for (let i = 0; i < 32; i++) {
|
||||
text += possible.charAt(Math.floor(Math.random() * possible.length));
|
||||
}
|
||||
return text;
|
||||
}
|
||||
function activate(context) {
|
||||
console.log('Congratulations, your extension "ra-aid" is now active!');
|
||||
const provider = new RAWebviewViewProvider(context.extensionUri);
|
||||
const viewRegistration = vscode.window.registerWebviewViewProvider(
|
||||
"ra-aid.view",
|
||||
// Must match the view id in package.json
|
||||
provider
|
||||
);
|
||||
context.subscriptions.push(viewRegistration);
|
||||
const disposable = vscode.commands.registerCommand("ra-aid.helloWorld", () => {
|
||||
vscode.window.showInformationMessage("Hello World from RA.Aid!");
|
||||
});
|
||||
context.subscriptions.push(disposable);
|
||||
}
|
||||
function deactivate() {
|
||||
}
|
||||
// Annotate the CommonJS export names for ESM import in node:
|
||||
0 && (module.exports = {
|
||||
activate,
|
||||
deactivate
|
||||
});
|
||||
//# sourceMappingURL=extension.js.map
|
||||
|
|
@ -1,6 +0,0 @@
|
|||
{
|
||||
"version": 3,
|
||||
"sources": ["../src/extension.ts"],
|
||||
"mappings": ";;;;;;;;;;;;;;;;;;;;;;;;;;;;;;AAAA;AAAA;AAAA;AAAA;AAAA;AAAA;AACA,aAAwB;AAKxB,IAAM,wBAAN,MAAkE;AAAA,EAChE,YAA6B,eAA2B;AAA3B;AAAA,EAA4B;AAAA;AAAA;AAAA;AAAA,EAKlD,mBACL,aACA,SACA,QACA;AAEA,gBAAY,QAAQ,UAAU;AAAA;AAAA,MAE5B,eAAe;AAAA;AAAA,MAEf,oBAAoB,CAAC,KAAK,aAAa;AAAA,IACzC;AAGA,gBAAY,QAAQ,OAAO,KAAK,mBAAmB,YAAY,OAAO;AAAA,EACxE;AAAA;AAAA;AAAA;AAAA,EAKQ,mBAAmB,SAAiC;AAE1D,UAAM,UAAU,QAAQ,aAAoB,WAAI,SAAS,KAAK,eAAe,UAAU,QAAQ,CAAC;AAMhG,UAAM,QAAQ,SAAS;AAEvB,WAAO;AAAA;AAAA;AAAA;AAAA,0FAI+E,QAAQ,SAAS,sBAAsB,QAAQ,SAAS,uCAAuC,KAAK;AAAA;AAAA;AAAA;AAAA;AAAA;AAAA;AAAA;AAAA;AAAA;AAAA;AAAA;AAAA;AAAA;AAAA;AAAA;AAAA;AAAA;AAAA;AAAA;AAAA;AAAA;AAAA;AAAA;AAAA;AAAA;AAAA;AAAA;AAAA;AAAA;AAAA;AAAA;AAAA;AAAA;AAAA;AAAA;AAAA;AAAA;AAAA,sBAsCxK,OAAO;AAAA;AAAA;AAAA;AAAA;AAAA;AAAA;AAAA,EAO3B;AACF;AAKA,SAAS,WAAW;AAClB,MAAI,OAAO;AACX,QAAM,WAAW;AACjB,WAAS,IAAI,GAAG,IAAI,IAAI,KAAK;AAC3B,YAAQ,SAAS,OAAO,KAAK,MAAM,KAAK,OAAO,IAAI,SAAS,MAAM,CAAC;AAAA,EACrE;AACA,SAAO;AACT;AAGO,SAAS,SAAS,SAAkC;AAEzD,UAAQ,IAAI,yDAAyD;AAGrE,QAAM,WAAW,IAAI,sBAAsB,QAAQ,YAAY;AAC/D,QAAM,mBAA0B,cAAO;AAAA,IACrC;AAAA;AAAA,IACA;AAAA,EACF;AACA,UAAQ,cAAc,KAAK,gBAAgB;AAK3C,QAAM,aAAoB,gBAAS,gBAAgB,qBAAqB,MAAM;AAG5E,IAAO,cAAO,uBAAuB,0BAA0B;AAAA,EACjE,CAAC;AAED,UAAQ,cAAc,KAAK,UAAU;AACvC;AAGO,SAAS,aAAa;AAAC;",
|
||||
"names": []
|
||||
}
|
||||
|
|
@ -1,16 +0,0 @@
|
|||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<meta name="description" content="Demo page showcasing shadcn/ui components from the common package" />
|
||||
<title>RA-Aid UI Components Demo</title>
|
||||
<link rel="preconnect" href="https://fonts.googleapis.com">
|
||||
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
||||
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap" rel="stylesheet">
|
||||
</head>
|
||||
<body>
|
||||
<div id="root"></div>
|
||||
<script type="module" src="/src/index.tsx"></script>
|
||||
</body>
|
||||
</html>
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
{
|
||||
"name": "@ra-aid/web",
|
||||
"version": "1.0.0",
|
||||
"private": true,
|
||||
"main": "dist/index.js",
|
||||
"scripts": {
|
||||
"dev": "vite",
|
||||
"build": "vite build"
|
||||
},
|
||||
"dependencies": {
|
||||
"react": "^18.0.0",
|
||||
"react-dom": "^18.0.0",
|
||||
"@ra-aid/common": "1.0.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"vite": "^4.0.0",
|
||||
"@vitejs/plugin-react": "^3.0.0",
|
||||
"typescript": "^5.0.0",
|
||||
"tailwindcss": "^3.4.1",
|
||||
"postcss": "^8.4.35",
|
||||
"autoprefixer": "^10.4.17"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@tailwindcss/forms": "^0.5.7"
|
||||
}
|
||||
}
|
||||
|
|
@ -1,6 +0,0 @@
|
|||
module.exports = {
|
||||
plugins: {
|
||||
tailwindcss: {},
|
||||
autoprefixer: {},
|
||||
},
|
||||
}
|
||||
|
|
@ -1,19 +0,0 @@
|
|||
import React from 'react';
|
||||
import ReactDOM from 'react-dom/client';
|
||||
import { DefaultAgentScreen } from '@ra-aid/common';
|
||||
|
||||
/**
|
||||
* Main application entry point
|
||||
* Simply renders the DefaultAgentScreen component from the common package
|
||||
*/
|
||||
const App = () => {
|
||||
return <DefaultAgentScreen />;
|
||||
};
|
||||
|
||||
// Mount the app to the root element
|
||||
const root = ReactDOM.createRoot(document.getElementById('root')!);
|
||||
root.render(
|
||||
<React.StrictMode>
|
||||
<App />
|
||||
</React.StrictMode>
|
||||
);
|
||||
|
|
@ -1,14 +0,0 @@
|
|||
/** @type {import('tailwindcss').Config} */
|
||||
module.exports = {
|
||||
presets: [require('../common/tailwind.preset')],
|
||||
content: [
|
||||
'./src/**/*.{js,jsx,ts,tsx}',
|
||||
'../common/src/**/*.{js,jsx,ts,tsx}'
|
||||
],
|
||||
theme: {
|
||||
extend: {},
|
||||
},
|
||||
plugins: [
|
||||
require('@tailwindcss/forms')
|
||||
],
|
||||
}
|
||||
|
|
@ -1,15 +0,0 @@
|
|||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES6",
|
||||
"module": "ESNext",
|
||||
"moduleResolution": "node",
|
||||
"jsx": "react-jsx",
|
||||
"strict": true,
|
||||
"esModuleInterop": true,
|
||||
"skipLibCheck": true,
|
||||
"forceConsistentCasingInFileNames": true,
|
||||
"outDir": "dist",
|
||||
"rootDir": "src"
|
||||
},
|
||||
"include": ["src"]
|
||||
}
|
||||
|
|
@ -1,40 +0,0 @@
|
|||
import { defineConfig } from 'vite';
|
||||
import react from '@vitejs/plugin-react';
|
||||
import path from 'path';
|
||||
import fs from 'fs';
|
||||
|
||||
// Get all component files from common package
|
||||
const commonSrcDir = path.resolve(__dirname, '../common/src');
|
||||
|
||||
export default defineConfig({
|
||||
plugins: [react()],
|
||||
resolve: {
|
||||
alias: {
|
||||
// Direct alias to the source directory
|
||||
'@ra-aid/common': path.resolve(__dirname, '../common/src')
|
||||
},
|
||||
preserveSymlinks: true
|
||||
},
|
||||
optimizeDeps: {
|
||||
// Exclude the common package from optimization so it can trigger hot reload
|
||||
exclude: ['@ra-aid/common']
|
||||
},
|
||||
server: {
|
||||
hmr: true,
|
||||
watch: {
|
||||
usePolling: true,
|
||||
interval: 100,
|
||||
// Make sure to explicitly NOT ignore the common package
|
||||
ignored: [
|
||||
'**/node_modules/**',
|
||||
'**/dist/**',
|
||||
'!**/common/src/**'
|
||||
]
|
||||
}
|
||||
},
|
||||
build: {
|
||||
commonjsOptions: {
|
||||
transformMixedEsModules: true
|
||||
}
|
||||
}
|
||||
});
|
||||
|
|
@ -50,7 +50,6 @@ dependencies = [
|
|||
"platformdirs>=3.17.9",
|
||||
"requests",
|
||||
"packaging",
|
||||
"prompt-toolkit"
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
|
|
|||
|
|
@ -5,8 +5,24 @@ import sys
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
# Add litellm import
|
||||
import litellm
|
||||
|
||||
# Configure litellm to suppress debug logs
|
||||
os.environ["LITELLM_LOG"] = "ERROR"
|
||||
litellm.suppress_debug_info = True
|
||||
litellm.set_verbose = False
|
||||
|
||||
# Explicitly configure LiteLLM's loggers
|
||||
for logger_name in ["litellm", "LiteLLM"]:
|
||||
litellm_logger = logging.getLogger(logger_name)
|
||||
litellm_logger.setLevel(logging.WARNING)
|
||||
litellm_logger.propagate = True
|
||||
|
||||
# Use litellm's internal method to disable debugging
|
||||
if hasattr(litellm, "_logging") and hasattr(litellm._logging, "_disable_debugging"):
|
||||
litellm._logging._disable_debugging()
|
||||
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
|
|
@ -83,148 +99,13 @@ from ra_aid.tools.human import ask_human
|
|||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Configure litellm to suppress debug logs
|
||||
os.environ["LITELLM_LOG"] = "ERROR"
|
||||
litellm.suppress_debug_info = True
|
||||
litellm.set_verbose = False
|
||||
|
||||
# Explicitly configure LiteLLM's loggers
|
||||
for logger_name in ["litellm", "LiteLLM"]:
|
||||
litellm_logger = logging.getLogger(logger_name)
|
||||
litellm_logger.setLevel(logging.WARNING)
|
||||
litellm_logger.propagate = True
|
||||
|
||||
# Use litellm's internal method to disable debugging
|
||||
if hasattr(litellm, "_logging") and hasattr(litellm._logging, "_disable_debugging"):
|
||||
litellm._logging._disable_debugging()
|
||||
|
||||
|
||||
def launch_server(host: str, port: int, args):
|
||||
def launch_webui(host: str, port: int):
|
||||
"""Launch the RA.Aid web interface."""
|
||||
from ra_aid.server import run_server
|
||||
from ra_aid.database.connection import DatabaseManager
|
||||
from ra_aid.database.repositories.session_repository import SessionRepositoryManager
|
||||
from ra_aid.database.repositories.key_fact_repository import KeyFactRepositoryManager
|
||||
from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepositoryManager
|
||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepositoryManager
|
||||
from ra_aid.database.repositories.research_note_repository import ResearchNoteRepositoryManager
|
||||
from ra_aid.database.repositories.related_files_repository import RelatedFilesRepositoryManager
|
||||
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepositoryManager
|
||||
from ra_aid.database.repositories.work_log_repository import WorkLogRepositoryManager
|
||||
from ra_aid.database.repositories.config_repository import ConfigRepositoryManager
|
||||
from ra_aid.env_inv_context import EnvInvManager
|
||||
from ra_aid.env_inv import EnvDiscovery
|
||||
|
||||
# Set the console handler level to INFO for server mode
|
||||
# Get the root logger and modify the console handler
|
||||
root_logger = logging.getLogger()
|
||||
for handler in root_logger.handlers:
|
||||
# Check if this is a console handler (outputs to stdout/stderr)
|
||||
if isinstance(handler, logging.StreamHandler) and handler.stream in [sys.stdout, sys.stderr]:
|
||||
# Set console handler to INFO level for better visibility in server mode
|
||||
handler.setLevel(logging.INFO)
|
||||
logger.debug("Modified console logging level to INFO for server mode")
|
||||
|
||||
# Apply any pending database migrations
|
||||
from ra_aid.database import ensure_migrations_applied
|
||||
try:
|
||||
migration_result = ensure_migrations_applied()
|
||||
if not migration_result:
|
||||
logger.warning("Database migrations failed but execution will continue")
|
||||
except Exception as e:
|
||||
logger.error(f"Database migration error: {str(e)}")
|
||||
|
||||
# Check dependencies before proceeding
|
||||
check_dependencies()
|
||||
from ra_aid.webui import run_server
|
||||
|
||||
# Validate environment (expert_enabled, web_research_enabled)
|
||||
(
|
||||
expert_enabled,
|
||||
expert_missing,
|
||||
web_research_enabled,
|
||||
web_research_missing,
|
||||
) = validate_environment(
|
||||
args
|
||||
) # Will exit if main env vars missing
|
||||
logger.debug("Environment validation successful")
|
||||
|
||||
# Validate model configuration early
|
||||
model_config = models_params.get(args.provider, {}).get(
|
||||
args.model or "", {}
|
||||
)
|
||||
supports_temperature = model_config.get(
|
||||
"supports_temperature",
|
||||
args.provider
|
||||
in [
|
||||
"anthropic",
|
||||
"openai",
|
||||
"openrouter",
|
||||
"openai-compatible",
|
||||
"deepseek",
|
||||
],
|
||||
)
|
||||
|
||||
if supports_temperature and args.temperature is None:
|
||||
args.temperature = model_config.get("default_temperature")
|
||||
if args.temperature is None:
|
||||
cpm(
|
||||
f"This model supports temperature argument but none was given. Setting default temperature to {DEFAULT_TEMPERATURE}."
|
||||
)
|
||||
args.temperature = DEFAULT_TEMPERATURE
|
||||
logger.debug(
|
||||
f"Using default temperature {args.temperature} for model {args.model}"
|
||||
)
|
||||
|
||||
# Initialize config dictionary with values from args and environment validation
|
||||
config = {
|
||||
"provider": args.provider,
|
||||
"model": args.model,
|
||||
"expert_provider": args.expert_provider,
|
||||
"expert_model": args.expert_model,
|
||||
"temperature": args.temperature,
|
||||
"experimental_fallback_handler": args.experimental_fallback_handler,
|
||||
"expert_enabled": expert_enabled,
|
||||
"web_research_enabled": web_research_enabled,
|
||||
"show_thoughts": args.show_thoughts,
|
||||
"show_cost": args.show_cost,
|
||||
"force_reasoning_assistance": args.reasoning_assistance,
|
||||
"disable_reasoning_assistance": args.no_reasoning_assistance
|
||||
}
|
||||
|
||||
# Initialize environment discovery
|
||||
env_discovery = EnvDiscovery()
|
||||
env_discovery.discover()
|
||||
env_data = env_discovery.format_markdown()
|
||||
|
||||
print(f"Starting RA.Aid web interface on http://{host}:{port}")
|
||||
|
||||
# Initialize database connection and repositories
|
||||
with DatabaseManager() as db, \
|
||||
SessionRepositoryManager(db) as session_repo, \
|
||||
KeyFactRepositoryManager(db) as key_fact_repo, \
|
||||
KeySnippetRepositoryManager(db) as key_snippet_repo, \
|
||||
HumanInputRepositoryManager(db) as human_input_repo, \
|
||||
ResearchNoteRepositoryManager(db) as research_note_repo, \
|
||||
RelatedFilesRepositoryManager() as related_files_repo, \
|
||||
TrajectoryRepositoryManager(db) as trajectory_repo, \
|
||||
WorkLogRepositoryManager() as work_log_repo, \
|
||||
ConfigRepositoryManager(config) as config_repo, \
|
||||
EnvInvManager(env_data) as env_inv:
|
||||
|
||||
# This initializes all repositories and makes them available via their respective get methods
|
||||
logger.debug("Initialized SessionRepository")
|
||||
logger.debug("Initialized KeyFactRepository")
|
||||
logger.debug("Initialized KeySnippetRepository")
|
||||
logger.debug("Initialized HumanInputRepository")
|
||||
logger.debug("Initialized ResearchNoteRepository")
|
||||
logger.debug("Initialized RelatedFilesRepository")
|
||||
logger.debug("Initialized TrajectoryRepository")
|
||||
logger.debug("Initialized WorkLogRepository")
|
||||
logger.debug("Initialized ConfigRepository")
|
||||
logger.debug("Initialized Environment Inventory")
|
||||
|
||||
# Run the server within the context managers
|
||||
run_server(host=host, port=port)
|
||||
run_server(host=host, port=port)
|
||||
|
||||
|
||||
def parse_arguments(args=None):
|
||||
|
|
@ -394,21 +275,21 @@ Examples:
|
|||
help=f"Timeout in seconds for test command execution (default: {DEFAULT_TEST_CMD_TIMEOUT})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--server",
|
||||
"--webui",
|
||||
action="store_true",
|
||||
help="Launch the web interface",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--server-host",
|
||||
"--webui-host",
|
||||
type=str,
|
||||
default="0.0.0.0",
|
||||
help="Host to listen on for web interface (default: 0.0.0.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--server-port",
|
||||
"--webui-port",
|
||||
type=int,
|
||||
default=1818,
|
||||
help="Port to listen on for web interface (default: 1818)",
|
||||
default=8080,
|
||||
help="Port to listen on for web interface (default: 8080)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--wipe-project-memory",
|
||||
|
|
@ -640,8 +521,8 @@ def main():
|
|||
print(f"📋 {result}")
|
||||
|
||||
# Launch web interface if requested
|
||||
if args.server:
|
||||
launch_server(args.server_host, args.server_port, args)
|
||||
if args.webui:
|
||||
launch_webui(args.webui_host, args.webui_port)
|
||||
return
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,3 @@
|
|||
"""Version information."""
|
||||
|
||||
__version__ = "0.17.1"
|
||||
__version__ = "0.17.0"
|
||||
|
|
|
|||
|
|
@ -825,8 +825,7 @@ class CiaynAgent:
|
|||
try:
|
||||
last_result = self._execute_tool(response)
|
||||
self.chat_history.append(response)
|
||||
if hasattr(self.fallback_handler, 'reset_fallback_handler'):
|
||||
self.fallback_handler.reset_fallback_handler()
|
||||
self.fallback_handler.reset_fallback_handler()
|
||||
yield {}
|
||||
|
||||
except ToolExecutionError as e:
|
||||
|
|
|
|||
|
|
@ -51,13 +51,7 @@ from ra_aid.database.repositories.human_input_repository import (
|
|||
)
|
||||
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||
from ra_aid.anthropic_token_limiter import (
|
||||
get_model_name_from_chat_model,
|
||||
sonnet_35_state_modifier,
|
||||
state_modifier,
|
||||
get_model_token_limit,
|
||||
)
|
||||
from ra_aid.model_detection import is_anthropic_claude
|
||||
from ra_aid.anthropic_token_limiter import sonnet_35_state_modifier, state_modifier, get_model_token_limit
|
||||
|
||||
console = Console()
|
||||
|
||||
|
|
@ -73,6 +67,8 @@ def output_markdown_message(message: str) -> str:
|
|||
return "Message output."
|
||||
|
||||
|
||||
|
||||
|
||||
def build_agent_kwargs(
|
||||
checkpointer: Optional[Any] = None,
|
||||
model: ChatAnthropic = None,
|
||||
|
|
@ -103,15 +99,8 @@ def build_agent_kwargs(
|
|||
):
|
||||
|
||||
def wrapped_state_modifier(state: AgentState) -> list[BaseMessage]:
|
||||
model_name = get_model_name_from_chat_model(model)
|
||||
|
||||
if any(
|
||||
pattern in model_name
|
||||
for pattern in ["claude-3.5", "claude3.5", "claude-3-5"]
|
||||
):
|
||||
return sonnet_35_state_modifier(
|
||||
state, max_input_tokens=max_input_tokens
|
||||
)
|
||||
if any(pattern in model.model for pattern in ["claude-3.5", "claude3.5", "claude-3-5"]):
|
||||
return sonnet_35_state_modifier(state, max_input_tokens=max_input_tokens)
|
||||
|
||||
return state_modifier(state, model, max_input_tokens=max_input_tokens)
|
||||
|
||||
|
|
@ -121,6 +110,27 @@ def build_agent_kwargs(
|
|||
return agent_kwargs
|
||||
|
||||
|
||||
def is_anthropic_claude(config: Dict[str, Any]) -> bool:
|
||||
"""Check if the provider and model name indicate an Anthropic Claude model.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary containing provider and model information
|
||||
|
||||
Returns:
|
||||
bool: True if this is an Anthropic Claude model
|
||||
"""
|
||||
# For backwards compatibility, allow passing of config directly
|
||||
provider = config.get("provider", "")
|
||||
model_name = config.get("model", "")
|
||||
result = (
|
||||
provider.lower() == "anthropic"
|
||||
and model_name
|
||||
and "claude" in model_name.lower()
|
||||
) or (
|
||||
provider.lower() == "openrouter"
|
||||
and model_name.lower().startswith("anthropic/claude-")
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def create_agent(
|
||||
|
|
@ -159,7 +169,7 @@ def create_agent(
|
|||
# So we'll use the passed config directly
|
||||
pass
|
||||
max_input_tokens = (
|
||||
get_model_token_limit(config, agent_type, model) or DEFAULT_TOKEN_LIMIT
|
||||
get_model_token_limit(config, agent_type) or DEFAULT_TOKEN_LIMIT
|
||||
)
|
||||
|
||||
# Use REACT agent for Anthropic Claude models, otherwise use CIAYN
|
||||
|
|
@ -178,7 +188,7 @@ def create_agent(
|
|||
# Default to REACT agent if provider/model detection fails
|
||||
logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.")
|
||||
config = get_config_repository().get_all()
|
||||
max_input_tokens = get_model_token_limit(config, agent_type, model)
|
||||
max_input_tokens = get_model_token_limit(config, agent_type)
|
||||
agent_kwargs = build_agent_kwargs(checkpointer, model, max_input_tokens)
|
||||
return create_react_agent(
|
||||
model, tools, interrupt_after=["tools"], **agent_kwargs
|
||||
|
|
@ -279,7 +289,7 @@ def _handle_api_error(e, attempt, max_retries, base_delay):
|
|||
logger.warning("API error (attempt %d/%d): %s", attempt + 1, max_retries, str(e))
|
||||
delay = base_delay * (2**attempt)
|
||||
error_message = f"Encountered {e.__class__.__name__}: {e}. Retrying in {delay}s... (Attempt {attempt+1}/{max_retries})"
|
||||
|
||||
|
||||
# Record error in trajectory
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
|
|
@ -291,9 +301,9 @@ def _handle_api_error(e, attempt, max_retries, base_delay):
|
|||
record_type="error",
|
||||
human_input_id=human_input_id,
|
||||
is_error=True,
|
||||
error_message=error_message,
|
||||
error_message=error_message
|
||||
)
|
||||
|
||||
|
||||
print_error(error_message)
|
||||
start = time.monotonic()
|
||||
while time.monotonic() - start < delay:
|
||||
|
|
@ -454,9 +464,7 @@ def run_agent_with_retry(
|
|||
|
||||
try:
|
||||
_run_agent_stream(agent, msg_list)
|
||||
if fallback_handler and hasattr(
|
||||
fallback_handler, "reset_fallback_handler"
|
||||
):
|
||||
if fallback_handler:
|
||||
fallback_handler.reset_fallback_handler()
|
||||
should_break, prompt, auto_test, test_attempts = (
|
||||
_execute_test_command_wrapper(
|
||||
|
|
|
|||
|
|
@ -33,7 +33,6 @@ from ra_aid.logging_config import get_logger
|
|||
from ra_aid.model_formatters import format_key_facts_dict
|
||||
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
|
||||
from ra_aid.model_formatters.research_notes_formatter import format_research_notes_dict
|
||||
from ra_aid.text.processing import process_thinking_content
|
||||
from ra_aid.models_params import models_params
|
||||
from ra_aid.project_info import format_project_info, get_project_info
|
||||
from ra_aid.prompts.expert_prompts import EXPERT_PROMPT_SECTION_PLANNING
|
||||
|
|
@ -287,7 +286,7 @@ def run_planning_agent(
|
|||
content = "\n".join(str(item) for item in content)
|
||||
elif supports_think_tag or supports_thinking:
|
||||
# Process thinking content using the centralized function
|
||||
content, _ = process_thinking_content(
|
||||
content, _ = agent_utils.process_thinking_content(
|
||||
content=content,
|
||||
supports_think_tag=supports_think_tag,
|
||||
supports_thinking=supports_thinking,
|
||||
|
|
|
|||
|
|
@ -34,7 +34,6 @@ from ra_aid.logging_config import get_logger
|
|||
from ra_aid.model_formatters import format_key_facts_dict
|
||||
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
|
||||
from ra_aid.model_formatters.research_notes_formatter import format_research_notes_dict
|
||||
from ra_aid.text.processing import process_thinking_content
|
||||
from ra_aid.models_params import models_params
|
||||
from ra_aid.project_info import display_project_status, format_project_info, get_project_info
|
||||
from ra_aid.prompts.expert_prompts import EXPERT_PROMPT_SECTION_RESEARCH
|
||||
|
|
@ -294,7 +293,7 @@ def run_research_agent(
|
|||
content = "\n".join(str(item) for item in content)
|
||||
elif supports_think_tag or supports_thinking:
|
||||
# Process thinking content using the centralized function
|
||||
content, _ = process_thinking_content(
|
||||
content, _ = agent_utils.process_thinking_content(
|
||||
content=content,
|
||||
supports_think_tag=supports_think_tag,
|
||||
supports_thinking=supports_thinking,
|
||||
|
|
|
|||
|
|
@ -1,27 +1,31 @@
|
|||
"""Utilities for handling token limits with Anthropic models."""
|
||||
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from ra_aid.config import DEFAULT_MODEL
|
||||
from ra_aid.model_detection import is_claude_37
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
from dataclasses import dataclass
|
||||
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
RemoveMessage,
|
||||
ToolMessage,
|
||||
trim_messages,
|
||||
)
|
||||
from langchain_core.messages.base import message_to_dict
|
||||
|
||||
from ra_aid.anthropic_message_utils import (
|
||||
anthropic_trim_messages,
|
||||
has_tool_use,
|
||||
)
|
||||
from langgraph.prebuilt.chat_agent_executor import AgentState
|
||||
from litellm import token_counter, get_model_info
|
||||
from litellm import token_counter
|
||||
|
||||
from ra_aid.agent_backends.ciayn_agent import CiaynAgent
|
||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||
from ra_aid.logging_config import get_logger
|
||||
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
|
||||
from ra_aid.console.output import cpm, print_messages_compact
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
|
@ -91,7 +95,7 @@ def create_token_counter_wrapper(model: str):
|
|||
|
||||
|
||||
def state_modifier(
|
||||
state: AgentState, model: BaseChatModel, max_input_tokens: int = DEFAULT_TOKEN_LIMIT
|
||||
state: AgentState, model: ChatAnthropic, max_input_tokens: int = DEFAULT_TOKEN_LIMIT
|
||||
) -> list[BaseMessage]:
|
||||
"""Given the agent state and max_tokens, return a trimmed list of messages.
|
||||
|
||||
|
|
@ -110,8 +114,7 @@ def state_modifier(
|
|||
if not messages:
|
||||
return []
|
||||
|
||||
model_name = get_model_name_from_chat_model(model)
|
||||
wrapped_token_counter = create_token_counter_wrapper(model_name)
|
||||
wrapped_token_counter = create_token_counter_wrapper(model.model)
|
||||
|
||||
result = anthropic_trim_messages(
|
||||
messages,
|
||||
|
|
@ -124,9 +127,7 @@ def state_modifier(
|
|||
)
|
||||
|
||||
if len(result) < len(messages):
|
||||
logger.info(
|
||||
f"Anthropic Token Limiter Trimmed: {len(messages)} messages → {len(result)} messages"
|
||||
)
|
||||
logger.info(f"Anthropic Token Limiter Trimmed: {len(messages)} messages → {len(result)} messages")
|
||||
|
||||
return result
|
||||
|
||||
|
|
@ -167,89 +168,14 @@ def sonnet_35_state_modifier(
|
|||
return result
|
||||
|
||||
|
||||
def get_provider_and_model_for_agent_type(
|
||||
config: Dict[str, Any], agent_type: str
|
||||
) -> Tuple[str, str]:
|
||||
"""Get the provider and model name for the specified agent type.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary containing provider and model information
|
||||
agent_type: Type of agent ("default", "research", or "planner")
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: A tuple containing (provider, model_name)
|
||||
"""
|
||||
if agent_type == "research":
|
||||
provider = config.get("research_provider", "") or config.get("provider", "")
|
||||
model_name = config.get("research_model", "") or config.get("model", "")
|
||||
elif agent_type == "planner":
|
||||
provider = config.get("planner_provider", "") or config.get("provider", "")
|
||||
model_name = config.get("planner_model", "") or config.get("model", "")
|
||||
else:
|
||||
provider = config.get("provider", "")
|
||||
model_name = config.get("model", "")
|
||||
|
||||
return provider, model_name
|
||||
|
||||
|
||||
def get_model_name_from_chat_model(model: Optional[BaseChatModel]) -> str:
|
||||
"""Extract the model name from a BaseChatModel instance.
|
||||
|
||||
Args:
|
||||
model: The BaseChatModel instance
|
||||
|
||||
Returns:
|
||||
str: The model name extracted from the instance, or DEFAULT_MODEL if not found
|
||||
"""
|
||||
if model is None:
|
||||
return DEFAULT_MODEL
|
||||
|
||||
if hasattr(model, "model"):
|
||||
return model.model
|
||||
elif hasattr(model, "model_name"):
|
||||
return model.model_name
|
||||
else:
|
||||
logger.debug(f"Could not extract model name from {model}, using DEFAULT_MODEL")
|
||||
return DEFAULT_MODEL
|
||||
|
||||
|
||||
def adjust_claude_37_token_limit(
|
||||
max_input_tokens: int, model: Optional[BaseChatModel]
|
||||
) -> Optional[int]:
|
||||
"""Adjust token limit for Claude 3.7 models by subtracting max_tokens.
|
||||
|
||||
Args:
|
||||
max_input_tokens: The original token limit
|
||||
model: The model instance to check
|
||||
|
||||
Returns:
|
||||
Optional[int]: Adjusted token limit if model is Claude 3.7, otherwise original limit
|
||||
"""
|
||||
if not max_input_tokens:
|
||||
return max_input_tokens
|
||||
|
||||
if model and hasattr(model, "model") and is_claude_37(model.model):
|
||||
if hasattr(model, "max_tokens") and model.max_tokens:
|
||||
effective_max_input_tokens = max_input_tokens - model.max_tokens
|
||||
logger.debug(
|
||||
f"Adjusting token limit for Claude 3.7 model: {max_input_tokens} - {model.max_tokens} = {effective_max_input_tokens}"
|
||||
)
|
||||
return effective_max_input_tokens
|
||||
|
||||
return max_input_tokens
|
||||
|
||||
|
||||
def get_model_token_limit(
|
||||
config: Dict[str, Any],
|
||||
agent_type: str = "default",
|
||||
model: Optional[BaseChatModel] = None,
|
||||
config: Dict[str, Any], agent_type: str = "default"
|
||||
) -> Optional[int]:
|
||||
"""Get the token limit for the current model configuration based on agent type.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary containing provider and model information
|
||||
agent_type: Type of agent ("default", "research", or "planner")
|
||||
model: Optional BaseChatModel instance to check for model-specific attributes
|
||||
|
||||
Returns:
|
||||
Optional[int]: The token limit if found, None otherwise
|
||||
|
|
@ -264,20 +190,27 @@ def get_model_token_limit(
|
|||
# In tests, this may fail because the repository isn't set up
|
||||
# So we'll use the passed config directly
|
||||
pass
|
||||
|
||||
provider, model_name = get_provider_and_model_for_agent_type(config, agent_type)
|
||||
|
||||
# Always attempt to get model info from litellm first
|
||||
provider_model = model_name if not provider else f"{provider}/{model_name}"
|
||||
if agent_type == "research":
|
||||
provider = config.get("research_provider", "") or config.get("provider", "")
|
||||
model_name = config.get("research_model", "") or config.get("model", "")
|
||||
elif agent_type == "planner":
|
||||
provider = config.get("planner_provider", "") or config.get("provider", "")
|
||||
model_name = config.get("planner_model", "") or config.get("model", "")
|
||||
else:
|
||||
provider = config.get("provider", "")
|
||||
model_name = config.get("model", "")
|
||||
|
||||
try:
|
||||
from litellm import get_model_info
|
||||
|
||||
provider_model = model_name if not provider else f"{provider}/{model_name}"
|
||||
model_info = get_model_info(provider_model)
|
||||
max_input_tokens = model_info.get("max_input_tokens")
|
||||
if max_input_tokens:
|
||||
logger.debug(
|
||||
f"Using litellm token limit for {model_name}: {max_input_tokens}"
|
||||
)
|
||||
return adjust_claude_37_token_limit(max_input_tokens, model)
|
||||
return max_input_tokens
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f"Error getting model info from litellm: {e}, falling back to models_params"
|
||||
|
|
@ -296,7 +229,7 @@ def get_model_token_limit(
|
|||
max_input_tokens = None
|
||||
logger.debug(f"Could not find token limit for {provider}/{model_name}")
|
||||
|
||||
return adjust_claude_37_token_limit(max_input_tokens, model)
|
||||
return max_input_tokens
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get model token limit: {e}")
|
||||
|
|
|
|||
|
|
@ -1,376 +0,0 @@
|
|||
"""
|
||||
Pydantic models for ra_aid database entities.
|
||||
|
||||
This module defines Pydantic models that correspond to Peewee ORM models,
|
||||
providing validation, serialization, and deserialization capabilities.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import json
|
||||
from typing import Dict, List, Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_serializer, field_validator
|
||||
|
||||
|
||||
class SessionModel(BaseModel):
|
||||
"""
|
||||
Pydantic model representing a Session.
|
||||
|
||||
This model corresponds to the Session Peewee ORM model and provides
|
||||
validation and serialization capabilities. It handles the conversion
|
||||
between JSON-encoded strings and Python dictionaries for the machine_info field.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier for the session
|
||||
created_at: When the session record was created
|
||||
updated_at: When the session record was last updated
|
||||
start_time: When the program session started
|
||||
command_line: Command line arguments used to start the program
|
||||
program_version: Version of the program
|
||||
machine_info: Dictionary containing machine-specific metadata
|
||||
"""
|
||||
id: Optional[int] = None
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
start_time: datetime.datetime
|
||||
command_line: Optional[str] = None
|
||||
program_version: Optional[str] = None
|
||||
machine_info: Optional[Dict[str, Any]] = None
|
||||
|
||||
# Configure the model to work with ORM objects
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_validator("machine_info", mode="before")
|
||||
@classmethod
|
||||
def parse_machine_info(cls, value: Any) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Parse the machine_info field from a JSON string to a dictionary.
|
||||
|
||||
Args:
|
||||
value: The value to parse, can be a string, dict, or None
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: The parsed dictionary or None
|
||||
|
||||
Raises:
|
||||
ValueError: If the JSON string is invalid
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON in machine_info: {e}")
|
||||
|
||||
raise ValueError(f"Unexpected type for machine_info: {type(value)}")
|
||||
|
||||
@field_serializer("machine_info")
|
||||
def serialize_machine_info(self, machine_info: Optional[Dict[str, Any]]) -> Optional[str]:
|
||||
"""
|
||||
Serialize the machine_info dictionary to a JSON string for storage.
|
||||
|
||||
Args:
|
||||
machine_info: Dictionary to serialize
|
||||
|
||||
Returns:
|
||||
Optional[str]: JSON-encoded string or None
|
||||
"""
|
||||
if machine_info is None:
|
||||
return None
|
||||
|
||||
return json.dumps(machine_info)
|
||||
|
||||
|
||||
class HumanInputModel(BaseModel):
|
||||
"""
|
||||
Pydantic model representing a HumanInput.
|
||||
|
||||
This model corresponds to the HumanInput Peewee ORM model and provides
|
||||
validation and serialization capabilities.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier for the human input
|
||||
created_at: When the record was created
|
||||
updated_at: When the record was last updated
|
||||
content: The text content of the input
|
||||
source: The source of the input ('cli', 'chat', or 'hil')
|
||||
session_id: Optional reference to the associated session
|
||||
"""
|
||||
id: Optional[int] = None
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
content: str
|
||||
source: str
|
||||
session_id: Optional[int] = None
|
||||
|
||||
# Configure the model to work with ORM objects
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class KeyFactModel(BaseModel):
|
||||
"""
|
||||
Pydantic model representing a KeyFact.
|
||||
|
||||
This model corresponds to the KeyFact Peewee ORM model and provides
|
||||
validation and serialization capabilities.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier for the key fact
|
||||
created_at: When the record was created
|
||||
updated_at: When the record was last updated
|
||||
content: The text content of the key fact
|
||||
human_input_id: Optional reference to the associated human input
|
||||
session_id: Optional reference to the associated session
|
||||
"""
|
||||
id: Optional[int] = None
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
content: str
|
||||
human_input_id: Optional[int] = None
|
||||
session_id: Optional[int] = None
|
||||
|
||||
# Configure the model to work with ORM objects
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class KeySnippetModel(BaseModel):
|
||||
"""
|
||||
Pydantic model representing a KeySnippet.
|
||||
|
||||
This model corresponds to the KeySnippet Peewee ORM model and provides
|
||||
validation and serialization capabilities.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier for the key snippet
|
||||
created_at: When the record was created
|
||||
updated_at: When the record was last updated
|
||||
filepath: Path to the source file
|
||||
line_number: Line number where the snippet starts
|
||||
snippet: The source code snippet text
|
||||
description: Optional description of the significance
|
||||
human_input_id: Optional reference to the associated human input
|
||||
session_id: Optional reference to the associated session
|
||||
"""
|
||||
id: Optional[int] = None
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
filepath: str
|
||||
line_number: int
|
||||
snippet: str
|
||||
description: Optional[str] = None
|
||||
human_input_id: Optional[int] = None
|
||||
session_id: Optional[int] = None
|
||||
|
||||
# Configure the model to work with ORM objects
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ResearchNoteModel(BaseModel):
|
||||
"""
|
||||
Pydantic model representing a ResearchNote.
|
||||
|
||||
This model corresponds to the ResearchNote Peewee ORM model and provides
|
||||
validation and serialization capabilities.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier for the research note
|
||||
created_at: When the record was created
|
||||
updated_at: When the record was last updated
|
||||
content: The text content of the research note
|
||||
human_input_id: Optional reference to the associated human input
|
||||
session_id: Optional reference to the associated session
|
||||
"""
|
||||
id: Optional[int] = None
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
content: str
|
||||
human_input_id: Optional[int] = None
|
||||
session_id: Optional[int] = None
|
||||
|
||||
# Configure the model to work with ORM objects
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class TrajectoryModel(BaseModel):
|
||||
"""
|
||||
Pydantic model representing a Trajectory.
|
||||
|
||||
This model corresponds to the Trajectory Peewee ORM model and provides
|
||||
validation and serialization capabilities. It handles the conversion
|
||||
between JSON-encoded strings and Python dictionaries for the tool_parameters,
|
||||
tool_result, and step_data fields.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier for the trajectory
|
||||
created_at: When the record was created
|
||||
updated_at: When the record was last updated
|
||||
human_input_id: Optional reference to the associated human input
|
||||
tool_name: Name of the tool that was executed
|
||||
tool_parameters: Dictionary containing the parameters passed to the tool
|
||||
tool_result: Dictionary containing the result returned by the tool
|
||||
step_data: Dictionary containing UI rendering data
|
||||
record_type: Type of trajectory record
|
||||
cost: Optional cost of the tool execution
|
||||
tokens: Optional token usage of the tool execution
|
||||
is_error: Flag indicating if this record represents an error
|
||||
error_message: The error message if is_error is True
|
||||
error_type: The type/class of the error if is_error is True
|
||||
error_details: Additional error details if is_error is True
|
||||
session_id: Optional reference to the associated session
|
||||
"""
|
||||
id: Optional[int] = None
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
human_input_id: Optional[int] = None
|
||||
tool_name: Optional[str] = None
|
||||
tool_parameters: Optional[Dict[str, Any]] = None
|
||||
tool_result: Optional[Any] = None
|
||||
step_data: Optional[Dict[str, Any]] = None
|
||||
record_type: Optional[str] = None
|
||||
cost: Optional[float] = None
|
||||
tokens: Optional[int] = None
|
||||
is_error: bool = False
|
||||
error_message: Optional[str] = None
|
||||
error_type: Optional[str] = None
|
||||
error_details: Optional[str] = None
|
||||
session_id: Optional[int] = None
|
||||
|
||||
# Configure the model to work with ORM objects
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_validator("tool_parameters", mode="before")
|
||||
@classmethod
|
||||
def parse_tool_parameters(cls, value: Any) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Parse the tool_parameters field from a JSON string to a dictionary.
|
||||
|
||||
Args:
|
||||
value: The value to parse, can be a string, dict, or None
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: The parsed dictionary or None
|
||||
|
||||
Raises:
|
||||
ValueError: If the JSON string is invalid
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON in tool_parameters: {e}")
|
||||
|
||||
raise ValueError(f"Unexpected type for tool_parameters: {type(value)}")
|
||||
|
||||
@field_validator("tool_result", mode="before")
|
||||
@classmethod
|
||||
def parse_tool_result(cls, value: Any) -> Optional[Any]:
|
||||
"""
|
||||
Parse the tool_result field from a JSON string to a Python object.
|
||||
|
||||
Args:
|
||||
value: The value to parse, can be a string, dict, list, or None
|
||||
|
||||
Returns:
|
||||
Optional[Any]: The parsed object or None
|
||||
|
||||
Raises:
|
||||
ValueError: If the JSON string is invalid
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON in tool_result: {e}")
|
||||
|
||||
@field_validator("step_data", mode="before")
|
||||
@classmethod
|
||||
def parse_step_data(cls, value: Any) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Parse the step_data field from a JSON string to a dictionary.
|
||||
|
||||
Args:
|
||||
value: The value to parse, can be a string, dict, or None
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: The parsed dictionary or None
|
||||
|
||||
Raises:
|
||||
ValueError: If the JSON string is invalid
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON in step_data: {e}")
|
||||
|
||||
raise ValueError(f"Unexpected type for step_data: {type(value)}")
|
||||
|
||||
@field_serializer("tool_parameters")
|
||||
def serialize_tool_parameters(self, tool_parameters: Optional[Dict[str, Any]]) -> Optional[str]:
|
||||
"""
|
||||
Serialize the tool_parameters dictionary to a JSON string for storage.
|
||||
|
||||
Args:
|
||||
tool_parameters: Dictionary to serialize
|
||||
|
||||
Returns:
|
||||
Optional[str]: JSON-encoded string or None
|
||||
"""
|
||||
if tool_parameters is None:
|
||||
return None
|
||||
|
||||
return json.dumps(tool_parameters)
|
||||
|
||||
@field_serializer("tool_result")
|
||||
def serialize_tool_result(self, tool_result: Optional[Any]) -> Optional[str]:
|
||||
"""
|
||||
Serialize the tool_result object to a JSON string for storage.
|
||||
|
||||
Args:
|
||||
tool_result: Object to serialize
|
||||
|
||||
Returns:
|
||||
Optional[str]: JSON-encoded string or None
|
||||
"""
|
||||
if tool_result is None:
|
||||
return None
|
||||
|
||||
return json.dumps(tool_result)
|
||||
|
||||
@field_serializer("step_data")
|
||||
def serialize_step_data(self, step_data: Optional[Dict[str, Any]]) -> Optional[str]:
|
||||
"""
|
||||
Serialize the step_data dictionary to a JSON string for storage.
|
||||
|
||||
Args:
|
||||
step_data: Dictionary to serialize
|
||||
|
||||
Returns:
|
||||
Optional[str]: JSON-encoded string or None
|
||||
"""
|
||||
if step_data is None:
|
||||
return None
|
||||
|
||||
return json.dumps(step_data)
|
||||
|
|
@ -11,7 +11,6 @@ import contextvars
|
|||
import peewee
|
||||
|
||||
from ra_aid.database.models import HumanInput
|
||||
from ra_aid.database.pydantic_models import HumanInputModel
|
||||
from ra_aid.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
|
@ -119,23 +118,8 @@ class HumanInputRepository:
|
|||
if db is None:
|
||||
raise ValueError("Database connection is required for HumanInputRepository")
|
||||
self.db = db
|
||||
|
||||
def _to_model(self, human_input: Optional[HumanInput]) -> Optional[HumanInputModel]:
|
||||
"""
|
||||
Convert a Peewee HumanInput object to a Pydantic HumanInputModel.
|
||||
|
||||
Args:
|
||||
human_input: Peewee HumanInput instance or None
|
||||
|
||||
Returns:
|
||||
Optional[HumanInputModel]: Pydantic model representation or None if human_input is None
|
||||
"""
|
||||
if human_input is None:
|
||||
return None
|
||||
|
||||
return HumanInputModel.model_validate(human_input, from_attributes=True)
|
||||
|
||||
def create(self, content: str, source: str) -> HumanInputModel:
|
||||
def create(self, content: str, source: str) -> HumanInput:
|
||||
"""
|
||||
Create a new human input record in the database.
|
||||
|
||||
|
|
@ -144,7 +128,7 @@ class HumanInputRepository:
|
|||
source: The source of the input (e.g., "cli", "chat", "hil")
|
||||
|
||||
Returns:
|
||||
HumanInputModel: The newly created human input instance
|
||||
HumanInput: The newly created human input instance
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error creating the record
|
||||
|
|
@ -152,12 +136,12 @@ class HumanInputRepository:
|
|||
try:
|
||||
input_record = HumanInput.create(content=content, source=source)
|
||||
logger.debug(f"Created human input ID {input_record.id} from {source}")
|
||||
return self._to_model(input_record)
|
||||
return input_record
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to create human input record: {str(e)}")
|
||||
raise
|
||||
|
||||
def get(self, input_id: int) -> Optional[HumanInputModel]:
|
||||
def get(self, input_id: int) -> Optional[HumanInput]:
|
||||
"""
|
||||
Retrieve a human input record by its ID.
|
||||
|
||||
|
|
@ -165,19 +149,18 @@ class HumanInputRepository:
|
|||
input_id: The ID of the human input to retrieve
|
||||
|
||||
Returns:
|
||||
Optional[HumanInputModel]: The human input instance if found, None otherwise
|
||||
Optional[HumanInput]: The human input instance if found, None otherwise
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error accessing the database
|
||||
"""
|
||||
try:
|
||||
human_input = HumanInput.get_or_none(HumanInput.id == input_id)
|
||||
return self._to_model(human_input)
|
||||
return HumanInput.get_or_none(HumanInput.id == input_id)
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to fetch human input {input_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def update(self, input_id: int, content: str = None, source: str = None) -> Optional[HumanInputModel]:
|
||||
def update(self, input_id: int, content: str = None, source: str = None) -> Optional[HumanInput]:
|
||||
"""
|
||||
Update an existing human input record.
|
||||
|
||||
|
|
@ -187,14 +170,14 @@ class HumanInputRepository:
|
|||
source: The new source for the human input
|
||||
|
||||
Returns:
|
||||
Optional[HumanInputModel]: The updated human input if found, None otherwise
|
||||
Optional[HumanInput]: The updated human input if found, None otherwise
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error updating the record
|
||||
"""
|
||||
try:
|
||||
# We need to get the raw Peewee object for updating
|
||||
input_record = HumanInput.get_or_none(HumanInput.id == input_id)
|
||||
# First check if the record exists
|
||||
input_record = self.get(input_id)
|
||||
if not input_record:
|
||||
logger.warning(f"Attempted to update non-existent human input {input_id}")
|
||||
return None
|
||||
|
|
@ -207,7 +190,7 @@ class HumanInputRepository:
|
|||
|
||||
input_record.save()
|
||||
logger.debug(f"Updated human input ID {input_id}")
|
||||
return self._to_model(input_record)
|
||||
return input_record
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to update human input {input_id}: {str(e)}")
|
||||
raise
|
||||
|
|
@ -240,24 +223,23 @@ class HumanInputRepository:
|
|||
logger.error(f"Failed to delete human input {input_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_all(self) -> List[HumanInputModel]:
|
||||
def get_all(self) -> List[HumanInput]:
|
||||
"""
|
||||
Retrieve all human input records from the database.
|
||||
|
||||
Returns:
|
||||
List[HumanInputModel]: List of all human input instances
|
||||
List[HumanInput]: List of all human input instances
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error accessing the database
|
||||
"""
|
||||
try:
|
||||
human_inputs = list(HumanInput.select().order_by(HumanInput.created_at.desc()))
|
||||
return [self._to_model(input) for input in human_inputs]
|
||||
return list(HumanInput.select().order_by(HumanInput.created_at.desc()))
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to fetch all human inputs: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_recent(self, limit: int = 10) -> List[HumanInputModel]:
|
||||
def get_recent(self, limit: int = 10) -> List[HumanInput]:
|
||||
"""
|
||||
Retrieve the most recent human input records.
|
||||
|
||||
|
|
@ -265,14 +247,13 @@ class HumanInputRepository:
|
|||
limit: Maximum number of records to retrieve (default: 10)
|
||||
|
||||
Returns:
|
||||
List[HumanInputModel]: List of the most recent human input records
|
||||
List[HumanInput]: List of the most recent human input records
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error accessing the database
|
||||
"""
|
||||
try:
|
||||
human_inputs = list(HumanInput.select().order_by(HumanInput.created_at.desc()).limit(limit))
|
||||
return [self._to_model(input) for input in human_inputs]
|
||||
return list(HumanInput.select().order_by(HumanInput.created_at.desc()).limit(limit))
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to fetch recent human inputs: {str(e)}")
|
||||
raise
|
||||
|
|
@ -296,7 +277,7 @@ class HumanInputRepository:
|
|||
logger.error(f"Failed to fetch most recent human input ID: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_by_source(self, source: str) -> List[HumanInputModel]:
|
||||
def get_by_source(self, source: str) -> List[HumanInput]:
|
||||
"""
|
||||
Retrieve human input records by source.
|
||||
|
||||
|
|
@ -304,14 +285,13 @@ class HumanInputRepository:
|
|||
source: The source to filter by (e.g., "cli", "chat", "hil")
|
||||
|
||||
Returns:
|
||||
List[HumanInputModel]: List of human input records from the specified source
|
||||
List[HumanInput]: List of human input records from the specified source
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error accessing the database
|
||||
"""
|
||||
try:
|
||||
human_inputs = list(HumanInput.select().where(HumanInput.source == source).order_by(HumanInput.created_at.desc()))
|
||||
return [self._to_model(input) for input in human_inputs]
|
||||
return list(HumanInput.select().where(HumanInput.source == source).order_by(HumanInput.created_at.desc()))
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to fetch human inputs by source {source}: {str(e)}")
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ from contextlib import contextmanager
|
|||
import peewee
|
||||
|
||||
from ra_aid.database.models import KeyFact
|
||||
from ra_aid.database.pydantic_models import KeyFactModel
|
||||
from ra_aid.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
|
@ -121,22 +120,7 @@ class KeyFactRepository:
|
|||
raise ValueError("Database connection is required for KeyFactRepository")
|
||||
self.db = db
|
||||
|
||||
def _to_model(self, fact: Optional[KeyFact]) -> Optional[KeyFactModel]:
|
||||
"""
|
||||
Convert a Peewee KeyFact object to a Pydantic KeyFactModel.
|
||||
|
||||
Args:
|
||||
fact: Peewee KeyFact instance or None
|
||||
|
||||
Returns:
|
||||
Optional[KeyFactModel]: Pydantic model representation or None if fact is None
|
||||
"""
|
||||
if fact is None:
|
||||
return None
|
||||
|
||||
return KeyFactModel.model_validate(fact, from_attributes=True)
|
||||
|
||||
def create(self, content: str, human_input_id: Optional[int] = None) -> KeyFactModel:
|
||||
def create(self, content: str, human_input_id: Optional[int] = None) -> KeyFact:
|
||||
"""
|
||||
Create a new key fact in the database.
|
||||
|
||||
|
|
@ -145,7 +129,7 @@ class KeyFactRepository:
|
|||
human_input_id: Optional ID of the associated human input
|
||||
|
||||
Returns:
|
||||
KeyFactModel: The newly created key fact instance
|
||||
KeyFact: The newly created key fact instance
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error creating the fact
|
||||
|
|
@ -153,12 +137,12 @@ class KeyFactRepository:
|
|||
try:
|
||||
fact = KeyFact.create(content=content, human_input_id=human_input_id)
|
||||
logger.debug(f"Created key fact ID {fact.id}: {content}")
|
||||
return self._to_model(fact)
|
||||
return fact
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to create key fact: {str(e)}")
|
||||
raise
|
||||
|
||||
def get(self, fact_id: int) -> Optional[KeyFactModel]:
|
||||
def get(self, fact_id: int) -> Optional[KeyFact]:
|
||||
"""
|
||||
Retrieve a key fact by its ID.
|
||||
|
||||
|
|
@ -166,19 +150,18 @@ class KeyFactRepository:
|
|||
fact_id: The ID of the key fact to retrieve
|
||||
|
||||
Returns:
|
||||
Optional[KeyFactModel]: The key fact instance if found, None otherwise
|
||||
Optional[KeyFact]: The key fact instance if found, None otherwise
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error accessing the database
|
||||
"""
|
||||
try:
|
||||
fact = KeyFact.get_or_none(KeyFact.id == fact_id)
|
||||
return self._to_model(fact)
|
||||
return KeyFact.get_or_none(KeyFact.id == fact_id)
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to fetch key fact {fact_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def update(self, fact_id: int, content: str) -> Optional[KeyFactModel]:
|
||||
def update(self, fact_id: int, content: str) -> Optional[KeyFact]:
|
||||
"""
|
||||
Update an existing key fact.
|
||||
|
||||
|
|
@ -187,14 +170,14 @@ class KeyFactRepository:
|
|||
content: The new content for the key fact
|
||||
|
||||
Returns:
|
||||
Optional[KeyFactModel]: The updated key fact if found, None otherwise
|
||||
Optional[KeyFact]: The updated key fact if found, None otherwise
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error updating the fact
|
||||
"""
|
||||
try:
|
||||
# First check if the fact exists
|
||||
fact = KeyFact.get_or_none(KeyFact.id == fact_id)
|
||||
fact = self.get(fact_id)
|
||||
if not fact:
|
||||
logger.warning(f"Attempted to update non-existent key fact {fact_id}")
|
||||
return None
|
||||
|
|
@ -203,7 +186,7 @@ class KeyFactRepository:
|
|||
fact.content = content
|
||||
fact.save()
|
||||
logger.debug(f"Updated key fact ID {fact_id}: {content}")
|
||||
return self._to_model(fact)
|
||||
return fact
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to update key fact {fact_id}: {str(e)}")
|
||||
raise
|
||||
|
|
@ -223,7 +206,7 @@ class KeyFactRepository:
|
|||
"""
|
||||
try:
|
||||
# First check if the fact exists
|
||||
fact = KeyFact.get_or_none(KeyFact.id == fact_id)
|
||||
fact = self.get(fact_id)
|
||||
if not fact:
|
||||
logger.warning(f"Attempted to delete non-existent key fact {fact_id}")
|
||||
return False
|
||||
|
|
@ -236,19 +219,18 @@ class KeyFactRepository:
|
|||
logger.error(f"Failed to delete key fact {fact_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_all(self) -> List[KeyFactModel]:
|
||||
def get_all(self) -> List[KeyFact]:
|
||||
"""
|
||||
Retrieve all key facts from the database.
|
||||
|
||||
Returns:
|
||||
List[KeyFactModel]: List of all key fact instances
|
||||
List[KeyFact]: List of all key fact instances
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error accessing the database
|
||||
"""
|
||||
try:
|
||||
facts = list(KeyFact.select().order_by(KeyFact.id))
|
||||
return [self._to_model(fact) for fact in facts]
|
||||
return list(KeyFact.select().order_by(KeyFact.id))
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to fetch all key facts: {str(e)}")
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ import contextvars
|
|||
import peewee
|
||||
|
||||
from ra_aid.database.models import KeySnippet
|
||||
from ra_aid.database.pydantic_models import KeySnippetModel
|
||||
from ra_aid.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
|
@ -130,25 +129,10 @@ class KeySnippetRepository:
|
|||
raise ValueError("Database connection is required for KeySnippetRepository")
|
||||
self.db = db
|
||||
|
||||
def _to_model(self, snippet: Optional[KeySnippet]) -> Optional[KeySnippetModel]:
|
||||
"""
|
||||
Convert a Peewee KeySnippet object to a Pydantic KeySnippetModel.
|
||||
|
||||
Args:
|
||||
snippet: Peewee KeySnippet instance or None
|
||||
|
||||
Returns:
|
||||
Optional[KeySnippetModel]: Pydantic model representation or None if snippet is None
|
||||
"""
|
||||
if snippet is None:
|
||||
return None
|
||||
|
||||
return KeySnippetModel.model_validate(snippet, from_attributes=True)
|
||||
|
||||
def create(
|
||||
self, filepath: str, line_number: int, snippet: str, description: Optional[str] = None,
|
||||
human_input_id: Optional[int] = None
|
||||
) -> KeySnippetModel:
|
||||
) -> KeySnippet:
|
||||
"""
|
||||
Create a new key snippet in the database.
|
||||
|
||||
|
|
@ -160,7 +144,7 @@ class KeySnippetRepository:
|
|||
human_input_id: Optional ID of the associated human input
|
||||
|
||||
Returns:
|
||||
KeySnippetModel: The newly created key snippet instance
|
||||
KeySnippet: The newly created key snippet instance
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error creating the snippet
|
||||
|
|
@ -174,12 +158,12 @@ class KeySnippetRepository:
|
|||
human_input_id=human_input_id
|
||||
)
|
||||
logger.debug(f"Created key snippet ID {key_snippet.id}: {filepath}:{line_number}")
|
||||
return self._to_model(key_snippet)
|
||||
return key_snippet
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to create key snippet: {str(e)}")
|
||||
raise
|
||||
|
||||
def get(self, snippet_id: int) -> Optional[KeySnippetModel]:
|
||||
def get(self, snippet_id: int) -> Optional[KeySnippet]:
|
||||
"""
|
||||
Retrieve a key snippet by its ID.
|
||||
|
||||
|
|
@ -187,14 +171,13 @@ class KeySnippetRepository:
|
|||
snippet_id: The ID of the key snippet to retrieve
|
||||
|
||||
Returns:
|
||||
Optional[KeySnippetModel]: The key snippet instance if found, None otherwise
|
||||
Optional[KeySnippet]: The key snippet instance if found, None otherwise
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error accessing the database
|
||||
"""
|
||||
try:
|
||||
snippet = KeySnippet.get_or_none(KeySnippet.id == snippet_id)
|
||||
return self._to_model(snippet)
|
||||
return KeySnippet.get_or_none(KeySnippet.id == snippet_id)
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to fetch key snippet {snippet_id}: {str(e)}")
|
||||
raise
|
||||
|
|
@ -206,7 +189,7 @@ class KeySnippetRepository:
|
|||
line_number: int,
|
||||
snippet: str,
|
||||
description: Optional[str] = None
|
||||
) -> Optional[KeySnippetModel]:
|
||||
) -> Optional[KeySnippet]:
|
||||
"""
|
||||
Update an existing key snippet.
|
||||
|
||||
|
|
@ -218,14 +201,14 @@ class KeySnippetRepository:
|
|||
description: Optional description of the significance
|
||||
|
||||
Returns:
|
||||
Optional[KeySnippetModel]: The updated key snippet if found, None otherwise
|
||||
Optional[KeySnippet]: The updated key snippet if found, None otherwise
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error updating the snippet
|
||||
"""
|
||||
try:
|
||||
# First check if the snippet exists
|
||||
key_snippet = KeySnippet.get_or_none(KeySnippet.id == snippet_id)
|
||||
key_snippet = self.get(snippet_id)
|
||||
if not key_snippet:
|
||||
logger.warning(f"Attempted to update non-existent key snippet {snippet_id}")
|
||||
return None
|
||||
|
|
@ -237,7 +220,7 @@ class KeySnippetRepository:
|
|||
key_snippet.description = description
|
||||
key_snippet.save()
|
||||
logger.debug(f"Updated key snippet ID {snippet_id}: {filepath}:{line_number}")
|
||||
return self._to_model(key_snippet)
|
||||
return key_snippet
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to update key snippet {snippet_id}: {str(e)}")
|
||||
raise
|
||||
|
|
@ -257,7 +240,7 @@ class KeySnippetRepository:
|
|||
"""
|
||||
try:
|
||||
# First check if the snippet exists
|
||||
key_snippet = KeySnippet.get_or_none(KeySnippet.id == snippet_id)
|
||||
key_snippet = self.get(snippet_id)
|
||||
if not key_snippet:
|
||||
logger.warning(f"Attempted to delete non-existent key snippet {snippet_id}")
|
||||
return False
|
||||
|
|
@ -270,19 +253,18 @@ class KeySnippetRepository:
|
|||
logger.error(f"Failed to delete key snippet {snippet_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_all(self) -> List[KeySnippetModel]:
|
||||
def get_all(self) -> List[KeySnippet]:
|
||||
"""
|
||||
Retrieve all key snippets from the database.
|
||||
|
||||
Returns:
|
||||
List[KeySnippetModel]: List of all key snippet instances
|
||||
List[KeySnippet]: List of all key snippet instances
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error accessing the database
|
||||
"""
|
||||
try:
|
||||
snippets = list(KeySnippet.select().order_by(KeySnippet.id))
|
||||
return [self._to_model(snippet) for snippet in snippets]
|
||||
return list(KeySnippet.select().order_by(KeySnippet.id))
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to fetch all key snippets: {str(e)}")
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ from contextlib import contextmanager
|
|||
import peewee
|
||||
|
||||
from ra_aid.database.models import ResearchNote
|
||||
from ra_aid.database.pydantic_models import ResearchNoteModel
|
||||
from ra_aid.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
|
@ -121,22 +120,7 @@ class ResearchNoteRepository:
|
|||
raise ValueError("Database connection is required for ResearchNoteRepository")
|
||||
self.db = db
|
||||
|
||||
def _to_model(self, note: Optional[ResearchNote]) -> Optional[ResearchNoteModel]:
|
||||
"""
|
||||
Convert a Peewee ResearchNote object to a Pydantic ResearchNoteModel.
|
||||
|
||||
Args:
|
||||
note: Peewee ResearchNote instance or None
|
||||
|
||||
Returns:
|
||||
Optional[ResearchNoteModel]: Pydantic model representation or None if note is None
|
||||
"""
|
||||
if note is None:
|
||||
return None
|
||||
|
||||
return ResearchNoteModel.model_validate(note, from_attributes=True)
|
||||
|
||||
def create(self, content: str, human_input_id: Optional[int] = None) -> ResearchNoteModel:
|
||||
def create(self, content: str, human_input_id: Optional[int] = None) -> ResearchNote:
|
||||
"""
|
||||
Create a new research note in the database.
|
||||
|
||||
|
|
@ -145,7 +129,7 @@ class ResearchNoteRepository:
|
|||
human_input_id: Optional ID of the associated human input
|
||||
|
||||
Returns:
|
||||
ResearchNoteModel: The newly created research note instance
|
||||
ResearchNote: The newly created research note instance
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error creating the note
|
||||
|
|
@ -153,12 +137,12 @@ class ResearchNoteRepository:
|
|||
try:
|
||||
note = ResearchNote.create(content=content, human_input_id=human_input_id)
|
||||
logger.debug(f"Created research note ID {note.id}: {content[:50]}...")
|
||||
return self._to_model(note)
|
||||
return note
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to create research note: {str(e)}")
|
||||
raise
|
||||
|
||||
def get(self, note_id: int) -> Optional[ResearchNoteModel]:
|
||||
def get(self, note_id: int) -> Optional[ResearchNote]:
|
||||
"""
|
||||
Retrieve a research note by its ID.
|
||||
|
||||
|
|
@ -166,19 +150,18 @@ class ResearchNoteRepository:
|
|||
note_id: The ID of the research note to retrieve
|
||||
|
||||
Returns:
|
||||
Optional[ResearchNoteModel]: The research note instance if found, None otherwise
|
||||
Optional[ResearchNote]: The research note instance if found, None otherwise
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error accessing the database
|
||||
"""
|
||||
try:
|
||||
note = ResearchNote.get_or_none(ResearchNote.id == note_id)
|
||||
return self._to_model(note)
|
||||
return ResearchNote.get_or_none(ResearchNote.id == note_id)
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to fetch research note {note_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def update(self, note_id: int, content: str) -> Optional[ResearchNoteModel]:
|
||||
def update(self, note_id: int, content: str) -> Optional[ResearchNote]:
|
||||
"""
|
||||
Update an existing research note.
|
||||
|
||||
|
|
@ -187,14 +170,14 @@ class ResearchNoteRepository:
|
|||
content: The new content for the research note
|
||||
|
||||
Returns:
|
||||
Optional[ResearchNoteModel]: The updated research note if found, None otherwise
|
||||
Optional[ResearchNote]: The updated research note if found, None otherwise
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error updating the note
|
||||
"""
|
||||
try:
|
||||
# First check if the note exists
|
||||
note = ResearchNote.get_or_none(ResearchNote.id == note_id)
|
||||
note = self.get(note_id)
|
||||
if not note:
|
||||
logger.warning(f"Attempted to update non-existent research note {note_id}")
|
||||
return None
|
||||
|
|
@ -203,7 +186,7 @@ class ResearchNoteRepository:
|
|||
note.content = content
|
||||
note.save()
|
||||
logger.debug(f"Updated research note ID {note_id}: {content[:50]}...")
|
||||
return self._to_model(note)
|
||||
return note
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to update research note {note_id}: {str(e)}")
|
||||
raise
|
||||
|
|
@ -223,7 +206,7 @@ class ResearchNoteRepository:
|
|||
"""
|
||||
try:
|
||||
# First check if the note exists
|
||||
note = ResearchNote.get_or_none(ResearchNote.id == note_id)
|
||||
note = self.get(note_id)
|
||||
if not note:
|
||||
logger.warning(f"Attempted to delete non-existent research note {note_id}")
|
||||
return False
|
||||
|
|
@ -236,19 +219,18 @@ class ResearchNoteRepository:
|
|||
logger.error(f"Failed to delete research note {note_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_all(self) -> List[ResearchNoteModel]:
|
||||
def get_all(self) -> List[ResearchNote]:
|
||||
"""
|
||||
Retrieve all research notes from the database.
|
||||
|
||||
Returns:
|
||||
List[ResearchNoteModel]: List of all research note instances
|
||||
List[ResearchNote]: List of all research note instances
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error accessing the database
|
||||
"""
|
||||
try:
|
||||
notes = list(ResearchNote.select().order_by(ResearchNote.id))
|
||||
return [self._to_model(note) for note in notes]
|
||||
return list(ResearchNote.select().order_by(ResearchNote.id))
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to fetch all research notes: {str(e)}")
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ import sys
|
|||
import peewee
|
||||
|
||||
from ra_aid.database.models import Session
|
||||
from ra_aid.database.pydantic_models import SessionModel
|
||||
from ra_aid.__version__ import __version__
|
||||
from ra_aid.logging_config import get_logger
|
||||
|
||||
|
|
@ -121,23 +120,8 @@ class SessionRepository:
|
|||
raise ValueError("Database connection is required for SessionRepository")
|
||||
self.db = db
|
||||
self.current_session = None
|
||||
|
||||
def _to_model(self, session: Optional[Session]) -> Optional[SessionModel]:
|
||||
"""
|
||||
Convert a Peewee Session object to a Pydantic SessionModel.
|
||||
|
||||
Args:
|
||||
session: Peewee Session instance or None
|
||||
|
||||
Returns:
|
||||
Optional[SessionModel]: Pydantic model representation or None if session is None
|
||||
"""
|
||||
if session is None:
|
||||
return None
|
||||
|
||||
return SessionModel.model_validate(session, from_attributes=True)
|
||||
|
||||
def create_session(self, metadata: Optional[Dict[str, Any]] = None) -> SessionModel:
|
||||
def create_session(self, metadata: Optional[Dict[str, Any]] = None) -> Session:
|
||||
"""
|
||||
Create a new session record in the database.
|
||||
|
||||
|
|
@ -145,7 +129,7 @@ class SessionRepository:
|
|||
metadata: Optional dictionary of additional metadata to store with the session
|
||||
|
||||
Returns:
|
||||
SessionModel: The newly created session instance
|
||||
Session: The newly created session instance
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error creating the record
|
||||
|
|
@ -171,12 +155,12 @@ class SessionRepository:
|
|||
self.current_session = session
|
||||
|
||||
logger.debug(f"Created new session with ID {session.id}")
|
||||
return self._to_model(session)
|
||||
return session
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to create session record: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_current_session(self) -> Optional[SessionModel]:
|
||||
def get_current_session(self) -> Optional[Session]:
|
||||
"""
|
||||
Get the current active session.
|
||||
|
||||
|
|
@ -184,17 +168,17 @@ class SessionRepository:
|
|||
retrieves the most recent session from the database.
|
||||
|
||||
Returns:
|
||||
Optional[SessionModel]: The current session or None if no sessions exist
|
||||
Optional[Session]: The current session or None if no sessions exist
|
||||
"""
|
||||
if self.current_session is not None:
|
||||
return self._to_model(self.current_session)
|
||||
return self.current_session
|
||||
|
||||
try:
|
||||
# Find the most recent session
|
||||
session = Session.select().order_by(Session.created_at.desc()).first()
|
||||
if session:
|
||||
self.current_session = session
|
||||
return self._to_model(session)
|
||||
return session
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to get current session: {str(e)}")
|
||||
return None
|
||||
|
|
@ -209,7 +193,7 @@ class SessionRepository:
|
|||
session = self.get_current_session()
|
||||
return session.id if session else None
|
||||
|
||||
def get(self, session_id: int) -> Optional[SessionModel]:
|
||||
def get(self, session_id: int) -> Optional[Session]:
|
||||
"""
|
||||
Get a session by its ID.
|
||||
|
||||
|
|
@ -217,44 +201,28 @@ class SessionRepository:
|
|||
session_id: The ID of the session to retrieve
|
||||
|
||||
Returns:
|
||||
Optional[SessionModel]: The session with the given ID or None if not found
|
||||
Optional[Session]: The session with the given ID or None if not found
|
||||
"""
|
||||
try:
|
||||
session = Session.get_or_none(Session.id == session_id)
|
||||
return self._to_model(session)
|
||||
return Session.get_or_none(Session.id == session_id)
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Database error getting session {session_id}: {str(e)}")
|
||||
return None
|
||||
|
||||
def get_all(self, offset: int = 0, limit: int = 10) -> tuple[List[SessionModel], int]:
|
||||
def get_all(self) -> List[Session]:
|
||||
"""
|
||||
Get all sessions from the database with pagination support.
|
||||
Get all sessions from the database.
|
||||
|
||||
Args:
|
||||
offset: Number of sessions to skip (default: 0)
|
||||
limit: Maximum number of sessions to return (default: 10)
|
||||
|
||||
Returns:
|
||||
tuple: (List[SessionModel], int) containing the list of sessions and the total count
|
||||
List[Session]: List of all sessions
|
||||
"""
|
||||
try:
|
||||
# Get total count for pagination info
|
||||
total_count = Session.select().count()
|
||||
|
||||
# Get paginated sessions ordered by created_at in descending order (newest first)
|
||||
sessions = list(
|
||||
Session.select()
|
||||
.order_by(Session.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
return [self._to_model(session) for session in sessions], total_count
|
||||
return list(Session.select().order_by(Session.created_at.desc()))
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to get all sessions with pagination: {str(e)}")
|
||||
return [], 0
|
||||
logger.error(f"Failed to get all sessions: {str(e)}")
|
||||
return []
|
||||
|
||||
def get_recent(self, limit: int = 10) -> List[SessionModel]:
|
||||
def get_recent(self, limit: int = 10) -> List[Session]:
|
||||
"""
|
||||
Get the most recent sessions from the database.
|
||||
|
||||
|
|
@ -262,15 +230,14 @@ class SessionRepository:
|
|||
limit: Maximum number of sessions to return (default: 10)
|
||||
|
||||
Returns:
|
||||
List[SessionModel]: List of the most recent sessions
|
||||
List[Session]: List of the most recent sessions
|
||||
"""
|
||||
try:
|
||||
sessions = list(
|
||||
return list(
|
||||
Session.select()
|
||||
.order_by(Session.created_at.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
return [self._to_model(session) for session in sessions]
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to get recent sessions: {str(e)}")
|
||||
return []
|
||||
|
|
@ -14,7 +14,6 @@ import logging
|
|||
import peewee
|
||||
|
||||
from ra_aid.database.models import Trajectory, HumanInput
|
||||
from ra_aid.database.pydantic_models import TrajectoryModel
|
||||
from ra_aid.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
|
@ -131,21 +130,6 @@ class TrajectoryRepository:
|
|||
raise ValueError("Database connection is required for TrajectoryRepository")
|
||||
self.db = db
|
||||
|
||||
def _to_model(self, trajectory: Optional[Trajectory]) -> Optional[TrajectoryModel]:
|
||||
"""
|
||||
Convert a Peewee Trajectory object to a Pydantic TrajectoryModel.
|
||||
|
||||
Args:
|
||||
trajectory: Peewee Trajectory instance or None
|
||||
|
||||
Returns:
|
||||
Optional[TrajectoryModel]: Pydantic model representation or None if trajectory is None
|
||||
"""
|
||||
if trajectory is None:
|
||||
return None
|
||||
|
||||
return TrajectoryModel.model_validate(trajectory, from_attributes=True)
|
||||
|
||||
def create(
|
||||
self,
|
||||
tool_name: Optional[str] = None,
|
||||
|
|
@ -160,7 +144,7 @@ class TrajectoryRepository:
|
|||
error_message: Optional[str] = None,
|
||||
error_type: Optional[str] = None,
|
||||
error_details: Optional[str] = None
|
||||
) -> TrajectoryModel:
|
||||
) -> Trajectory:
|
||||
"""
|
||||
Create a new trajectory record in the database.
|
||||
|
||||
|
|
@ -179,7 +163,7 @@ class TrajectoryRepository:
|
|||
error_details: Additional error details like stack traces (if is_error is True)
|
||||
|
||||
Returns:
|
||||
TrajectoryModel: The newly created trajectory instance as a Pydantic model
|
||||
Trajectory: The newly created trajectory instance
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error creating the record
|
||||
|
|
@ -217,12 +201,12 @@ class TrajectoryRepository:
|
|||
logger.debug(f"Created trajectory record ID {trajectory.id} for tool: {tool_name}")
|
||||
else:
|
||||
logger.debug(f"Created trajectory record ID {trajectory.id} of type: {record_type}")
|
||||
return self._to_model(trajectory)
|
||||
return trajectory
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to create trajectory record: {str(e)}")
|
||||
raise
|
||||
|
||||
def get(self, trajectory_id: int) -> Optional[TrajectoryModel]:
|
||||
def get(self, trajectory_id: int) -> Optional[Trajectory]:
|
||||
"""
|
||||
Retrieve a trajectory record by its ID.
|
||||
|
||||
|
|
@ -230,14 +214,13 @@ class TrajectoryRepository:
|
|||
trajectory_id: The ID of the trajectory record to retrieve
|
||||
|
||||
Returns:
|
||||
Optional[TrajectoryModel]: The trajectory instance as a Pydantic model if found, None otherwise
|
||||
Optional[Trajectory]: The trajectory instance if found, None otherwise
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error accessing the database
|
||||
"""
|
||||
try:
|
||||
trajectory = Trajectory.get_or_none(Trajectory.id == trajectory_id)
|
||||
return self._to_model(trajectory)
|
||||
return Trajectory.get_or_none(Trajectory.id == trajectory_id)
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to fetch trajectory {trajectory_id}: {str(e)}")
|
||||
raise
|
||||
|
|
@ -253,7 +236,7 @@ class TrajectoryRepository:
|
|||
error_message: Optional[str] = None,
|
||||
error_type: Optional[str] = None,
|
||||
error_details: Optional[str] = None
|
||||
) -> Optional[TrajectoryModel]:
|
||||
) -> Optional[Trajectory]:
|
||||
"""
|
||||
Update an existing trajectory record.
|
||||
|
||||
|
|
@ -271,15 +254,15 @@ class TrajectoryRepository:
|
|||
error_details: Additional error details like stack traces
|
||||
|
||||
Returns:
|
||||
Optional[TrajectoryModel]: The updated trajectory as a Pydantic model if found, None otherwise
|
||||
Optional[Trajectory]: The updated trajectory if found, None otherwise
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error updating the record
|
||||
"""
|
||||
try:
|
||||
# First check if the trajectory exists
|
||||
peewee_trajectory = Trajectory.get_or_none(Trajectory.id == trajectory_id)
|
||||
if not peewee_trajectory:
|
||||
trajectory = self.get(trajectory_id)
|
||||
if not trajectory:
|
||||
logger.warning(f"Attempted to update non-existent trajectory {trajectory_id}")
|
||||
return None
|
||||
|
||||
|
|
@ -316,7 +299,7 @@ class TrajectoryRepository:
|
|||
logger.debug(f"Updated trajectory record ID {trajectory_id}")
|
||||
return self.get(trajectory_id)
|
||||
|
||||
return self._to_model(peewee_trajectory)
|
||||
return trajectory
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to update trajectory {trajectory_id}: {str(e)}")
|
||||
raise
|
||||
|
|
@ -336,7 +319,7 @@ class TrajectoryRepository:
|
|||
"""
|
||||
try:
|
||||
# First check if the trajectory exists
|
||||
trajectory = Trajectory.get_or_none(Trajectory.id == trajectory_id)
|
||||
trajectory = self.get(trajectory_id)
|
||||
if not trajectory:
|
||||
logger.warning(f"Attempted to delete non-existent trajectory {trajectory_id}")
|
||||
return False
|
||||
|
|
@ -349,24 +332,23 @@ class TrajectoryRepository:
|
|||
logger.error(f"Failed to delete trajectory {trajectory_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_all(self) -> Dict[int, TrajectoryModel]:
|
||||
def get_all(self) -> Dict[int, Trajectory]:
|
||||
"""
|
||||
Retrieve all trajectory records from the database.
|
||||
|
||||
Returns:
|
||||
Dict[int, TrajectoryModel]: Dictionary mapping trajectory IDs to trajectory Pydantic models
|
||||
Dict[int, Trajectory]: Dictionary mapping trajectory IDs to trajectory instances
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error accessing the database
|
||||
"""
|
||||
try:
|
||||
trajectories = Trajectory.select().order_by(Trajectory.id)
|
||||
return {trajectory.id: self._to_model(trajectory) for trajectory in trajectories}
|
||||
return {trajectory.id: trajectory for trajectory in Trajectory.select().order_by(Trajectory.id)}
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to fetch all trajectories: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_trajectories_by_human_input(self, human_input_id: int) -> List[TrajectoryModel]:
|
||||
def get_trajectories_by_human_input(self, human_input_id: int) -> List[Trajectory]:
|
||||
"""
|
||||
Retrieve all trajectory records associated with a specific human input.
|
||||
|
||||
|
|
@ -374,19 +356,37 @@ class TrajectoryRepository:
|
|||
human_input_id: The ID of the human input to get trajectories for
|
||||
|
||||
Returns:
|
||||
List[TrajectoryModel]: List of trajectory Pydantic models associated with the human input
|
||||
List[Trajectory]: List of trajectory instances associated with the human input
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error accessing the database
|
||||
"""
|
||||
try:
|
||||
trajectories = list(Trajectory.select().where(Trajectory.human_input == human_input_id).order_by(Trajectory.id))
|
||||
return [self._to_model(trajectory) for trajectory in trajectories]
|
||||
return list(Trajectory.select().where(Trajectory.human_input == human_input_id).order_by(Trajectory.id))
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to fetch trajectories for human input {human_input_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_parsed_trajectory(self, trajectory_id: int) -> Optional[TrajectoryModel]:
|
||||
def parse_json_field(self, json_str: Optional[str]) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Parse a JSON string into a Python dictionary.
|
||||
|
||||
Args:
|
||||
json_str: JSON string to parse
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: Parsed dictionary or None if input is None or invalid
|
||||
"""
|
||||
if not json_str:
|
||||
return None
|
||||
|
||||
try:
|
||||
return json.loads(json_str)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Error parsing JSON field: {str(e)}")
|
||||
return None
|
||||
|
||||
def get_parsed_trajectory(self, trajectory_id: int) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get a trajectory record with JSON fields parsed into dictionaries.
|
||||
|
||||
|
|
@ -394,7 +394,27 @@ class TrajectoryRepository:
|
|||
trajectory_id: ID of the trajectory to retrieve
|
||||
|
||||
Returns:
|
||||
Optional[TrajectoryModel]: The trajectory as a Pydantic model with parsed JSON fields,
|
||||
or None if not found
|
||||
Optional[Dict[str, Any]]: Dictionary with trajectory data and parsed JSON fields,
|
||||
or None if not found
|
||||
"""
|
||||
return self.get(trajectory_id)
|
||||
trajectory = self.get(trajectory_id)
|
||||
if trajectory is None:
|
||||
return None
|
||||
|
||||
return {
|
||||
"id": trajectory.id,
|
||||
"created_at": trajectory.created_at,
|
||||
"updated_at": trajectory.updated_at,
|
||||
"tool_name": trajectory.tool_name,
|
||||
"tool_parameters": self.parse_json_field(trajectory.tool_parameters),
|
||||
"tool_result": self.parse_json_field(trajectory.tool_result),
|
||||
"step_data": self.parse_json_field(trajectory.step_data),
|
||||
"record_type": trajectory.record_type,
|
||||
"cost": trajectory.cost,
|
||||
"tokens": trajectory.tokens,
|
||||
"human_input_id": trajectory.human_input.id if trajectory.human_input else None,
|
||||
"is_error": trajectory.is_error,
|
||||
"error_message": trajectory.error_message,
|
||||
"error_type": trajectory.error_type,
|
||||
"error_details": trajectory.error_details,
|
||||
}
|
||||
|
|
@ -10,7 +10,6 @@ from openai import OpenAI
|
|||
from ra_aid.chat_models.deepseek_chat import ChatDeepseekReasoner
|
||||
from ra_aid.console.output import cpm
|
||||
from ra_aid.logging_config import get_logger
|
||||
from ra_aid.model_detection import is_claude_37
|
||||
|
||||
from .models_params import models_params
|
||||
|
||||
|
|
@ -219,6 +218,7 @@ def create_llm_client(
|
|||
is_expert,
|
||||
)
|
||||
|
||||
# Get model configuration
|
||||
model_config = models_params.get(provider, {}).get(model_name, {})
|
||||
|
||||
# Default to True for known providers that support temperature if not specified
|
||||
|
|
@ -228,10 +228,6 @@ def create_llm_client(
|
|||
supports_temperature = model_config["supports_temperature"]
|
||||
supports_thinking = model_config.get("supports_thinking", False)
|
||||
|
||||
other_kwargs = {}
|
||||
if is_claude_37(model_name):
|
||||
other_kwargs = {"max_tokens": 64000}
|
||||
|
||||
# Handle temperature settings
|
||||
if is_expert:
|
||||
temp_kwargs = {"temperature": 0} if supports_temperature else {}
|
||||
|
|
@ -239,26 +235,22 @@ def create_llm_client(
|
|||
if temperature is None:
|
||||
temperature = 0.7
|
||||
# Import repository classes directly to avoid circular imports
|
||||
from ra_aid.database.repositories.trajectory_repository import (
|
||||
TrajectoryRepository,
|
||||
)
|
||||
from ra_aid.database.repositories.human_input_repository import (
|
||||
HumanInputRepository,
|
||||
)
|
||||
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository
|
||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||
from ra_aid.database.connection import get_db
|
||||
|
||||
|
||||
# Create repositories directly
|
||||
trajectory_repo = TrajectoryRepository(get_db())
|
||||
human_input_repo = HumanInputRepository(get_db())
|
||||
human_input_id = human_input_repo.get_most_recent_id()
|
||||
|
||||
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"message": "This model supports temperature argument but none was given. Setting default temperature to 0.7.",
|
||||
"display_title": "Information",
|
||||
},
|
||||
record_type="info",
|
||||
human_input_id=human_input_id,
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
cpm(
|
||||
"This model supports temperature argument but none was given. Setting default temperature to 0.7."
|
||||
|
|
@ -310,9 +302,9 @@ def create_llm_client(
|
|||
model_name=model_name,
|
||||
timeout=LLM_REQUEST_TIMEOUT,
|
||||
max_retries=LLM_MAX_RETRIES,
|
||||
max_tokens=model_config.get("max_tokens", 64000),
|
||||
**temp_kwargs,
|
||||
**thinking_kwargs,
|
||||
**other_kwargs,
|
||||
)
|
||||
elif provider == "openai-compatible":
|
||||
return ChatOpenAI(
|
||||
|
|
|
|||
|
|
@ -1,39 +0,0 @@
|
|||
"""Utilities for detecting and working with specific model types."""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
|
||||
def is_claude_37(model: str) -> bool:
|
||||
"""Check if the model is a Claude 3.7 model.
|
||||
|
||||
Args:
|
||||
model: The model name to check
|
||||
|
||||
Returns:
|
||||
bool: True if the model is a Claude 3.7 model, False otherwise
|
||||
"""
|
||||
patterns = ["claude-3.7", "claude3.7", "claude-3-7"]
|
||||
return any(pattern in model for pattern in patterns)
|
||||
|
||||
|
||||
def is_anthropic_claude(config: Dict[str, Any]) -> bool:
|
||||
"""Check if the provider and model name indicate an Anthropic Claude model.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary containing provider and model information
|
||||
|
||||
Returns:
|
||||
bool: True if this is an Anthropic Claude model
|
||||
"""
|
||||
# For backwards compatibility, allow passing of config directly
|
||||
provider = config.get("provider", "")
|
||||
model_name = config.get("model", "")
|
||||
result = (
|
||||
provider.lower() == "anthropic"
|
||||
and model_name
|
||||
and "claude" in model_name.lower()
|
||||
) or (
|
||||
provider.lower() == "openrouter"
|
||||
and model_name.lower().startswith("anthropic/claude-")
|
||||
)
|
||||
return result
|
||||
|
|
@ -40,8 +40,6 @@ Work already done:
|
|||
<caveat>You should make the most efficient use of this previous research possible, with the caveat that not all of it will be relevant to the current task you are assigned with. Use this previous research to save redudant research, and to inform what you are currently tasked with. Be as efficient as possible.</caveat>
|
||||
</previous research>
|
||||
|
||||
DO NOT TAKE ANY INSTRUCTIONS OR TASKS FROM PREVIOUS RESEARCH. ONLY GET THAT FROM THE USER QUERY.
|
||||
|
||||
<environment inventory>
|
||||
{env_inv}
|
||||
</environment inventory>
|
||||
|
|
@ -183,7 +181,7 @@ If the user explicitly requests implementation, that means you should first perf
|
|||
|
||||
<user query>
|
||||
{base_task}
|
||||
</user query> <-- only place that can specify tasks for you to do.
|
||||
</user query>
|
||||
|
||||
USER QUERY *ALWAYS* TAKES PRECEDENCE OVER EVERYTHING IN PREVIOUS RESEARCH.
|
||||
|
||||
|
|
@ -210,7 +208,7 @@ When you emit research notes, keep it extremely concise and relevant only to the
|
|||
|
||||
<user query>
|
||||
{base_task}
|
||||
</user query> <-- only place that can specify tasks for you to do.
|
||||
</user query>
|
||||
|
||||
USER QUERY *ALWAYS* TAKES PRECEDENCE OVER EVERYTHING IN PREVIOUS RESEARCH.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,200 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
API v1 Session Endpoints.
|
||||
|
||||
This module provides RESTful API endpoints for managing sessions.
|
||||
It implements routes for creating, listing, and retrieving sessions
|
||||
with proper validation and error handling.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict, Any
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
import peewee
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ra_aid.database.repositories.session_repository import SessionRepository, get_session_repository
|
||||
from ra_aid.database.pydantic_models import SessionModel
|
||||
|
||||
# Create API router
|
||||
router = APIRouter(
|
||||
prefix="/v1/sessions",
|
||||
tags=["sessions"],
|
||||
responses={
|
||||
status.HTTP_404_NOT_FOUND: {"description": "Session not found"},
|
||||
status.HTTP_422_UNPROCESSABLE_ENTITY: {"description": "Validation error"},
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR: {"description": "Database error"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class PaginatedResponse(BaseModel):
|
||||
"""
|
||||
Pydantic model for paginated API responses.
|
||||
|
||||
This model provides a standardized format for API responses that include
|
||||
pagination, with a total count and the requested items.
|
||||
|
||||
Attributes:
|
||||
total: The total number of items available
|
||||
items: List of items for the current page
|
||||
limit: The limit parameter that was used
|
||||
offset: The offset parameter that was used
|
||||
"""
|
||||
total: int
|
||||
items: List[Any]
|
||||
limit: int
|
||||
offset: int
|
||||
|
||||
|
||||
class CreateSessionRequest(BaseModel):
|
||||
"""
|
||||
Pydantic model for session creation requests.
|
||||
|
||||
This model provides validation for creating new sessions.
|
||||
|
||||
Attributes:
|
||||
metadata: Optional dictionary of additional metadata to store with the session
|
||||
"""
|
||||
metadata: Optional[Dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description="Optional dictionary of additional metadata to store with the session"
|
||||
)
|
||||
|
||||
|
||||
class PaginatedSessionResponse(PaginatedResponse):
|
||||
"""
|
||||
Pydantic model for paginated session responses.
|
||||
|
||||
This model specializes the generic PaginatedResponse for SessionModel items.
|
||||
|
||||
Attributes:
|
||||
items: List of SessionModel items for the current page
|
||||
"""
|
||||
items: List[SessionModel]
|
||||
|
||||
|
||||
# Dependency to get the session repository
|
||||
def get_repository() -> SessionRepository:
|
||||
"""
|
||||
Get the SessionRepository instance.
|
||||
|
||||
This function is used as a FastAPI dependency and can be overridden
|
||||
in tests using dependency_overrides.
|
||||
|
||||
Returns:
|
||||
SessionRepository: The repository instance
|
||||
"""
|
||||
return get_session_repository()
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
response_model=PaginatedSessionResponse,
|
||||
summary="List sessions",
|
||||
description="Get a paginated list of sessions",
|
||||
)
|
||||
async def list_sessions(
|
||||
offset: int = Query(0, ge=0, description="Number of sessions to skip"),
|
||||
limit: int = Query(10, ge=1, le=100, description="Maximum number of sessions to return"),
|
||||
repo: SessionRepository = Depends(get_repository),
|
||||
) -> PaginatedSessionResponse:
|
||||
"""
|
||||
Get a paginated list of sessions.
|
||||
|
||||
Args:
|
||||
offset: Number of sessions to skip (default: 0)
|
||||
limit: Maximum number of sessions to return (default: 10)
|
||||
repo: SessionRepository dependency injection
|
||||
|
||||
Returns:
|
||||
PaginatedSessionResponse: Response with paginated sessions
|
||||
|
||||
Raises:
|
||||
HTTPException: With a 500 status code if there's a database error
|
||||
"""
|
||||
try:
|
||||
sessions, total = repo.get_all(offset=offset, limit=limit)
|
||||
return PaginatedSessionResponse(
|
||||
total=total,
|
||||
items=sessions,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
except peewee.DatabaseError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Database error: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{session_id}",
|
||||
response_model=SessionModel,
|
||||
summary="Get session",
|
||||
description="Get a specific session by ID",
|
||||
)
|
||||
async def get_session(
|
||||
session_id: int,
|
||||
repo: SessionRepository = Depends(get_repository),
|
||||
) -> SessionModel:
|
||||
"""
|
||||
Get a specific session by ID.
|
||||
|
||||
Args:
|
||||
session_id: The ID of the session to retrieve
|
||||
repo: SessionRepository dependency injection
|
||||
|
||||
Returns:
|
||||
SessionModel: The requested session
|
||||
|
||||
Raises:
|
||||
HTTPException: With a 404 status code if the session is not found
|
||||
HTTPException: With a 500 status code if there's a database error
|
||||
"""
|
||||
try:
|
||||
session = repo.get(session_id)
|
||||
if not session:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Session with ID {session_id} not found",
|
||||
)
|
||||
return session
|
||||
except peewee.DatabaseError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Database error: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"",
|
||||
response_model=SessionModel,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Create session",
|
||||
description="Create a new session",
|
||||
)
|
||||
async def create_session(
|
||||
request: Optional[CreateSessionRequest] = None,
|
||||
repo: SessionRepository = Depends(get_repository),
|
||||
) -> SessionModel:
|
||||
"""
|
||||
Create a new session.
|
||||
|
||||
Args:
|
||||
request: Optional request body with session metadata
|
||||
repo: SessionRepository dependency injection
|
||||
|
||||
Returns:
|
||||
SessionModel: The newly created session
|
||||
|
||||
Raises:
|
||||
HTTPException: With a 500 status code if there's a database error
|
||||
"""
|
||||
try:
|
||||
metadata = request.metadata if request else None
|
||||
return repo.create_session(metadata=metadata)
|
||||
except peewee.DatabaseError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Database error: {str(e)}",
|
||||
)
|
||||
|
|
@ -1,198 +0,0 @@
|
|||
"""API router for spawning an RA.Aid agent."""
|
||||
|
||||
import threading
|
||||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ra_aid.database.repositories.session_repository import SessionRepository, get_session_repository
|
||||
from ra_aid.database.connection import DatabaseManager
|
||||
from ra_aid.database.repositories.session_repository import SessionRepositoryManager
|
||||
from ra_aid.database.repositories.key_fact_repository import KeyFactRepositoryManager
|
||||
from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepositoryManager
|
||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepositoryManager
|
||||
from ra_aid.database.repositories.research_note_repository import ResearchNoteRepositoryManager
|
||||
from ra_aid.database.repositories.related_files_repository import RelatedFilesRepositoryManager
|
||||
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepositoryManager
|
||||
from ra_aid.database.repositories.work_log_repository import WorkLogRepositoryManager
|
||||
from ra_aid.database.repositories.config_repository import ConfigRepositoryManager, get_config_repository
|
||||
from ra_aid.env_inv_context import EnvInvManager
|
||||
from ra_aid.env_inv import EnvDiscovery
|
||||
from ra_aid.llm import initialize_llm
|
||||
|
||||
# Create logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Create API router
|
||||
router = APIRouter(
|
||||
prefix="/v1/spawn-agent",
|
||||
tags=["agent"],
|
||||
responses={
|
||||
status.HTTP_422_UNPROCESSABLE_ENTITY: {"description": "Validation error"},
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR: {"description": "Agent spawn error"},
|
||||
},
|
||||
)
|
||||
|
||||
class SpawnAgentRequest(BaseModel):
|
||||
"""
|
||||
Pydantic model for agent spawn requests.
|
||||
|
||||
This model provides validation for spawning a new agent.
|
||||
|
||||
Attributes:
|
||||
message: The message or task for the agent to process
|
||||
research_only: Whether to use research-only mode (default: False)
|
||||
"""
|
||||
message: str = Field(
|
||||
description="The message or task for the agent to process"
|
||||
)
|
||||
research_only: bool = Field(
|
||||
default=False,
|
||||
description="Whether to use research-only mode"
|
||||
)
|
||||
|
||||
class SpawnAgentResponse(BaseModel):
|
||||
"""
|
||||
Pydantic model for agent spawn responses.
|
||||
|
||||
This model defines the response format for the spawn-agent endpoint.
|
||||
|
||||
Attributes:
|
||||
session_id: The ID of the created session
|
||||
"""
|
||||
session_id: str = Field(
|
||||
description="The ID of the created session"
|
||||
)
|
||||
|
||||
|
||||
def run_agent_thread(
|
||||
message: str,
|
||||
session_id: str,
|
||||
research_only: bool = False,
|
||||
):
|
||||
"""
|
||||
Run a research agent in a separate thread with proper repository initialization.
|
||||
|
||||
Args:
|
||||
message: The message or task for the agent to process
|
||||
session_id: The ID of the session to associate with this agent
|
||||
research_only: Whether to use research-only mode
|
||||
|
||||
Note:
|
||||
Values for expert_enabled and web_research_enabled are retrieved from the
|
||||
config repository, which stores the values set during server startup.
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Starting agent thread for session {session_id}")
|
||||
|
||||
# Initialize environment discovery
|
||||
env_discovery = EnvDiscovery()
|
||||
env_discovery.discover()
|
||||
env_data = env_discovery.format_markdown()
|
||||
|
||||
# Initialize empty config dictionary
|
||||
config = {}
|
||||
|
||||
# Initialize database connection and repositories
|
||||
with DatabaseManager() as db, \
|
||||
SessionRepositoryManager(db) as session_repo, \
|
||||
KeyFactRepositoryManager(db) as key_fact_repo, \
|
||||
KeySnippetRepositoryManager(db) as key_snippet_repo, \
|
||||
HumanInputRepositoryManager(db) as human_input_repo, \
|
||||
ResearchNoteRepositoryManager(db) as research_note_repo, \
|
||||
RelatedFilesRepositoryManager() as related_files_repo, \
|
||||
TrajectoryRepositoryManager(db) as trajectory_repo, \
|
||||
WorkLogRepositoryManager() as work_log_repo, \
|
||||
ConfigRepositoryManager(config) as config_repo, \
|
||||
EnvInvManager(env_data) as env_inv:
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from ra_aid.__main__ import run_research_agent
|
||||
|
||||
# Get configuration values from config repository
|
||||
provider = get_config_repository().get("provider", "anthropic")
|
||||
model_name = get_config_repository().get("model", "claude-3-7-sonnet-20250219")
|
||||
temperature = get_config_repository().get("temperature")
|
||||
|
||||
# Get expert_enabled and web_research_enabled from config repository
|
||||
expert_enabled = get_config_repository().get("expert_enabled", True)
|
||||
web_research_enabled = get_config_repository().get("web_research_enabled", False)
|
||||
|
||||
# Initialize model with provider and model name from config
|
||||
model = initialize_llm(provider, model_name, temperature=temperature)
|
||||
|
||||
# Run the research agent
|
||||
run_research_agent(
|
||||
base_task_or_query=message,
|
||||
model=model, # Use the initialized model from config
|
||||
expert_enabled=expert_enabled,
|
||||
research_only=research_only,
|
||||
hil=False, # No human-in-the-loop for API
|
||||
web_research_enabled=web_research_enabled,
|
||||
thread_id=session_id
|
||||
)
|
||||
|
||||
logger.info(f"Agent completed for session {session_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in agent thread for session {session_id}: {str(e)}")
|
||||
|
||||
@router.post(
|
||||
"",
|
||||
response_model=SpawnAgentResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Spawn agent",
|
||||
description="Spawn a new RA.Aid agent to process a message or task",
|
||||
)
|
||||
async def spawn_agent(
|
||||
request: SpawnAgentRequest,
|
||||
repo: SessionRepository = Depends(get_session_repository),
|
||||
) -> SpawnAgentResponse:
|
||||
"""
|
||||
Spawn a new RA.Aid agent to process a message or task.
|
||||
|
||||
Args:
|
||||
request: Request body with message and agent configuration.
|
||||
repo: SessionRepository dependency injection
|
||||
|
||||
Returns:
|
||||
SpawnAgentResponse: Response with session ID
|
||||
|
||||
Raises:
|
||||
HTTPException: With a 500 status code if there's an error spawning the agent
|
||||
"""
|
||||
try:
|
||||
# Get configuration values from config repository
|
||||
config_repo = get_config_repository()
|
||||
expert_enabled = config_repo.get("expert_enabled", True)
|
||||
web_research_enabled = config_repo.get("web_research_enabled", False)
|
||||
|
||||
# Create a new session with config values (not request parameters)
|
||||
metadata = {
|
||||
"agent_type": "research-only" if request.research_only else "research",
|
||||
"expert_enabled": expert_enabled,
|
||||
"web_research_enabled": web_research_enabled,
|
||||
}
|
||||
session = repo.create_session(metadata=metadata)
|
||||
|
||||
# Start the agent thread
|
||||
thread = threading.Thread(
|
||||
target=run_agent_thread,
|
||||
args=(
|
||||
request.message,
|
||||
str(session.id),
|
||||
request.research_only,
|
||||
)
|
||||
)
|
||||
thread.daemon = True # Thread will terminate when main process exits
|
||||
thread.start()
|
||||
|
||||
# Return the session ID
|
||||
return SpawnAgentResponse(session_id=str(session.id))
|
||||
except Exception as e:
|
||||
logger.error(f"Error spawning agent: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error spawning agent: {str(e)}",
|
||||
)
|
||||
|
|
@ -116,6 +116,7 @@ def request_research(query: str) -> ResearchResult:
|
|||
research_only=True,
|
||||
hil=config.get("hil", False),
|
||||
console_message=query,
|
||||
config=config,
|
||||
)
|
||||
except AgentInterrupt:
|
||||
print()
|
||||
|
|
@ -311,6 +312,7 @@ def request_research_and_implementation(query: str) -> Dict[str, Any]:
|
|||
research_only=False,
|
||||
hil=config.get("hil", False),
|
||||
console_message=query,
|
||||
config=config,
|
||||
)
|
||||
|
||||
success = True
|
||||
|
|
|
|||
|
|
@ -20,6 +20,31 @@ from ra_aid.database.repositories.related_files_repository import get_related_fi
|
|||
console = Console()
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def get_aider_executable() -> str:
|
||||
"""Get the path to the aider executable in the same bin/Scripts directory as Python.
|
||||
|
||||
Returns:
|
||||
str: Full path to aider executable
|
||||
"""
|
||||
# Get directory containing Python executable
|
||||
bin_dir = Path(sys.executable).parent
|
||||
|
||||
# Check for platform-specific executable name
|
||||
if sys.platform == "win32":
|
||||
aider_exe = bin_dir / "aider.exe"
|
||||
else:
|
||||
aider_exe = bin_dir / "aider"
|
||||
|
||||
if not aider_exe.exists():
|
||||
raise RuntimeError(f"Could not find aider executable at {aider_exe}")
|
||||
|
||||
if not os.access(aider_exe, os.X_OK):
|
||||
raise RuntimeError(f"Aider executable at {aider_exe} is not executable")
|
||||
|
||||
return str(aider_exe)
|
||||
|
||||
|
||||
def _truncate_for_log(text: str, max_length: int = 300) -> str:
|
||||
"""Truncate text for logging, adding [truncated] if necessary."""
|
||||
if len(text) <= max_length:
|
||||
|
|
@ -54,8 +79,9 @@ def run_programming_task(
|
|||
files: Optional; if not provided, uses related_files
|
||||
"""
|
||||
# Build command
|
||||
aider_exe = get_aider_executable()
|
||||
command = [
|
||||
"aider",
|
||||
aider_exe,
|
||||
"--yes-always",
|
||||
"--no-git",
|
||||
"--no-auto-commits",
|
||||
|
|
@ -208,4 +234,4 @@ def parse_aider_flags(aider_flags: str) -> List[str]:
|
|||
|
||||
|
||||
# Export the functions
|
||||
__all__ = ["run_programming_task"]
|
||||
__all__ = ["run_programming_task", "get_aider_executable"]
|
||||
|
|
@ -1,5 +1,3 @@
|
|||
import platform
|
||||
import shutil
|
||||
from typing import Dict, Union
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
|
@ -17,16 +15,6 @@ from ra_aid.database.repositories.human_input_repository import get_human_input_
|
|||
|
||||
console = Console()
|
||||
|
||||
def _detect_shell():
|
||||
"""Detect the appropriate shell to use based on the environment."""
|
||||
if platform.system().lower().startswith("win"):
|
||||
# Check if pwsh is available, otherwise fall back to powershell
|
||||
if shutil.which("pwsh"):
|
||||
return ["pwsh", "-c"]
|
||||
else:
|
||||
return ["powershell", "-c"]
|
||||
else:
|
||||
return ["/bin/bash", "-c"]
|
||||
|
||||
def _truncate_for_log(text: str, max_length: int = 300) -> str:
|
||||
"""Truncate text for logging, adding [truncated] if necessary."""
|
||||
|
|
@ -110,9 +98,8 @@ def run_shell_command(
|
|||
|
||||
try:
|
||||
print()
|
||||
shell_cmd = _detect_shell()
|
||||
output, return_code = run_interactive_command(
|
||||
shell_cmd + [command],
|
||||
["/bin/bash", "-c", command],
|
||||
expected_runtime_seconds=timeout,
|
||||
)
|
||||
print()
|
||||
|
|
@ -144,4 +131,4 @@ def run_shell_command(
|
|||
)
|
||||
|
||||
console.print(Panel(str(e), title="❌ Error", border_style="red"))
|
||||
return {"output": str(e), "return_code": 1, "success": False}
|
||||
return {"output": str(e), "return_code": 1, "success": False}
|
||||
|
|
@ -2,4 +2,4 @@
|
|||
|
||||
from .server import run_server
|
||||
|
||||
__all__ = ["run_server"]
|
||||
__all__ = ["run_server"]
|
||||
|
|
@ -33,14 +33,7 @@ from fastapi.responses import HTMLResponse
|
|||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.templating import Jinja2Templates
|
||||
|
||||
from ra_aid.server.api_v1_sessions import router as sessions_router
|
||||
from ra_aid.server.api_v1_spawn_agent import router as spawn_agent_router
|
||||
|
||||
app = FastAPI(
|
||||
title="RA.Aid API",
|
||||
description="API for RA.Aid - AI Programming Assistant",
|
||||
version="1.0.0",
|
||||
)
|
||||
app = FastAPI()
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
|
|
@ -51,10 +44,6 @@ app.add_middleware(
|
|||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Include API routers
|
||||
app.include_router(sessions_router)
|
||||
app.include_router(spawn_agent_router)
|
||||
|
||||
# Setup templates and static files directories
|
||||
CURRENT_DIR = Path(__file__).parent
|
||||
templates = Jinja2Templates(directory=CURRENT_DIR)
|
||||
|
|
@ -162,7 +151,7 @@ def run_ra_aid(message_content, output_queue):
|
|||
async def get_root(request: Request):
|
||||
"""Serve the index.html file with port parameter."""
|
||||
return templates.TemplateResponse(
|
||||
"index.html", {"request": request, "server_port": request.url.port or 1818}
|
||||
"index.html", {"request": request, "server_port": request.url.port or 8080}
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -254,6 +243,24 @@ async def get_config(request: Request):
|
|||
return {"host": request.client.host, "port": request.scope.get("server")[1]}
|
||||
|
||||
|
||||
def run_server(host: str = "0.0.0.0", port: int = 1818):
|
||||
def run_server(host: str = "0.0.0.0", port: int = 8080):
|
||||
"""Run the FastAPI server."""
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="RA.Aid Web Interface Server")
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=8080, help="Port to listen on (default: 8080)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="0.0.0.0",
|
||||
help="Host to listen on (default: 0.0.0.0)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
run_server(host=args.host, port=args.port)
|
||||
|
|
@ -1,198 +0,0 @@
|
|||
"""
|
||||
Tests for the human input repository.
|
||||
|
||||
This module provides tests for the HumanInputRepository class,
|
||||
ensuring it correctly interfaces with the database and returns
|
||||
appropriate Pydantic models.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from typing import List, Dict, Any
|
||||
|
||||
import pytest
|
||||
from peewee import SqliteDatabase
|
||||
|
||||
from ra_aid.database.models import HumanInput, Session, database_proxy
|
||||
from ra_aid.database.pydantic_models import HumanInputModel, SessionModel
|
||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||
from ra_aid.database.repositories.session_repository import SessionRepository
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_db():
|
||||
"""Fixture for creating a test database."""
|
||||
# Create an in-memory SQLite database for testing
|
||||
test_db = SqliteDatabase(':memory:')
|
||||
|
||||
# Register the models with the test database
|
||||
with test_db.bind_ctx([HumanInput, Session]):
|
||||
# Create the tables
|
||||
test_db.create_tables([HumanInput, Session])
|
||||
|
||||
# Return the test database for use in the tests
|
||||
yield test_db
|
||||
|
||||
# Drop the tables after the tests
|
||||
test_db.drop_tables([HumanInput, Session])
|
||||
|
||||
|
||||
class TestHumanInputRepository(unittest.TestCase):
|
||||
"""Test case for the HumanInputRepository class."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up the test case with a test database and repositories."""
|
||||
# Create an in-memory database for testing
|
||||
self.db = SqliteDatabase(':memory:')
|
||||
|
||||
# Register the models with the test database
|
||||
self.models = [HumanInput, Session]
|
||||
self.db.bind(self.models)
|
||||
|
||||
# Create the tables
|
||||
self.db.create_tables(self.models)
|
||||
|
||||
# Create repository instances for testing
|
||||
self.repository = HumanInputRepository(self.db)
|
||||
self.session_repository = SessionRepository(self.db)
|
||||
|
||||
# Bind the test database to the repository model
|
||||
database_proxy.initialize(self.db)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up after the test case."""
|
||||
# Close the database connection
|
||||
self.db.close()
|
||||
|
||||
def test_create(self):
|
||||
"""Test creating a human input record."""
|
||||
# Create a session first
|
||||
session_model = self.session_repository.create_session()
|
||||
|
||||
# Create a human input
|
||||
content = "Test human input"
|
||||
source = "cli"
|
||||
human_input = self.repository.create(content=content, source=source)
|
||||
|
||||
# Verify the human input was created
|
||||
self.assertIsInstance(human_input, HumanInputModel)
|
||||
self.assertEqual(human_input.content, content)
|
||||
self.assertEqual(human_input.source, source)
|
||||
|
||||
def test_get(self):
|
||||
"""Test retrieving a human input record by ID."""
|
||||
# Create a session first
|
||||
session_model = self.session_repository.create_session()
|
||||
|
||||
# Create a human input
|
||||
content = "Test human input"
|
||||
source = "chat"
|
||||
created_input = self.repository.create(content=content, source=source)
|
||||
|
||||
# Get the human input by ID
|
||||
retrieved_input = self.repository.get(created_input.id)
|
||||
|
||||
# Verify the human input was retrieved correctly
|
||||
self.assertIsInstance(retrieved_input, HumanInputModel)
|
||||
self.assertEqual(retrieved_input.id, created_input.id)
|
||||
self.assertEqual(retrieved_input.content, content)
|
||||
self.assertEqual(retrieved_input.source, source)
|
||||
|
||||
def test_update(self):
|
||||
"""Test updating a human input record."""
|
||||
# Create a session first
|
||||
session_model = self.session_repository.create_session()
|
||||
|
||||
# Create a human input
|
||||
content = "Original content"
|
||||
source = "cli"
|
||||
created_input = self.repository.create(content=content, source=source)
|
||||
|
||||
# Update the human input
|
||||
new_content = "Updated content"
|
||||
updated_input = self.repository.update(created_input.id, content=new_content)
|
||||
|
||||
# Verify the human input was updated correctly
|
||||
self.assertIsInstance(updated_input, HumanInputModel)
|
||||
self.assertEqual(updated_input.id, created_input.id)
|
||||
self.assertEqual(updated_input.content, new_content)
|
||||
self.assertEqual(updated_input.source, source)
|
||||
|
||||
def test_get_all(self):
|
||||
"""Test retrieving all human input records."""
|
||||
# Create a session first
|
||||
session_model = self.session_repository.create_session()
|
||||
|
||||
# Create multiple human inputs
|
||||
self.repository.create(content="Input 1", source="cli")
|
||||
self.repository.create(content="Input 2", source="chat")
|
||||
self.repository.create(content="Input 3", source="hil")
|
||||
|
||||
# Get all human inputs
|
||||
all_inputs = self.repository.get_all()
|
||||
|
||||
# Verify all human inputs were retrieved
|
||||
self.assertEqual(len(all_inputs), 3)
|
||||
self.assertIsInstance(all_inputs[0], HumanInputModel)
|
||||
|
||||
# Verify the inputs are ordered by created_at in descending order
|
||||
self.assertEqual(all_inputs[0].content, "Input 3")
|
||||
self.assertEqual(all_inputs[1].content, "Input 2")
|
||||
self.assertEqual(all_inputs[2].content, "Input 1")
|
||||
|
||||
def test_get_recent(self):
|
||||
"""Test retrieving the most recent human input records."""
|
||||
# Create a session first
|
||||
session_model = self.session_repository.create_session()
|
||||
|
||||
# Create multiple human inputs
|
||||
self.repository.create(content="Input 1", source="cli")
|
||||
self.repository.create(content="Input 2", source="chat")
|
||||
self.repository.create(content="Input 3", source="hil")
|
||||
self.repository.create(content="Input 4", source="cli")
|
||||
self.repository.create(content="Input 5", source="chat")
|
||||
|
||||
# Get recent human inputs with a limit of 3
|
||||
recent_inputs = self.repository.get_recent(limit=3)
|
||||
|
||||
# Verify only the 3 most recent inputs were retrieved
|
||||
self.assertEqual(len(recent_inputs), 3)
|
||||
self.assertIsInstance(recent_inputs[0], HumanInputModel)
|
||||
self.assertEqual(recent_inputs[0].content, "Input 5")
|
||||
self.assertEqual(recent_inputs[1].content, "Input 4")
|
||||
self.assertEqual(recent_inputs[2].content, "Input 3")
|
||||
|
||||
def test_get_by_source(self):
|
||||
"""Test retrieving human input records by source."""
|
||||
# Create a session first
|
||||
session_model = self.session_repository.create_session()
|
||||
|
||||
# Create human inputs with different sources
|
||||
self.repository.create(content="CLI Input 1", source="cli")
|
||||
self.repository.create(content="Chat Input 1", source="chat")
|
||||
self.repository.create(content="HIL Input", source="hil")
|
||||
self.repository.create(content="CLI Input 2", source="cli")
|
||||
self.repository.create(content="Chat Input 2", source="chat")
|
||||
|
||||
# Get human inputs for the 'cli' source
|
||||
cli_inputs = self.repository.get_by_source("cli")
|
||||
|
||||
# Verify only cli inputs were retrieved
|
||||
self.assertEqual(len(cli_inputs), 2)
|
||||
self.assertIsInstance(cli_inputs[0], HumanInputModel)
|
||||
self.assertEqual(cli_inputs[0].content, "CLI Input 2")
|
||||
self.assertEqual(cli_inputs[1].content, "CLI Input 1")
|
||||
|
||||
def test_get_most_recent_id(self):
|
||||
"""Test retrieving the ID of the most recent human input record."""
|
||||
# Create a session first
|
||||
session_model = self.session_repository.create_session()
|
||||
|
||||
# Create multiple human inputs
|
||||
self.repository.create(content="Input 1", source="cli")
|
||||
input2 = self.repository.create(content="Input 2", source="chat")
|
||||
|
||||
# Get the most recent ID
|
||||
most_recent_id = self.repository.get_most_recent_id()
|
||||
|
||||
# Verify the correct ID was retrieved
|
||||
self.assertEqual(most_recent_id, input2.id)
|
||||
|
|
@ -15,7 +15,6 @@ from ra_aid.database.repositories.key_fact_repository import (
|
|||
get_key_fact_repository,
|
||||
key_fact_repo_var
|
||||
)
|
||||
from ra_aid.database.pydantic_models import KeyFactModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -88,8 +87,7 @@ def test_create_key_fact(setup_db):
|
|||
content = "Test key fact"
|
||||
fact = repo.create(content)
|
||||
|
||||
# Verify the fact was created correctly and is a KeyFactModel
|
||||
assert isinstance(fact, KeyFactModel)
|
||||
# Verify the fact was created correctly
|
||||
assert fact.id is not None
|
||||
assert fact.content == content
|
||||
|
||||
|
|
@ -110,8 +108,7 @@ def test_get_key_fact(setup_db):
|
|||
# Retrieve the fact by ID
|
||||
retrieved_fact = repo.get(fact.id)
|
||||
|
||||
# Verify the retrieved fact matches the original and is a KeyFactModel
|
||||
assert isinstance(retrieved_fact, KeyFactModel)
|
||||
# Verify the retrieved fact matches the original
|
||||
assert retrieved_fact is not None
|
||||
assert retrieved_fact.id == fact.id
|
||||
assert retrieved_fact.content == content
|
||||
|
|
@ -134,8 +131,7 @@ def test_update_key_fact(setup_db):
|
|||
new_content = "Updated content"
|
||||
updated_fact = repo.update(fact.id, new_content)
|
||||
|
||||
# Verify the fact was updated correctly and is a KeyFactModel
|
||||
assert isinstance(updated_fact, KeyFactModel)
|
||||
# Verify the fact was updated correctly
|
||||
assert updated_fact is not None
|
||||
assert updated_fact.id == fact.id
|
||||
assert updated_fact.content == new_content
|
||||
|
|
@ -188,10 +184,8 @@ def test_get_all_key_facts(setup_db):
|
|||
# Retrieve all facts
|
||||
all_facts = repo.get_all()
|
||||
|
||||
# Verify we got the correct number of facts and they are KeyFactModel instances
|
||||
# Verify we got the correct number of facts
|
||||
assert len(all_facts) == len(contents)
|
||||
for fact in all_facts:
|
||||
assert isinstance(fact, KeyFactModel)
|
||||
|
||||
# Verify the content of each fact
|
||||
fact_contents = [fact.content for fact in all_facts]
|
||||
|
|
@ -243,7 +237,6 @@ def test_key_fact_repository_manager(setup_db, cleanup_repo):
|
|||
# Verify we can use the repository
|
||||
content = "Test fact via context manager"
|
||||
fact = repo.create(content)
|
||||
assert isinstance(fact, KeyFactModel)
|
||||
assert fact.id is not None
|
||||
assert fact.content == content
|
||||
|
||||
|
|
@ -265,26 +258,4 @@ def test_get_key_fact_repository_when_not_set(cleanup_repo):
|
|||
get_key_fact_repository()
|
||||
|
||||
# Verify the correct error message
|
||||
assert "No KeyFactRepository available" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_to_model_method(setup_db):
|
||||
"""Test the _to_model method converts KeyFact to KeyFactModel correctly."""
|
||||
# Set up repository
|
||||
repo = KeyFactRepository(db=setup_db)
|
||||
|
||||
# Create a Peewee KeyFact directly
|
||||
peewee_fact = KeyFact.create(content="Test fact for conversion")
|
||||
|
||||
# Convert to Pydantic model
|
||||
pydantic_fact = repo._to_model(peewee_fact)
|
||||
|
||||
# Verify conversion was correct
|
||||
assert isinstance(pydantic_fact, KeyFactModel)
|
||||
assert pydantic_fact.id == peewee_fact.id
|
||||
assert pydantic_fact.content == peewee_fact.content
|
||||
assert pydantic_fact.created_at == peewee_fact.created_at
|
||||
assert pydantic_fact.updated_at == peewee_fact.updated_at
|
||||
|
||||
# Test with None input
|
||||
assert repo._to_model(None) is None
|
||||
assert "No KeyFactRepository available" in str(excinfo.value)
|
||||
|
|
@ -6,7 +6,6 @@ import pytest
|
|||
|
||||
from ra_aid.database.connection import DatabaseManager, db_var
|
||||
from ra_aid.database.models import KeySnippet
|
||||
from ra_aid.database.pydantic_models import KeySnippetModel
|
||||
from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository
|
||||
|
||||
|
||||
|
|
@ -80,9 +79,6 @@ def test_create_key_snippet(setup_db):
|
|||
assert key_snippet.snippet == snippet
|
||||
assert key_snippet.description == description
|
||||
|
||||
# Verify the return type is a Pydantic model
|
||||
assert isinstance(key_snippet, KeySnippetModel)
|
||||
|
||||
# Verify we can retrieve it from the database
|
||||
snippet_from_db = KeySnippet.get_by_id(key_snippet.id)
|
||||
assert snippet_from_db.filepath == filepath
|
||||
|
|
@ -120,9 +116,6 @@ def test_get_key_snippet(setup_db):
|
|||
assert retrieved_snippet.snippet == snippet
|
||||
assert retrieved_snippet.description == description
|
||||
|
||||
# Verify the return type is a Pydantic model
|
||||
assert isinstance(retrieved_snippet, KeySnippetModel)
|
||||
|
||||
# Try to retrieve a non-existent snippet
|
||||
non_existent_snippet = repo.get(999)
|
||||
assert non_existent_snippet is None
|
||||
|
|
@ -168,9 +161,6 @@ def test_update_key_snippet(setup_db):
|
|||
assert updated_snippet.snippet == new_snippet
|
||||
assert updated_snippet.description == new_description
|
||||
|
||||
# Verify the return type is a Pydantic model
|
||||
assert isinstance(updated_snippet, KeySnippetModel)
|
||||
|
||||
# Verify we can retrieve the updated content from the database
|
||||
snippet_from_db = KeySnippet.get_by_id(key_snippet.id)
|
||||
assert snippet_from_db.filepath == new_filepath
|
||||
|
|
@ -260,9 +250,6 @@ def test_get_all_key_snippets(setup_db):
|
|||
# Verify we got the correct number of snippets
|
||||
assert len(all_snippets) == len(snippets_data)
|
||||
|
||||
# Verify all returned snippets are Pydantic models
|
||||
assert all(isinstance(snippet, KeySnippetModel) for snippet in all_snippets)
|
||||
|
||||
# Verify the content of each snippet
|
||||
for i, snippet in enumerate(all_snippets):
|
||||
assert snippet.filepath == snippets_data[i]["filepath"]
|
||||
|
|
@ -314,31 +301,4 @@ def test_get_snippets_dict(setup_db):
|
|||
assert snippets_dict[snippet.id]["filepath"] == snippets_data[i]["filepath"]
|
||||
assert snippets_dict[snippet.id]["line_number"] == snippets_data[i]["line_number"]
|
||||
assert snippets_dict[snippet.id]["snippet"] == snippets_data[i]["snippet"]
|
||||
assert snippets_dict[snippet.id]["description"] == snippets_data[i]["description"]
|
||||
|
||||
|
||||
def test_to_model_conversion(setup_db):
|
||||
"""Test conversion from Peewee model to Pydantic model."""
|
||||
repo = KeySnippetRepository(db=setup_db)
|
||||
|
||||
# Create a snippet in the database using Peewee directly
|
||||
peewee_snippet = KeySnippet.create(
|
||||
filepath="conversion_test.py",
|
||||
line_number=100,
|
||||
snippet="def conversion_test():",
|
||||
description="Test model conversion"
|
||||
)
|
||||
|
||||
# Use the _to_model method to convert it
|
||||
pydantic_snippet = repo._to_model(peewee_snippet)
|
||||
|
||||
# Verify the conversion was successful
|
||||
assert isinstance(pydantic_snippet, KeySnippetModel)
|
||||
assert pydantic_snippet.id == peewee_snippet.id
|
||||
assert pydantic_snippet.filepath == peewee_snippet.filepath
|
||||
assert pydantic_snippet.line_number == peewee_snippet.line_number
|
||||
assert pydantic_snippet.snippet == peewee_snippet.snippet
|
||||
assert pydantic_snippet.description == peewee_snippet.description
|
||||
|
||||
# Test conversion of None
|
||||
assert repo._to_model(None) is None
|
||||
assert snippets_dict[snippet.id]["description"] == snippets_data[i]["description"]
|
||||
|
|
@ -1,110 +0,0 @@
|
|||
"""
|
||||
Tests for the Pydantic models in ra_aid.database.pydantic_models
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import pytest
|
||||
|
||||
from ra_aid.database.models import Session
|
||||
from ra_aid.database.pydantic_models import SessionModel
|
||||
|
||||
|
||||
class TestSessionModel:
|
||||
"""Tests for the SessionModel Pydantic model"""
|
||||
|
||||
def test_from_peewee_model(self):
|
||||
"""Test conversion from a Peewee model instance"""
|
||||
# Create a Peewee Session instance
|
||||
now = datetime.datetime.now()
|
||||
metadata = {"os": "Linux", "cpu_cores": 8, "memory_gb": 16}
|
||||
session = Session(
|
||||
id=1,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
start_time=now,
|
||||
command_line="ra-aid run",
|
||||
program_version="1.0.0",
|
||||
machine_info=json.dumps(metadata)
|
||||
)
|
||||
|
||||
# Convert to Pydantic model
|
||||
session_model = SessionModel.model_validate(session, from_attributes=True)
|
||||
|
||||
# Verify fields
|
||||
assert session_model.id == 1
|
||||
assert session_model.created_at == now
|
||||
assert session_model.updated_at == now
|
||||
assert session_model.start_time == now
|
||||
assert session_model.command_line == "ra-aid run"
|
||||
assert session_model.program_version == "1.0.0"
|
||||
assert session_model.machine_info == metadata
|
||||
|
||||
def test_with_dict_machine_info(self):
|
||||
"""Test creating a model with a dict for machine_info"""
|
||||
# Create directly with a dict for machine_info
|
||||
now = datetime.datetime.now()
|
||||
metadata = {"os": "Windows", "cpu_cores": 4, "memory_gb": 8}
|
||||
|
||||
session_model = SessionModel(
|
||||
id=2,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
start_time=now,
|
||||
command_line="ra-aid --debug",
|
||||
program_version="1.0.1",
|
||||
machine_info=metadata
|
||||
)
|
||||
|
||||
# Verify fields
|
||||
assert session_model.id == 2
|
||||
assert session_model.machine_info == metadata
|
||||
|
||||
def test_with_none_machine_info(self):
|
||||
"""Test creating a model with None for machine_info"""
|
||||
now = datetime.datetime.now()
|
||||
|
||||
session_model = SessionModel(
|
||||
id=3,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
start_time=now,
|
||||
command_line="ra-aid",
|
||||
program_version="1.0.0",
|
||||
machine_info=None
|
||||
)
|
||||
|
||||
assert session_model.id == 3
|
||||
assert session_model.machine_info is None
|
||||
|
||||
def test_invalid_json_machine_info(self):
|
||||
"""Test error handling for invalid JSON in machine_info"""
|
||||
now = datetime.datetime.now()
|
||||
|
||||
# Invalid JSON string should raise ValueError
|
||||
with pytest.raises(ValueError):
|
||||
SessionModel(
|
||||
id=4,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
start_time=now,
|
||||
command_line="ra-aid",
|
||||
program_version="1.0.0",
|
||||
machine_info="{invalid json}"
|
||||
)
|
||||
|
||||
def test_unexpected_type_machine_info(self):
|
||||
"""Test error handling for unexpected type in machine_info"""
|
||||
now = datetime.datetime.now()
|
||||
|
||||
# Integer type should raise ValueError
|
||||
with pytest.raises(ValueError):
|
||||
SessionModel(
|
||||
id=5,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
start_time=now,
|
||||
command_line="ra-aid",
|
||||
program_version="1.0.0",
|
||||
machine_info=123 # Not a dict or string
|
||||
)
|
||||
|
|
@ -9,7 +9,6 @@ import peewee
|
|||
|
||||
from ra_aid.database.connection import DatabaseManager, db_var
|
||||
from ra_aid.database.models import ResearchNote, BaseModel
|
||||
from ra_aid.database.pydantic_models import ResearchNoteModel
|
||||
from ra_aid.database.repositories.research_note_repository import (
|
||||
ResearchNoteRepository,
|
||||
ResearchNoteRepositoryManager,
|
||||
|
|
@ -91,12 +90,10 @@ def test_create_research_note(setup_db):
|
|||
# Verify the note was created correctly
|
||||
assert note.id is not None
|
||||
assert note.content == content
|
||||
assert isinstance(note, ResearchNoteModel)
|
||||
|
||||
# Verify we can retrieve it from the database using the repository
|
||||
note_from_db = repo.get(note.id)
|
||||
assert note_from_db.content == content
|
||||
assert isinstance(note_from_db, ResearchNoteModel)
|
||||
|
||||
|
||||
def test_get_research_note(setup_db):
|
||||
|
|
@ -115,7 +112,6 @@ def test_get_research_note(setup_db):
|
|||
assert retrieved_note is not None
|
||||
assert retrieved_note.id == note.id
|
||||
assert retrieved_note.content == content
|
||||
assert isinstance(retrieved_note, ResearchNoteModel)
|
||||
|
||||
# Try to retrieve a non-existent note
|
||||
non_existent_note = repo.get(999)
|
||||
|
|
@ -139,12 +135,10 @@ def test_update_research_note(setup_db):
|
|||
assert updated_note is not None
|
||||
assert updated_note.id == note.id
|
||||
assert updated_note.content == new_content
|
||||
assert isinstance(updated_note, ResearchNoteModel)
|
||||
|
||||
# Verify we can retrieve the updated content from the database using the repository
|
||||
note_from_db = repo.get(note.id)
|
||||
assert note_from_db.content == new_content
|
||||
assert isinstance(note_from_db, ResearchNoteModel)
|
||||
|
||||
# Try to update a non-existent note
|
||||
non_existent_update = repo.update(999, "This shouldn't work")
|
||||
|
|
@ -193,11 +187,8 @@ def test_get_all_research_notes(setup_db):
|
|||
# Verify we got the correct number of notes
|
||||
assert len(all_notes) == len(contents)
|
||||
|
||||
# Verify the content of each note and that they are Pydantic models
|
||||
# Verify the content of each note
|
||||
note_contents = [note.content for note in all_notes]
|
||||
for note in all_notes:
|
||||
assert isinstance(note, ResearchNoteModel)
|
||||
|
||||
for content in contents:
|
||||
assert content in note_contents
|
||||
|
||||
|
|
@ -248,7 +239,6 @@ def test_research_note_repository_manager(setup_db, cleanup_repo):
|
|||
note = repo.create(content)
|
||||
assert note.id is not None
|
||||
assert note.content == content
|
||||
assert isinstance(note, ResearchNoteModel)
|
||||
|
||||
# Verify we can get the repository using get_research_note_repository
|
||||
repo_from_var = get_research_note_repository()
|
||||
|
|
@ -268,26 +258,4 @@ def test_get_research_note_repository_when_not_set(cleanup_repo):
|
|||
get_research_note_repository()
|
||||
|
||||
# Verify the correct error message
|
||||
assert "No ResearchNoteRepository available" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_to_model_method(setup_db):
|
||||
"""Test the _to_model method converts Peewee models to Pydantic models correctly."""
|
||||
# Set up repository
|
||||
repo = ResearchNoteRepository(db=setup_db)
|
||||
|
||||
# Create a Peewee ResearchNote directly
|
||||
peewee_note = ResearchNote.create(content="Test note for conversion")
|
||||
|
||||
# Convert it using _to_model
|
||||
pydantic_note = repo._to_model(peewee_note)
|
||||
|
||||
# Verify the conversion
|
||||
assert isinstance(pydantic_note, ResearchNoteModel)
|
||||
assert pydantic_note.id == peewee_note.id
|
||||
assert pydantic_note.content == peewee_note.content
|
||||
assert pydantic_note.created_at == peewee_note.created_at
|
||||
assert pydantic_note.updated_at == peewee_note.updated_at
|
||||
|
||||
# Test with None
|
||||
assert repo._to_model(None) is None
|
||||
assert "No ResearchNoteRepository available" in str(excinfo.value)
|
||||
|
|
@ -1,364 +0,0 @@
|
|||
"""
|
||||
Tests for the SessionRepository class.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import datetime
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
import peewee
|
||||
|
||||
from ra_aid.database.connection import DatabaseManager, db_var
|
||||
from ra_aid.database.models import Session, BaseModel
|
||||
from ra_aid.database.repositories.session_repository import (
|
||||
SessionRepository,
|
||||
SessionRepositoryManager,
|
||||
get_session_repository,
|
||||
session_repo_var
|
||||
)
|
||||
from ra_aid.database.pydantic_models import SessionModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cleanup_db():
|
||||
"""Reset the database contextvar and connection state after each test."""
|
||||
# Reset before the test
|
||||
db = db_var.get()
|
||||
if db is not None:
|
||||
try:
|
||||
if not db.is_closed():
|
||||
db.close()
|
||||
except Exception:
|
||||
# Ignore errors when closing the database
|
||||
pass
|
||||
db_var.set(None)
|
||||
|
||||
# Run the test
|
||||
yield
|
||||
|
||||
# Reset after the test
|
||||
db = db_var.get()
|
||||
if db is not None:
|
||||
try:
|
||||
if not db.is_closed():
|
||||
db.close()
|
||||
except Exception:
|
||||
# Ignore errors when closing the database
|
||||
pass
|
||||
db_var.set(None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cleanup_repo():
|
||||
"""Reset the repository contextvar after each test."""
|
||||
# Reset before the test
|
||||
session_repo_var.set(None)
|
||||
|
||||
# Run the test
|
||||
yield
|
||||
|
||||
# Reset after the test
|
||||
session_repo_var.set(None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_db(cleanup_db):
|
||||
"""Set up an in-memory database with the Session table and patch the BaseModel.Meta.database."""
|
||||
# Initialize an in-memory database connection
|
||||
with DatabaseManager(in_memory=True) as db:
|
||||
# Patch the BaseModel.Meta.database to use our in-memory database
|
||||
with patch.object(BaseModel._meta, 'database', db):
|
||||
# Create the Session table
|
||||
with db.atomic():
|
||||
db.create_tables([Session], safe=True)
|
||||
|
||||
yield db
|
||||
|
||||
# Clean up
|
||||
with db.atomic():
|
||||
Session.drop_table(safe=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_metadata():
|
||||
"""Return test metadata for sessions."""
|
||||
return {
|
||||
"os": "Test OS",
|
||||
"version": "1.0",
|
||||
"cpu_cores": 4,
|
||||
"memory_gb": 16,
|
||||
"additional_info": {
|
||||
"gpu": "Test GPU",
|
||||
"display_resolution": "1920x1080"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_session(setup_db, test_metadata):
|
||||
"""Create a sample session in the database."""
|
||||
now = datetime.datetime.now()
|
||||
return Session.create(
|
||||
start_time=now,
|
||||
command_line="ra-aid test",
|
||||
program_version="1.0.0",
|
||||
machine_info=json.dumps(test_metadata)
|
||||
)
|
||||
|
||||
|
||||
def test_create_session_with_metadata(setup_db, test_metadata):
|
||||
"""Test creating a session with metadata."""
|
||||
# Set up repository
|
||||
repo = SessionRepository(db=setup_db)
|
||||
|
||||
# Create a session with metadata
|
||||
session = repo.create_session(metadata=test_metadata)
|
||||
|
||||
# Verify type is SessionModel, not Session
|
||||
assert isinstance(session, SessionModel)
|
||||
|
||||
# Verify the session was created correctly
|
||||
assert session.id is not None
|
||||
assert session.command_line is not None
|
||||
assert session.program_version is not None
|
||||
|
||||
# Verify machine_info is a dict, not a JSON string
|
||||
assert isinstance(session.machine_info, dict)
|
||||
assert session.machine_info == test_metadata
|
||||
|
||||
# Verify the dictionary structure is preserved
|
||||
assert "additional_info" in session.machine_info
|
||||
assert session.machine_info["additional_info"]["gpu"] == "Test GPU"
|
||||
|
||||
|
||||
def test_create_session_without_metadata(setup_db):
|
||||
"""Test creating a session without metadata."""
|
||||
# Set up repository
|
||||
repo = SessionRepository(db=setup_db)
|
||||
|
||||
# Create a session without metadata
|
||||
session = repo.create_session()
|
||||
|
||||
# Verify type is SessionModel, not Session
|
||||
assert isinstance(session, SessionModel)
|
||||
|
||||
# Verify the session was created correctly
|
||||
assert session.id is not None
|
||||
assert session.command_line is not None
|
||||
assert session.program_version is not None
|
||||
|
||||
# Verify machine_info is None
|
||||
assert session.machine_info is None
|
||||
|
||||
|
||||
def test_get_current_session(setup_db, sample_session):
|
||||
"""Test retrieving the current session."""
|
||||
# Set up repository
|
||||
repo = SessionRepository(db=setup_db)
|
||||
|
||||
# Set the current session
|
||||
repo.current_session = sample_session
|
||||
|
||||
# Get the current session
|
||||
current_session = repo.get_current_session()
|
||||
|
||||
# Verify type is SessionModel, not Session
|
||||
assert isinstance(current_session, SessionModel)
|
||||
|
||||
# Verify the retrieved session matches the original
|
||||
assert current_session.id == sample_session.id
|
||||
assert current_session.command_line == sample_session.command_line
|
||||
assert current_session.program_version == sample_session.program_version
|
||||
|
||||
# Verify machine_info is a dict, not a JSON string
|
||||
assert isinstance(current_session.machine_info, dict)
|
||||
|
||||
|
||||
def test_get_current_session_from_db(setup_db, sample_session):
|
||||
"""Test retrieving the current session from the database when no current session is set."""
|
||||
# Set up repository
|
||||
repo = SessionRepository(db=setup_db)
|
||||
|
||||
# Get the current session (should retrieve the most recent from DB)
|
||||
current_session = repo.get_current_session()
|
||||
|
||||
# Verify type is SessionModel, not Session
|
||||
assert isinstance(current_session, SessionModel)
|
||||
|
||||
# Verify the retrieved session matches the sample session
|
||||
assert current_session.id == sample_session.id
|
||||
|
||||
# Verify machine_info is a dict, not a JSON string
|
||||
assert isinstance(current_session.machine_info, dict)
|
||||
|
||||
|
||||
def test_get_by_id(setup_db, sample_session):
|
||||
"""Test retrieving a session by ID."""
|
||||
# Set up repository
|
||||
repo = SessionRepository(db=setup_db)
|
||||
|
||||
# Get the session by ID
|
||||
session = repo.get(sample_session.id)
|
||||
|
||||
# Verify type is SessionModel, not Session
|
||||
assert isinstance(session, SessionModel)
|
||||
|
||||
# Verify the retrieved session matches the original
|
||||
assert session.id == sample_session.id
|
||||
assert session.command_line == sample_session.command_line
|
||||
assert session.program_version == sample_session.program_version
|
||||
|
||||
# Verify machine_info is a dict, not a JSON string
|
||||
assert isinstance(session.machine_info, dict)
|
||||
|
||||
# Verify getting a non-existent session returns None
|
||||
non_existent_session = repo.get(999)
|
||||
assert non_existent_session is None
|
||||
|
||||
|
||||
def test_get_all(setup_db):
|
||||
"""Test retrieving all sessions."""
|
||||
# Set up repository
|
||||
repo = SessionRepository(db=setup_db)
|
||||
|
||||
# Create multiple sessions
|
||||
metadata1 = {"os": "Linux", "cpu_cores": 8}
|
||||
metadata2 = {"os": "Windows", "cpu_cores": 4}
|
||||
metadata3 = {"os": "macOS", "cpu_cores": 10}
|
||||
|
||||
repo.create_session(metadata=metadata1)
|
||||
repo.create_session(metadata=metadata2)
|
||||
repo.create_session(metadata=metadata3)
|
||||
|
||||
# Get all sessions with default pagination
|
||||
sessions, total_count = repo.get_all()
|
||||
|
||||
# Verify total count
|
||||
assert total_count == 3
|
||||
|
||||
# Verify we got a list of SessionModel objects
|
||||
assert len(sessions) == 3
|
||||
for session in sessions:
|
||||
assert isinstance(session, SessionModel)
|
||||
assert isinstance(session.machine_info, dict)
|
||||
|
||||
# Verify the sessions are in descending order of creation time
|
||||
assert sessions[0].created_at >= sessions[1].created_at
|
||||
assert sessions[1].created_at >= sessions[2].created_at
|
||||
|
||||
# Verify the machine_info fields
|
||||
os_values = [session.machine_info["os"] for session in sessions]
|
||||
assert "Linux" in os_values
|
||||
assert "Windows" in os_values
|
||||
assert "macOS" in os_values
|
||||
|
||||
# Test pagination with limit
|
||||
sessions_limited, total_count = repo.get_all(limit=2)
|
||||
assert total_count == 3 # Total count should still be 3
|
||||
assert len(sessions_limited) == 2 # But only 2 returned
|
||||
|
||||
# Test pagination with offset
|
||||
sessions_offset, total_count = repo.get_all(offset=1, limit=2)
|
||||
assert total_count == 3
|
||||
assert len(sessions_offset) == 2
|
||||
|
||||
# The second item in the full list should be the first item in the offset list
|
||||
assert sessions[1].id == sessions_offset[0].id
|
||||
|
||||
|
||||
def test_get_all_empty(setup_db):
|
||||
"""Test retrieving all sessions when none exist."""
|
||||
# Set up repository
|
||||
repo = SessionRepository(db=setup_db)
|
||||
|
||||
# Get all sessions
|
||||
sessions, total_count = repo.get_all()
|
||||
|
||||
# Verify we got an empty list and zero count
|
||||
assert isinstance(sessions, list)
|
||||
assert len(sessions) == 0
|
||||
assert total_count == 0
|
||||
|
||||
|
||||
def test_get_recent(setup_db):
|
||||
"""Test retrieving recent sessions with a limit."""
|
||||
# Set up repository
|
||||
repo = SessionRepository(db=setup_db)
|
||||
|
||||
# Create multiple sessions
|
||||
for i in range(5):
|
||||
metadata = {"index": i, "os": f"OS {i}"}
|
||||
repo.create_session(metadata=metadata)
|
||||
|
||||
# Get recent sessions with limit=3
|
||||
sessions = repo.get_recent(limit=3)
|
||||
|
||||
# Verify we got the correct number of SessionModel objects
|
||||
assert len(sessions) == 3
|
||||
for session in sessions:
|
||||
assert isinstance(session, SessionModel)
|
||||
assert isinstance(session.machine_info, dict)
|
||||
|
||||
# Verify the sessions are in descending order and are the most recent ones
|
||||
indexes = [session.machine_info["index"] for session in sessions]
|
||||
assert indexes == [4, 3, 2] # Most recent first
|
||||
|
||||
|
||||
def test_session_repository_manager(setup_db, cleanup_repo):
|
||||
"""Test the SessionRepositoryManager context manager."""
|
||||
# Use the context manager to create a repository
|
||||
with SessionRepositoryManager(setup_db) as repo:
|
||||
# Verify the repository was created correctly
|
||||
assert isinstance(repo, SessionRepository)
|
||||
assert repo.db is setup_db
|
||||
|
||||
# Create a session and verify it's a SessionModel
|
||||
metadata = {"test": "manager"}
|
||||
session = repo.create_session(metadata=metadata)
|
||||
assert isinstance(session, SessionModel)
|
||||
assert session.machine_info["test"] == "manager"
|
||||
|
||||
# Verify we can get the repository using get_session_repository
|
||||
repo_from_var = get_session_repository()
|
||||
assert repo_from_var is repo
|
||||
|
||||
# Verify the repository was removed from the context var
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
get_session_repository()
|
||||
|
||||
assert "No SessionRepository available" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_repository_init_without_db():
|
||||
"""Test that SessionRepository raises an error when initialized without a db parameter."""
|
||||
# Attempt to create a repository without a database connection
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
SessionRepository(db=None)
|
||||
|
||||
# Verify the correct error message
|
||||
assert "Database connection is required" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_get_current_session_id(setup_db, sample_session):
|
||||
"""Test retrieving the ID of the current session."""
|
||||
# Set up repository
|
||||
repo = SessionRepository(db=setup_db)
|
||||
|
||||
# Set the current session
|
||||
repo.current_session = sample_session
|
||||
|
||||
# Get the current session ID
|
||||
session_id = repo.get_current_session_id()
|
||||
|
||||
# Verify the ID matches
|
||||
assert session_id == sample_session.id
|
||||
|
||||
# Test when no current session exists
|
||||
repo.current_session = None
|
||||
# Delete all sessions
|
||||
Session.delete().execute()
|
||||
|
||||
# Verify None is returned when no session exists
|
||||
session_id = repo.get_current_session_id()
|
||||
assert session_id is None
|
||||
|
|
@ -1,458 +0,0 @@
|
|||
"""
|
||||
Tests for the TrajectoryRepository class.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import datetime
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
import peewee
|
||||
|
||||
from ra_aid.database.connection import DatabaseManager, db_var
|
||||
from ra_aid.database.models import Trajectory, HumanInput, Session, BaseModel
|
||||
from ra_aid.database.repositories.trajectory_repository import (
|
||||
TrajectoryRepository,
|
||||
TrajectoryRepositoryManager,
|
||||
get_trajectory_repository,
|
||||
trajectory_repo_var
|
||||
)
|
||||
from ra_aid.database.pydantic_models import TrajectoryModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cleanup_db():
|
||||
"""Reset the database contextvar and connection state after each test."""
|
||||
# Reset before the test
|
||||
db = db_var.get()
|
||||
if db is not None:
|
||||
try:
|
||||
if not db.is_closed():
|
||||
db.close()
|
||||
except Exception:
|
||||
# Ignore errors when closing the database
|
||||
pass
|
||||
db_var.set(None)
|
||||
|
||||
# Run the test
|
||||
yield
|
||||
|
||||
# Reset after the test
|
||||
db = db_var.get()
|
||||
if db is not None:
|
||||
try:
|
||||
if not db.is_closed():
|
||||
db.close()
|
||||
except Exception:
|
||||
# Ignore errors when closing the database
|
||||
pass
|
||||
db_var.set(None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cleanup_repo():
|
||||
"""Reset the repository contextvar after each test."""
|
||||
# Reset before the test
|
||||
trajectory_repo_var.set(None)
|
||||
|
||||
# Run the test
|
||||
yield
|
||||
|
||||
# Reset after the test
|
||||
trajectory_repo_var.set(None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_db(cleanup_db):
|
||||
"""Set up an in-memory database with the necessary tables and patch the BaseModel.Meta.database."""
|
||||
# Initialize an in-memory database connection
|
||||
with DatabaseManager(in_memory=True) as db:
|
||||
# Patch the BaseModel.Meta.database to use our in-memory database
|
||||
with patch.object(BaseModel._meta, 'database', db):
|
||||
# Create the required tables
|
||||
with db.atomic():
|
||||
db.create_tables([Trajectory, HumanInput, Session], safe=True)
|
||||
|
||||
yield db
|
||||
|
||||
# Clean up
|
||||
with db.atomic():
|
||||
Trajectory.drop_table(safe=True)
|
||||
HumanInput.drop_table(safe=True)
|
||||
Session.drop_table(safe=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_human_input(setup_db):
|
||||
"""Create a sample human input in the database."""
|
||||
return HumanInput.create(
|
||||
content="Test human input",
|
||||
source="test"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_tool_parameters():
|
||||
"""Return test tool parameters."""
|
||||
return {
|
||||
"pattern": "test pattern",
|
||||
"file_path": "/path/to/file",
|
||||
"options": {
|
||||
"case_sensitive": True,
|
||||
"whole_words": False
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_tool_result():
|
||||
"""Return test tool result."""
|
||||
return {
|
||||
"matches": [
|
||||
{"line": 10, "content": "This is a test pattern"},
|
||||
{"line": 20, "content": "Another test pattern here"}
|
||||
],
|
||||
"total_matches": 2,
|
||||
"execution_time": 0.5
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_step_data():
|
||||
"""Return test step data for UI rendering."""
|
||||
return {
|
||||
"display_type": "text",
|
||||
"content": "Tool execution results",
|
||||
"highlights": [
|
||||
{"start": 10, "end": 15, "color": "red"}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_trajectory(setup_db, sample_human_input, test_tool_parameters, test_tool_result, test_step_data):
|
||||
"""Create a sample trajectory in the database."""
|
||||
return Trajectory.create(
|
||||
human_input=sample_human_input,
|
||||
tool_name="ripgrep_search",
|
||||
tool_parameters=json.dumps(test_tool_parameters),
|
||||
tool_result=json.dumps(test_tool_result),
|
||||
step_data=json.dumps(test_step_data),
|
||||
record_type="tool_execution",
|
||||
cost=0.001,
|
||||
tokens=100,
|
||||
is_error=False
|
||||
)
|
||||
|
||||
|
||||
def test_create_trajectory(setup_db, sample_human_input, test_tool_parameters, test_tool_result, test_step_data):
|
||||
"""Test creating a trajectory with all fields."""
|
||||
# Set up repository
|
||||
repo = TrajectoryRepository(db=setup_db)
|
||||
|
||||
# Create a trajectory
|
||||
trajectory = repo.create(
|
||||
tool_name="ripgrep_search",
|
||||
tool_parameters=test_tool_parameters,
|
||||
tool_result=test_tool_result,
|
||||
step_data=test_step_data,
|
||||
record_type="tool_execution",
|
||||
human_input_id=sample_human_input.id,
|
||||
cost=0.001,
|
||||
tokens=100
|
||||
)
|
||||
|
||||
# Verify type is TrajectoryModel, not Trajectory
|
||||
assert isinstance(trajectory, TrajectoryModel)
|
||||
|
||||
# Verify the trajectory was created correctly
|
||||
assert trajectory.id is not None
|
||||
assert trajectory.tool_name == "ripgrep_search"
|
||||
|
||||
# Verify the JSON fields are dictionaries, not strings
|
||||
assert isinstance(trajectory.tool_parameters, dict)
|
||||
assert isinstance(trajectory.tool_result, dict)
|
||||
assert isinstance(trajectory.step_data, dict)
|
||||
|
||||
# Verify the nested structure of tool parameters
|
||||
assert trajectory.tool_parameters["options"]["case_sensitive"] == True
|
||||
assert trajectory.tool_result["total_matches"] == 2
|
||||
assert trajectory.step_data["highlights"][0]["color"] == "red"
|
||||
|
||||
# Verify foreign key reference
|
||||
assert trajectory.human_input_id == sample_human_input.id
|
||||
|
||||
|
||||
def test_create_trajectory_minimal(setup_db):
|
||||
"""Test creating a trajectory with minimal fields."""
|
||||
# Set up repository
|
||||
repo = TrajectoryRepository(db=setup_db)
|
||||
|
||||
# Create a trajectory with minimal fields
|
||||
trajectory = repo.create(
|
||||
tool_name="simple_tool"
|
||||
)
|
||||
|
||||
# Verify type is TrajectoryModel, not Trajectory
|
||||
assert isinstance(trajectory, TrajectoryModel)
|
||||
|
||||
# Verify the trajectory was created correctly
|
||||
assert trajectory.id is not None
|
||||
assert trajectory.tool_name == "simple_tool"
|
||||
|
||||
# Verify optional fields are None
|
||||
assert trajectory.tool_parameters is None
|
||||
assert trajectory.tool_result is None
|
||||
assert trajectory.step_data is None
|
||||
assert trajectory.human_input_id is None
|
||||
assert trajectory.cost is None
|
||||
assert trajectory.tokens is None
|
||||
assert trajectory.is_error is False
|
||||
|
||||
|
||||
def test_get_trajectory(setup_db, sample_trajectory, test_tool_parameters, test_tool_result, test_step_data):
|
||||
"""Test retrieving a trajectory by ID."""
|
||||
# Set up repository
|
||||
repo = TrajectoryRepository(db=setup_db)
|
||||
|
||||
# Get the trajectory by ID
|
||||
trajectory = repo.get(sample_trajectory.id)
|
||||
|
||||
# Verify type is TrajectoryModel, not Trajectory
|
||||
assert isinstance(trajectory, TrajectoryModel)
|
||||
|
||||
# Verify the retrieved trajectory matches the original
|
||||
assert trajectory.id == sample_trajectory.id
|
||||
assert trajectory.tool_name == sample_trajectory.tool_name
|
||||
|
||||
# Verify the JSON fields are dictionaries, not strings
|
||||
assert isinstance(trajectory.tool_parameters, dict)
|
||||
assert isinstance(trajectory.tool_result, dict)
|
||||
assert isinstance(trajectory.step_data, dict)
|
||||
|
||||
# Verify the content of JSON fields
|
||||
assert trajectory.tool_parameters == test_tool_parameters
|
||||
assert trajectory.tool_result == test_tool_result
|
||||
assert trajectory.step_data == test_step_data
|
||||
|
||||
# Verify non-existent trajectory returns None
|
||||
non_existent_trajectory = repo.get(999)
|
||||
assert non_existent_trajectory is None
|
||||
|
||||
|
||||
def test_update_trajectory(setup_db, sample_trajectory):
|
||||
"""Test updating a trajectory."""
|
||||
# Set up repository
|
||||
repo = TrajectoryRepository(db=setup_db)
|
||||
|
||||
# New data for update
|
||||
new_tool_result = {
|
||||
"matches": [
|
||||
{"line": 15, "content": "Updated test pattern"}
|
||||
],
|
||||
"total_matches": 1,
|
||||
"execution_time": 0.3
|
||||
}
|
||||
|
||||
new_step_data = {
|
||||
"display_type": "html",
|
||||
"content": "Updated UI rendering",
|
||||
"highlights": []
|
||||
}
|
||||
|
||||
# Update the trajectory
|
||||
updated_trajectory = repo.update(
|
||||
trajectory_id=sample_trajectory.id,
|
||||
tool_result=new_tool_result,
|
||||
step_data=new_step_data,
|
||||
cost=0.002,
|
||||
tokens=200,
|
||||
is_error=True,
|
||||
error_message="Test error",
|
||||
error_type="TestErrorType",
|
||||
error_details="Detailed error information"
|
||||
)
|
||||
|
||||
# Verify type is TrajectoryModel, not Trajectory
|
||||
assert isinstance(updated_trajectory, TrajectoryModel)
|
||||
|
||||
# Verify the fields were updated
|
||||
assert updated_trajectory.tool_result == new_tool_result
|
||||
assert updated_trajectory.step_data == new_step_data
|
||||
assert updated_trajectory.cost == 0.002
|
||||
assert updated_trajectory.tokens == 200
|
||||
assert updated_trajectory.is_error is True
|
||||
assert updated_trajectory.error_message == "Test error"
|
||||
assert updated_trajectory.error_type == "TestErrorType"
|
||||
assert updated_trajectory.error_details == "Detailed error information"
|
||||
|
||||
# Original tool parameters should not change
|
||||
# We need to parse the JSON string from the Peewee object for comparison
|
||||
original_params = json.loads(sample_trajectory.tool_parameters)
|
||||
assert updated_trajectory.tool_parameters == original_params
|
||||
|
||||
# Verify updating a non-existent trajectory returns None
|
||||
non_existent_update = repo.update(trajectory_id=999, cost=0.005)
|
||||
assert non_existent_update is None
|
||||
|
||||
|
||||
def test_delete_trajectory(setup_db, sample_trajectory):
|
||||
"""Test deleting a trajectory."""
|
||||
# Set up repository
|
||||
repo = TrajectoryRepository(db=setup_db)
|
||||
|
||||
# Verify the trajectory exists
|
||||
assert repo.get(sample_trajectory.id) is not None
|
||||
|
||||
# Delete the trajectory
|
||||
result = repo.delete(sample_trajectory.id)
|
||||
|
||||
# Verify the trajectory was deleted
|
||||
assert result is True
|
||||
assert repo.get(sample_trajectory.id) is None
|
||||
|
||||
# Verify deleting a non-existent trajectory returns False
|
||||
result = repo.delete(999)
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_get_all_trajectories(setup_db, sample_human_input):
|
||||
"""Test retrieving all trajectories."""
|
||||
# Set up repository
|
||||
repo = TrajectoryRepository(db=setup_db)
|
||||
|
||||
# Create multiple trajectories
|
||||
for i in range(3):
|
||||
repo.create(
|
||||
tool_name=f"tool_{i}",
|
||||
tool_parameters={"index": i},
|
||||
human_input_id=sample_human_input.id
|
||||
)
|
||||
|
||||
# Get all trajectories
|
||||
trajectories = repo.get_all()
|
||||
|
||||
# Verify we got a dictionary of TrajectoryModel objects
|
||||
assert len(trajectories) == 3
|
||||
for trajectory_id, trajectory in trajectories.items():
|
||||
assert isinstance(trajectory, TrajectoryModel)
|
||||
assert isinstance(trajectory.tool_parameters, dict)
|
||||
|
||||
# Verify the trajectories have the correct tool names
|
||||
tool_names = {trajectory.tool_name for trajectory in trajectories.values()}
|
||||
assert "tool_0" in tool_names
|
||||
assert "tool_1" in tool_names
|
||||
assert "tool_2" in tool_names
|
||||
|
||||
|
||||
def test_get_trajectories_by_human_input(setup_db, sample_human_input):
|
||||
"""Test retrieving trajectories by human input ID."""
|
||||
# Set up repository
|
||||
repo = TrajectoryRepository(db=setup_db)
|
||||
|
||||
# Create another human input
|
||||
other_human_input = HumanInput.create(
|
||||
content="Another human input",
|
||||
source="test"
|
||||
)
|
||||
|
||||
# Create trajectories for both human inputs
|
||||
for i in range(2):
|
||||
repo.create(
|
||||
tool_name=f"tool_1_{i}",
|
||||
human_input_id=sample_human_input.id
|
||||
)
|
||||
|
||||
for i in range(3):
|
||||
repo.create(
|
||||
tool_name=f"tool_2_{i}",
|
||||
human_input_id=other_human_input.id
|
||||
)
|
||||
|
||||
# Get trajectories for the first human input
|
||||
trajectories = repo.get_trajectories_by_human_input(sample_human_input.id)
|
||||
|
||||
# Verify we got a list of TrajectoryModel objects for the first human input
|
||||
assert len(trajectories) == 2
|
||||
for trajectory in trajectories:
|
||||
assert isinstance(trajectory, TrajectoryModel)
|
||||
assert trajectory.human_input_id == sample_human_input.id
|
||||
assert trajectory.tool_name.startswith("tool_1")
|
||||
|
||||
# Get trajectories for the second human input
|
||||
trajectories = repo.get_trajectories_by_human_input(other_human_input.id)
|
||||
|
||||
# Verify we got a list of TrajectoryModel objects for the second human input
|
||||
assert len(trajectories) == 3
|
||||
for trajectory in trajectories:
|
||||
assert isinstance(trajectory, TrajectoryModel)
|
||||
assert trajectory.human_input_id == other_human_input.id
|
||||
assert trajectory.tool_name.startswith("tool_2")
|
||||
|
||||
|
||||
def test_get_parsed_trajectory(setup_db, sample_trajectory, test_tool_parameters, test_tool_result, test_step_data):
|
||||
"""Test retrieving a parsed trajectory."""
|
||||
# Set up repository
|
||||
repo = TrajectoryRepository(db=setup_db)
|
||||
|
||||
# Get the parsed trajectory
|
||||
trajectory = repo.get_parsed_trajectory(sample_trajectory.id)
|
||||
|
||||
# Verify type is TrajectoryModel, not Trajectory
|
||||
assert isinstance(trajectory, TrajectoryModel)
|
||||
|
||||
# Verify the retrieved trajectory matches the original
|
||||
assert trajectory.id == sample_trajectory.id
|
||||
assert trajectory.tool_name == sample_trajectory.tool_name
|
||||
|
||||
# Verify the JSON fields are dictionaries, not strings
|
||||
assert isinstance(trajectory.tool_parameters, dict)
|
||||
assert isinstance(trajectory.tool_result, dict)
|
||||
assert isinstance(trajectory.step_data, dict)
|
||||
|
||||
# Verify the content of JSON fields
|
||||
assert trajectory.tool_parameters == test_tool_parameters
|
||||
assert trajectory.tool_result == test_tool_result
|
||||
assert trajectory.step_data == test_step_data
|
||||
|
||||
# Verify non-existent trajectory returns None
|
||||
non_existent_trajectory = repo.get_parsed_trajectory(999)
|
||||
assert non_existent_trajectory is None
|
||||
|
||||
|
||||
def test_trajectory_repository_manager(setup_db, cleanup_repo):
|
||||
"""Test the TrajectoryRepositoryManager context manager."""
|
||||
# Use the context manager to create a repository
|
||||
with TrajectoryRepositoryManager(setup_db) as repo:
|
||||
# Verify the repository was created correctly
|
||||
assert isinstance(repo, TrajectoryRepository)
|
||||
assert repo.db is setup_db
|
||||
|
||||
# Create a trajectory and verify it's a TrajectoryModel
|
||||
tool_parameters = {"test": "manager"}
|
||||
trajectory = repo.create(
|
||||
tool_name="manager_test",
|
||||
tool_parameters=tool_parameters
|
||||
)
|
||||
assert isinstance(trajectory, TrajectoryModel)
|
||||
assert trajectory.tool_parameters["test"] == "manager"
|
||||
|
||||
# Verify we can get the repository using get_trajectory_repository
|
||||
repo_from_var = get_trajectory_repository()
|
||||
assert repo_from_var is repo
|
||||
|
||||
# Verify the repository was removed from the context var
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
get_trajectory_repository()
|
||||
|
||||
assert "No TrajectoryRepository available" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_repository_init_without_db():
|
||||
"""Test that TrajectoryRepository raises an error when initialized without a db parameter."""
|
||||
# Attempt to create a repository without a database connection
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
TrajectoryRepository(db=None)
|
||||
|
||||
# Verify the correct error message
|
||||
assert "Database connection is required" in str(excinfo.value)
|
||||
|
|
@ -1,132 +0,0 @@
|
|||
"""
|
||||
Tests for the Sessions API v1 endpoints.
|
||||
|
||||
This module contains tests for the sessions API endpoints in ra_aid/server/api_v1_sessions.py.
|
||||
It tests the creation, listing, and retrieval of sessions through the API.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import MagicMock
|
||||
import datetime
|
||||
|
||||
from ra_aid.server.api_v1_sessions import router, get_repository
|
||||
from ra_aid.database.pydantic_models import SessionModel
|
||||
|
||||
|
||||
# Mock session data for testing
|
||||
@pytest.fixture
|
||||
def mock_session():
|
||||
"""Return a mock session for testing."""
|
||||
return SessionModel(
|
||||
id=1,
|
||||
created_at=datetime.datetime(2025, 1, 1, 0, 0, 0),
|
||||
updated_at=datetime.datetime(2025, 1, 1, 0, 0, 0),
|
||||
start_time=datetime.datetime(2025, 1, 1, 0, 0, 0),
|
||||
command_line="ra-aid test",
|
||||
program_version="1.0.0",
|
||||
machine_info={"os": "test"}
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_sessions():
|
||||
"""Return a list of mock sessions for testing."""
|
||||
return [
|
||||
SessionModel(
|
||||
id=1,
|
||||
created_at=datetime.datetime(2025, 1, 1, 0, 0, 0),
|
||||
updated_at=datetime.datetime(2025, 1, 1, 0, 0, 0),
|
||||
start_time=datetime.datetime(2025, 1, 1, 0, 0, 0),
|
||||
command_line="ra-aid test1",
|
||||
program_version="1.0.0",
|
||||
machine_info={"os": "test"}
|
||||
),
|
||||
SessionModel(
|
||||
id=2,
|
||||
created_at=datetime.datetime(2025, 1, 2, 0, 0, 0),
|
||||
updated_at=datetime.datetime(2025, 1, 2, 0, 0, 0),
|
||||
start_time=datetime.datetime(2025, 1, 2, 0, 0, 0),
|
||||
command_line="ra-aid test2",
|
||||
program_version="1.0.0",
|
||||
machine_info={"os": "test"}
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_repo(mock_session, mock_sessions):
|
||||
"""Mock the SessionRepository for testing."""
|
||||
repo = MagicMock()
|
||||
repo.get.return_value = mock_session
|
||||
repo.get_all.return_value = (mock_sessions, len(mock_sessions))
|
||||
repo.create_session.return_value = mock_session
|
||||
return repo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(mock_repo):
|
||||
"""Return a TestClient for the API router with dependency override."""
|
||||
from fastapi import FastAPI
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
|
||||
# Override the dependency
|
||||
app.dependency_overrides[get_repository] = lambda: mock_repo
|
||||
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def test_get_session(client, mock_repo, mock_session):
|
||||
"""Test getting a specific session by ID."""
|
||||
response = client.get("/v1/sessions/1")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["id"] == mock_session.id
|
||||
assert response.json()["command_line"] == mock_session.command_line
|
||||
mock_repo.get.assert_called_once_with(1)
|
||||
|
||||
|
||||
def test_get_session_not_found(client, mock_repo):
|
||||
"""Test getting a session that doesn't exist."""
|
||||
mock_repo.get.return_value = None
|
||||
|
||||
response = client.get("/v1/sessions/999")
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "not found" in response.json()["detail"]
|
||||
mock_repo.get.assert_called_once_with(999)
|
||||
|
||||
|
||||
def test_list_sessions(client, mock_repo, mock_sessions):
|
||||
"""Test listing sessions with pagination."""
|
||||
response = client.get("/v1/sessions?offset=0&limit=10")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == len(mock_sessions)
|
||||
assert len(data["items"]) == len(mock_sessions)
|
||||
assert data["limit"] == 10
|
||||
assert data["offset"] == 0
|
||||
mock_repo.get_all.assert_called_once_with(offset=0, limit=10)
|
||||
|
||||
|
||||
def test_create_session(client, mock_repo, mock_session):
|
||||
"""Test creating a new session."""
|
||||
response = client.post(
|
||||
"/v1/sessions",
|
||||
json={"metadata": {"test": "data"}}
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
assert response.json()["id"] == mock_session.id
|
||||
mock_repo.create_session.assert_called_once_with(metadata={"test": "data"})
|
||||
|
||||
|
||||
def test_create_session_no_body(client, mock_repo, mock_session):
|
||||
"""Test creating a new session without a request body."""
|
||||
response = client.post("/v1/sessions")
|
||||
|
||||
assert response.status_code == 201
|
||||
assert response.json()["id"] == mock_session.id
|
||||
mock_repo.create_session.assert_called_once_with(metadata=None)
|
||||
|
|
@ -1,153 +0,0 @@
|
|||
"""
|
||||
Tests for the Spawn Agent API v1 endpoint.
|
||||
|
||||
This module contains tests for the spawn-agent API endpoint in ra_aid/server/api_v1_spawn_agent.py.
|
||||
It tests the creation of agent threads and session handling for the spawn-agent endpoint.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import threading
|
||||
from unittest.mock import MagicMock
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from ra_aid.server.api_v1_spawn_agent import router, get_repository
|
||||
from ra_aid.database.pydantic_models import SessionModel
|
||||
import datetime
|
||||
import ra_aid.server.api_v1_spawn_agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session():
|
||||
"""Return a mock session for testing."""
|
||||
return SessionModel(
|
||||
id=123,
|
||||
created_at=datetime.datetime(2025, 1, 1, 0, 0, 0),
|
||||
updated_at=datetime.datetime(2025, 1, 1, 0, 0, 0),
|
||||
start_time=datetime.datetime(2025, 1, 1, 0, 0, 0),
|
||||
command_line="ra-aid test",
|
||||
program_version="1.0.0",
|
||||
machine_info={"agent_type": "research", "expert_enabled": True, "web_research_enabled": False}
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_thread():
|
||||
"""Create a mock thread that does nothing when started."""
|
||||
mock = MagicMock()
|
||||
mock.daemon = True
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_repository(mock_session):
|
||||
"""Create a mock repository for testing."""
|
||||
mock_repo = MagicMock()
|
||||
mock_repo.create_session.return_value = mock_session
|
||||
return mock_repo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_repository():
|
||||
"""Create a mock config repository for testing."""
|
||||
mock_config = MagicMock()
|
||||
mock_config.get.side_effect = lambda key, default=None: {
|
||||
"expert_enabled": True,
|
||||
"web_research_enabled": False,
|
||||
"provider": "anthropic",
|
||||
"model": "claude-3-7-sonnet-20250219",
|
||||
}.get(key, default)
|
||||
return mock_config
|
||||
|
||||
@pytest.fixture
|
||||
def client(mock_repository, mock_thread, mock_config_repository, monkeypatch):
|
||||
"""Set up a test client with mocked dependencies."""
|
||||
# Create FastAPI app with router
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
|
||||
# Override the dependency to use our mock repository
|
||||
app.dependency_overrides[get_repository] = lambda: mock_repository
|
||||
|
||||
# Mock run_agent_thread to be a no-op
|
||||
monkeypatch.setattr(
|
||||
"ra_aid.server.api_v1_spawn_agent.run_agent_thread",
|
||||
lambda *args, **kwargs: None
|
||||
)
|
||||
|
||||
# Mock get_config_repository to use our mock
|
||||
monkeypatch.setattr(
|
||||
"ra_aid.server.api_v1_spawn_agent.get_config_repository",
|
||||
lambda: mock_config_repository
|
||||
)
|
||||
|
||||
# Mock threading.Thread to return our mock thread
|
||||
def mock_thread_constructor(*args, **kwargs):
|
||||
mock_thread.target = kwargs.get('target')
|
||||
mock_thread.args = kwargs.get('args')
|
||||
mock_thread.daemon = kwargs.get('daemon', False)
|
||||
return mock_thread
|
||||
|
||||
monkeypatch.setattr(
|
||||
ra_aid.server.api_v1_spawn_agent,
|
||||
"threading",
|
||||
MagicMock(Thread=mock_thread_constructor)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# Add mocks to client for test access
|
||||
client.mock_repo = mock_repository
|
||||
client.mock_thread = mock_thread
|
||||
client.mock_config = mock_config_repository
|
||||
|
||||
yield client
|
||||
|
||||
# Clean up the dependency override
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_spawn_agent(client, mock_repository, mock_thread):
|
||||
"""Test spawning an agent with valid parameters."""
|
||||
# Create the request payload
|
||||
payload = {
|
||||
"message": "Test task for the agent",
|
||||
"research_only": False,
|
||||
"expert_enabled": True,
|
||||
"web_research_enabled": False
|
||||
}
|
||||
|
||||
# Send the request
|
||||
response = client.post("/v1/spawn-agent", json=payload)
|
||||
|
||||
# Verify response
|
||||
assert response.status_code == 201
|
||||
assert response.json() == {"session_id": "123"}
|
||||
|
||||
# Verify session creation
|
||||
mock_repository.create_session.assert_called_once()
|
||||
|
||||
# Verify thread was created with correct args
|
||||
assert mock_thread.args == ("Test task for the agent", "123", False)
|
||||
assert mock_thread.daemon is True
|
||||
|
||||
# Verify thread.start was called
|
||||
mock_thread.start.assert_called_once()
|
||||
|
||||
|
||||
def test_spawn_agent_missing_message(client):
|
||||
"""Test spawning an agent with missing required message parameter."""
|
||||
# Create a request payload missing the required message
|
||||
payload = {
|
||||
"research_only": False,
|
||||
"expert_enabled": True,
|
||||
"web_research_enabled": False
|
||||
}
|
||||
|
||||
# Send the request
|
||||
response = client.post("/v1/spawn-agent", json=payload)
|
||||
|
||||
# Verify response indicates validation error
|
||||
assert response.status_code == 422
|
||||
error_detail = response.json().get("detail", [])
|
||||
assert any("message" in error.get("loc", []) for error in error_detail)
|
||||
|
|
@ -1,75 +0,0 @@
|
|||
"""
|
||||
Tests for server.py FastAPI application.
|
||||
|
||||
This module tests the FastAPI application setup in server.py to ensure
|
||||
that all routers are properly mounted and middleware is configured.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from ra_aid.server.server import app
|
||||
from ra_aid.database.repositories.session_repository import session_repo_var
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Return a TestClient for the FastAPI app."""
|
||||
# Mock the session repository to avoid database dependency
|
||||
mock_repo = MagicMock()
|
||||
mock_repo.get_all.return_value = ([], 0)
|
||||
|
||||
# Set the repository in the contextvar
|
||||
token = session_repo_var.set(mock_repo)
|
||||
|
||||
yield TestClient(app)
|
||||
|
||||
# Reset the contextvar after the test
|
||||
session_repo_var.reset(token)
|
||||
|
||||
|
||||
def test_config_endpoint(client):
|
||||
"""Test that the config endpoint returns server configuration."""
|
||||
response = client.get("/config")
|
||||
assert response.status_code == 200
|
||||
assert "host" in response.json()
|
||||
assert "port" in response.json()
|
||||
|
||||
|
||||
def test_api_documentation(client):
|
||||
"""Test that the OpenAPI documentation includes the sessions API."""
|
||||
response = client.get("/openapi.json")
|
||||
assert response.status_code == 200
|
||||
|
||||
openapi_spec = response.json()
|
||||
assert "paths" in openapi_spec
|
||||
|
||||
# Check that the sessions API paths are included
|
||||
assert "/v1/sessions" in openapi_spec["paths"]
|
||||
assert "/v1/sessions/{session_id}" in openapi_spec["paths"]
|
||||
|
||||
# Verify that sessions API operations are documented
|
||||
assert "get" in openapi_spec["paths"]["/v1/sessions"]
|
||||
assert "post" in openapi_spec["paths"]["/v1/sessions"]
|
||||
assert "get" in openapi_spec["paths"]["/v1/sessions/{session_id}"]
|
||||
|
||||
|
||||
@patch("ra_aid.database.repositories.session_repository.get_session_repository")
|
||||
def test_sessions_api_mounted(mock_get_repo, client):
|
||||
"""Test that the sessions API router is mounted correctly."""
|
||||
# Mock the repository for this specific test
|
||||
mock_repo = MagicMock()
|
||||
mock_repo.get_all.return_value = ([], 0)
|
||||
mock_get_repo.return_value = mock_repo
|
||||
|
||||
# Test that the sessions list endpoint is accessible
|
||||
response = client.get("/v1/sessions")
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify the response structure follows our expected format
|
||||
data = response.json()
|
||||
assert "total" in data
|
||||
assert "items" in data
|
||||
assert "limit" in data
|
||||
assert "offset" in data
|
||||
|
|
@ -1,532 +0,0 @@
|
|||
"""
|
||||
Integration tests for the Sessions API endpoints.
|
||||
|
||||
This module contains integration tests for the API endpoints defined in ra_aid/server/api_v1_sessions.py.
|
||||
It uses mocks to simulate the database interactions while testing the real API behavior.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import datetime
|
||||
from typing import Dict, Any, List, Tuple
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from ra_aid.server.server import app
|
||||
from ra_aid.database.pydantic_models import SessionModel
|
||||
from ra_aid.server.api_v1_sessions import get_repository
|
||||
|
||||
|
||||
# Mock session data for testing
|
||||
@pytest.fixture
|
||||
def mock_session():
|
||||
"""Return a mock session for testing."""
|
||||
return SessionModel(
|
||||
id=1,
|
||||
created_at=datetime.datetime(2025, 1, 1, 0, 0, 0),
|
||||
updated_at=datetime.datetime(2025, 1, 1, 0, 0, 0),
|
||||
start_time=datetime.datetime(2025, 1, 1, 0, 0, 0),
|
||||
command_line="ra-aid test",
|
||||
program_version="1.0.0",
|
||||
machine_info={"os": "test"}
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_sessions():
|
||||
"""Return a list of mock sessions for testing."""
|
||||
return [
|
||||
SessionModel(
|
||||
id=i+1,
|
||||
created_at=datetime.datetime(2025, 1, i+1, 0, 0, 0),
|
||||
updated_at=datetime.datetime(2025, 1, i+1, 0, 0, 0),
|
||||
start_time=datetime.datetime(2025, 1, i+1, 0, 0, 0),
|
||||
command_line=f"ra-aid test{i+1}",
|
||||
program_version="1.0.0",
|
||||
machine_info={"index": i}
|
||||
)
|
||||
for i in range(15)
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_repo(mock_session, mock_sessions):
|
||||
"""Create a mock repository with predefined responses."""
|
||||
repo = MagicMock()
|
||||
repo.get.return_value = mock_session
|
||||
repo.get_all.return_value = (mock_sessions[:10], len(mock_sessions))
|
||||
repo.create_session.return_value = mock_session
|
||||
|
||||
# Add behavior for custom parameters
|
||||
def get_with_id(session_id):
|
||||
if session_id == 999999:
|
||||
return None
|
||||
for session in mock_sessions:
|
||||
if session.id == session_id:
|
||||
return session
|
||||
return mock_session
|
||||
|
||||
def get_all_with_pagination(offset=0, limit=10):
|
||||
total = len(mock_sessions)
|
||||
sorted_sessions = sorted(mock_sessions, key=lambda s: s.id, reverse=True)
|
||||
return sorted_sessions[offset:offset+limit], total
|
||||
|
||||
def create_with_metadata(metadata=None):
|
||||
if metadata is None:
|
||||
return SessionModel(
|
||||
id=16,
|
||||
created_at=datetime.datetime.now(),
|
||||
updated_at=datetime.datetime.now(),
|
||||
start_time=datetime.datetime.now(),
|
||||
command_line="ra-aid test-null",
|
||||
program_version="1.0.0",
|
||||
machine_info=None
|
||||
)
|
||||
return SessionModel(
|
||||
id=16,
|
||||
created_at=datetime.datetime.now(),
|
||||
updated_at=datetime.datetime.now(),
|
||||
start_time=datetime.datetime.now(),
|
||||
command_line="ra-aid test-custom",
|
||||
program_version="1.0.0",
|
||||
machine_info=metadata
|
||||
)
|
||||
|
||||
repo.get.side_effect = get_with_id
|
||||
repo.get_all.side_effect = get_all_with_pagination
|
||||
repo.create_session.side_effect = create_with_metadata
|
||||
|
||||
return repo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(mock_repo):
|
||||
"""Create a TestClient with the API and dependency overrides."""
|
||||
# Override the dependency to use our mock repository
|
||||
app.dependency_overrides[get_repository] = lambda: mock_repo
|
||||
|
||||
# Create a test client
|
||||
with TestClient(app) as test_client:
|
||||
yield test_client
|
||||
|
||||
# Clean up the dependency override
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_metadata() -> Dict[str, Any]:
|
||||
"""Return sample metadata for session creation."""
|
||||
return {
|
||||
"os": "Test OS",
|
||||
"version": "1.0.0",
|
||||
"environment": "test",
|
||||
"cpu_cores": 4,
|
||||
"memory_gb": 16,
|
||||
"additional_info": {
|
||||
"gpu": "Test GPU",
|
||||
"display_resolution": "1920x1080"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def test_create_session_with_metadata(client, mock_repo, sample_metadata):
|
||||
"""Test creating a session with metadata through the API endpoint."""
|
||||
# Send request to create a session with metadata
|
||||
response = client.post(
|
||||
"/v1/sessions",
|
||||
json={"metadata": sample_metadata}
|
||||
)
|
||||
|
||||
# Verify response status code and structure
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
|
||||
# Verify the session was created with the expected fields
|
||||
assert data["id"] is not None
|
||||
assert data["command_line"] is not None
|
||||
assert data["program_version"] is not None
|
||||
assert data["created_at"] is not None
|
||||
assert data["updated_at"] is not None
|
||||
assert data["start_time"] is not None
|
||||
|
||||
# Verify metadata was passed correctly to the repository
|
||||
mock_repo.create_session.assert_called_once_with(metadata=sample_metadata)
|
||||
|
||||
|
||||
def test_create_session_without_metadata(client, mock_repo):
|
||||
"""Test creating a session without metadata through the API endpoint."""
|
||||
# Send request without a body
|
||||
response = client.post("/v1/sessions")
|
||||
|
||||
# Verify response status code and structure
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
|
||||
# Verify the session was created with the expected fields
|
||||
assert data["id"] is not None
|
||||
assert data["command_line"] is not None
|
||||
assert data["program_version"] is not None
|
||||
|
||||
# Verify correct parameters were passed to the repository
|
||||
mock_repo.create_session.assert_called_once_with(metadata=None)
|
||||
|
||||
|
||||
def test_get_session_by_id(client):
|
||||
"""Test retrieving a session by ID through the API endpoint."""
|
||||
# Use a completely isolated, standalone test
|
||||
|
||||
# For this test, let's focus on verifying the core functionality:
|
||||
# 1. The API endpoint receives a request for a specific session ID
|
||||
# 2. It calls the repository with that ID
|
||||
# 3. It returns a properly formatted response
|
||||
|
||||
mock_repo = MagicMock()
|
||||
|
||||
# Create a test session with a simple machine_info to reduce serialization issues
|
||||
test_session = SessionModel(
|
||||
id=42,
|
||||
created_at=datetime.datetime(2025, 1, 1, 0, 0, 0),
|
||||
updated_at=datetime.datetime(2025, 1, 1, 0, 0, 0),
|
||||
start_time=datetime.datetime(2025, 1, 1, 0, 0, 0),
|
||||
command_line="ra-aid specific-test",
|
||||
program_version="1.0.0-test",
|
||||
machine_info=None # Use None to avoid serialization issues
|
||||
)
|
||||
|
||||
# Configure the mock
|
||||
mock_repo.get.return_value = test_session
|
||||
|
||||
# Override the dependency
|
||||
app.dependency_overrides[get_repository] = lambda: mock_repo
|
||||
|
||||
try:
|
||||
# Retrieve the session through the API
|
||||
response = client.get(f"/v1/sessions/{test_session.id}")
|
||||
|
||||
# Verify response status code
|
||||
assert response.status_code == 200
|
||||
|
||||
# Parse the response data
|
||||
data = response.json()
|
||||
|
||||
# Print for debugging
|
||||
import json
|
||||
print("Response JSON:", json.dumps(data, indent=2))
|
||||
|
||||
# Verify the returned session matches what we expected
|
||||
assert data["id"] == test_session.id
|
||||
assert data["command_line"] == test_session.command_line
|
||||
assert data["program_version"] == test_session.program_version
|
||||
assert data["machine_info"] is None
|
||||
|
||||
# Verify the repository was called with the correct ID
|
||||
mock_repo.get.assert_called_once_with(test_session.id)
|
||||
finally:
|
||||
# Clean up the override
|
||||
if get_repository in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_repository]
|
||||
|
||||
|
||||
def test_get_session_not_found(client, mock_repo):
|
||||
"""Test the error handling when requesting a non-existent session."""
|
||||
# Try to get a session with a non-existent ID
|
||||
response = client.get("/v1/sessions/999999")
|
||||
|
||||
# Verify response status code and error message
|
||||
assert response.status_code == 404
|
||||
assert "not found" in response.json()["detail"]
|
||||
|
||||
# Verify the repository was called with the correct ID
|
||||
mock_repo.get.assert_called_with(999999)
|
||||
|
||||
|
||||
def test_list_sessions_empty(client, mock_repo):
|
||||
"""Test listing sessions when no sessions exist."""
|
||||
# Reset the mock first to clear any previous calls/side effects
|
||||
mock_repo.reset_mock()
|
||||
|
||||
# Configure the mock to return empty results
|
||||
mock_repo.get_all.side_effect = None # Clear any previous side effects
|
||||
mock_repo.get_all.return_value = ([], 0)
|
||||
|
||||
# Get the list of sessions
|
||||
response = client.get("/v1/sessions")
|
||||
|
||||
# Verify response status code and structure
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
# Verify the pagination response
|
||||
assert data["total"] == 0
|
||||
assert len(data["items"]) == 0
|
||||
assert data["limit"] == 10
|
||||
assert data["offset"] == 0
|
||||
|
||||
# Verify the repository was called with the correct parameters
|
||||
mock_repo.get_all.assert_called_with(offset=0, limit=10)
|
||||
|
||||
|
||||
def test_list_sessions_with_pagination(client, mock_repo, mock_sessions):
|
||||
"""Test listing sessions with pagination parameters."""
|
||||
# Set up the repository mock to return specific results for different pagination parameters
|
||||
default_result = (mock_sessions[:10], len(mock_sessions))
|
||||
limit_5_result = (mock_sessions[:5], len(mock_sessions))
|
||||
offset_10_result = (mock_sessions[10:], len(mock_sessions))
|
||||
offset_5_limit_3_result = (mock_sessions[5:8], len(mock_sessions))
|
||||
|
||||
pagination_responses = {
|
||||
(0, 10): default_result,
|
||||
(0, 5): limit_5_result,
|
||||
(10, 10): offset_10_result,
|
||||
(5, 3): offset_5_limit_3_result
|
||||
}
|
||||
|
||||
def mock_get_all(offset=0, limit=10):
|
||||
return pagination_responses.get((offset, limit), ([], 0))
|
||||
|
||||
mock_repo.get_all.side_effect = mock_get_all
|
||||
|
||||
# Test default pagination (limit=10, offset=0)
|
||||
response = client.get("/v1/sessions")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == len(mock_sessions)
|
||||
assert len(data["items"]) == 10
|
||||
assert data["limit"] == 10
|
||||
assert data["offset"] == 0
|
||||
|
||||
# Test with custom limit
|
||||
response = client.get("/v1/sessions?limit=5")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == len(mock_sessions)
|
||||
assert len(data["items"]) == 5
|
||||
assert data["limit"] == 5
|
||||
assert data["offset"] == 0
|
||||
|
||||
# Test with custom offset
|
||||
response = client.get("/v1/sessions?offset=10")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == len(mock_sessions)
|
||||
assert len(data["items"]) == 5 # Only 5 items left after offset 10
|
||||
assert data["limit"] == 10
|
||||
assert data["offset"] == 10
|
||||
|
||||
# Test with both custom limit and offset
|
||||
response = client.get("/v1/sessions?limit=3&offset=5")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == len(mock_sessions)
|
||||
assert len(data["items"]) == 3
|
||||
assert data["limit"] == 3
|
||||
assert data["offset"] == 5
|
||||
|
||||
|
||||
def test_list_sessions_invalid_parameters(client):
|
||||
"""Test error handling for invalid pagination parameters."""
|
||||
# Test with negative offset
|
||||
response = client.get("/v1/sessions?offset=-1")
|
||||
assert response.status_code == 422
|
||||
|
||||
# Test with negative limit
|
||||
response = client.get("/v1/sessions?limit=-5")
|
||||
assert response.status_code == 422
|
||||
|
||||
# Test with zero limit
|
||||
response = client.get("/v1/sessions?limit=0")
|
||||
assert response.status_code == 422
|
||||
|
||||
# Test with limit exceeding maximum
|
||||
response = client.get("/v1/sessions?limit=101")
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_metadata_validation(client, mock_repo):
|
||||
"""Test validation for different metadata formats in session creation."""
|
||||
# Create test sessions with different metadata
|
||||
null_metadata_session = SessionModel(
|
||||
id=20,
|
||||
created_at=datetime.datetime.now(),
|
||||
updated_at=datetime.datetime.now(),
|
||||
start_time=datetime.datetime.now(),
|
||||
command_line="ra-aid test-null",
|
||||
program_version="1.0.0",
|
||||
machine_info=None
|
||||
)
|
||||
|
||||
empty_dict_metadata_session = SessionModel(
|
||||
id=21,
|
||||
created_at=datetime.datetime.now(),
|
||||
updated_at=datetime.datetime.now(),
|
||||
start_time=datetime.datetime.now(),
|
||||
command_line="ra-aid test-empty",
|
||||
program_version="1.0.0",
|
||||
machine_info={}
|
||||
)
|
||||
|
||||
complex_metadata_session = SessionModel(
|
||||
id=22,
|
||||
created_at=datetime.datetime.now(),
|
||||
updated_at=datetime.datetime.now(),
|
||||
start_time=datetime.datetime.now(),
|
||||
command_line="ra-aid test-complex",
|
||||
program_version="1.0.0",
|
||||
machine_info={"level1": {"level2": {"level3": [1, 2, 3, {"key": "value"}]}}}
|
||||
)
|
||||
|
||||
# Configure mock to return different sessions based on metadata
|
||||
def create_with_specific_metadata(metadata=None):
|
||||
if metadata is None:
|
||||
return null_metadata_session
|
||||
elif metadata == {}:
|
||||
return empty_dict_metadata_session
|
||||
elif isinstance(metadata, dict) and "level1" in metadata:
|
||||
return complex_metadata_session
|
||||
return SessionModel(
|
||||
id=23,
|
||||
created_at=datetime.datetime.now(),
|
||||
updated_at=datetime.datetime.now(),
|
||||
start_time=datetime.datetime.now(),
|
||||
command_line="ra-aid test-other",
|
||||
program_version="1.0.0",
|
||||
machine_info=metadata
|
||||
)
|
||||
|
||||
mock_repo.create_session.side_effect = create_with_specific_metadata
|
||||
|
||||
# Try to create a session with null metadata
|
||||
response = client.post(
|
||||
"/v1/sessions",
|
||||
json={"metadata": None}
|
||||
)
|
||||
|
||||
# This should work fine
|
||||
assert response.status_code == 201
|
||||
mock_repo.create_session.assert_called_with(metadata=None)
|
||||
|
||||
# Try to create a session with an empty metadata dict
|
||||
response = client.post(
|
||||
"/v1/sessions",
|
||||
json={"metadata": {}}
|
||||
)
|
||||
|
||||
# This should work fine
|
||||
assert response.status_code == 201
|
||||
mock_repo.create_session.assert_called_with(metadata={})
|
||||
|
||||
# Try to create a session with a complex nested metadata
|
||||
response = client.post(
|
||||
"/v1/sessions",
|
||||
json={"metadata": {
|
||||
"level1": {
|
||||
"level2": {
|
||||
"level3": [1, 2, 3, {"key": "value"}]
|
||||
}
|
||||
}
|
||||
}}
|
||||
)
|
||||
|
||||
# Verify the complex nested structure is preserved
|
||||
assert response.status_code == 201
|
||||
complex_metadata = {
|
||||
"level1": {
|
||||
"level2": {
|
||||
"level3": [1, 2, 3, {"key": "value"}]
|
||||
}
|
||||
}
|
||||
}
|
||||
mock_repo.create_session.assert_called_with(metadata=complex_metadata)
|
||||
|
||||
|
||||
def test_integration_workflow(client, mock_repo):
|
||||
"""Test a complete workflow of creating and retrieving sessions."""
|
||||
# Set up mock sessions for the workflow
|
||||
first_session = SessionModel(
|
||||
id=30,
|
||||
created_at=datetime.datetime.now(),
|
||||
updated_at=datetime.datetime.now(),
|
||||
start_time=datetime.datetime.now(),
|
||||
command_line="ra-aid workflow-1",
|
||||
program_version="1.0.0",
|
||||
machine_info={"workflow_test": True}
|
||||
)
|
||||
|
||||
second_session = SessionModel(
|
||||
id=31,
|
||||
created_at=datetime.datetime.now(),
|
||||
updated_at=datetime.datetime.now(),
|
||||
start_time=datetime.datetime.now(),
|
||||
command_line="ra-aid workflow-2",
|
||||
program_version="1.0.0",
|
||||
machine_info={"workflow_test": False, "second": True}
|
||||
)
|
||||
|
||||
# Configure mock for create_session
|
||||
create_calls = 0
|
||||
def create_session_for_workflow(metadata=None):
|
||||
nonlocal create_calls
|
||||
create_calls += 1
|
||||
if create_calls == 1:
|
||||
return first_session
|
||||
return second_session
|
||||
|
||||
mock_repo.create_session.side_effect = create_session_for_workflow
|
||||
|
||||
# Configure mock for get
|
||||
def get_session_for_workflow(session_id):
|
||||
if session_id == first_session.id:
|
||||
return first_session
|
||||
elif session_id == second_session.id:
|
||||
return second_session
|
||||
return None
|
||||
|
||||
mock_repo.get.side_effect = get_session_for_workflow
|
||||
|
||||
# Configure mock for get_all
|
||||
def get_all_for_workflow(offset=0, limit=10):
|
||||
if create_calls == 1:
|
||||
return [first_session], 1
|
||||
return [second_session, first_session], 2
|
||||
|
||||
mock_repo.get_all.side_effect = get_all_for_workflow
|
||||
|
||||
# 1. Create a session
|
||||
create_response = client.post(
|
||||
"/v1/sessions",
|
||||
json={"metadata": {"workflow_test": True}}
|
||||
)
|
||||
assert create_response.status_code == 201
|
||||
session_id = create_response.json()["id"]
|
||||
assert session_id == first_session.id
|
||||
|
||||
# 2. Retrieve the created session
|
||||
get_response = client.get(f"/v1/sessions/{session_id}")
|
||||
assert get_response.status_code == 200
|
||||
assert get_response.json()["id"] == session_id
|
||||
|
||||
# 3. List all sessions and verify the created one is included
|
||||
list_response = client.get("/v1/sessions")
|
||||
assert list_response.status_code == 200
|
||||
items = list_response.json()["items"]
|
||||
assert len(items) == 1
|
||||
assert items[0]["id"] == session_id
|
||||
|
||||
# 4. Create a second session
|
||||
create_response2 = client.post(
|
||||
"/v1/sessions",
|
||||
json={"metadata": {"workflow_test": False, "second": True}}
|
||||
)
|
||||
assert create_response2.status_code == 201
|
||||
session_id2 = create_response2.json()["id"]
|
||||
assert session_id2 == second_session.id
|
||||
|
||||
# 5. List all sessions and verify both sessions are included
|
||||
list_response = client.get("/v1/sessions")
|
||||
assert list_response.status_code == 200
|
||||
data = list_response.json()
|
||||
assert data["total"] == 2
|
||||
items = data["items"]
|
||||
assert len(items) == 2
|
||||
assert items[0]["id"] == session_id2
|
||||
assert items[1]["id"] == session_id
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue