session model

This commit is contained in:
AI Christianson 2025-03-11 20:11:14 -04:00
parent 376fe18b83
commit c8fbd942ac
9 changed files with 693 additions and 3 deletions

View File

@ -56,6 +56,9 @@ from ra_aid.database.repositories.research_note_repository import (
from ra_aid.database.repositories.trajectory_repository import (
TrajectoryRepositoryManager, get_trajectory_repository
)
from ra_aid.database.repositories.session_repository import (
SessionRepositoryManager, get_session_repository
)
from ra_aid.database.repositories.related_files_repository import (
RelatedFilesRepositoryManager
)
@ -526,7 +529,8 @@ def main():
env_discovery.discover()
env_data = env_discovery.format_markdown()
with KeyFactRepositoryManager(db) as key_fact_repo, \
with SessionRepositoryManager(db) as session_repo, \
KeyFactRepositoryManager(db) as key_fact_repo, \
KeySnippetRepositoryManager(db) as key_snippet_repo, \
HumanInputRepositoryManager(db) as human_input_repo, \
ResearchNoteRepositoryManager(db) as research_note_repo, \
@ -536,6 +540,7 @@ def main():
ConfigRepositoryManager(config) as config_repo, \
EnvInvManager(env_data) as env_inv:
# This initializes all repositories and makes them available via their respective get methods
logger.debug("Initialized SessionRepository")
logger.debug("Initialized KeyFactRepository")
logger.debug("Initialized KeySnippetRepository")
logger.debug("Initialized HumanInputRepository")
@ -545,6 +550,10 @@ def main():
logger.debug("Initialized WorkLogRepository")
logger.debug("Initialized ConfigRepository")
logger.debug("Initialized Environment Inventory")
# Create a new session for this program run
logger.debug("Initializing new session")
session_repo.create_session()
# Check dependencies before proceeding
check_dependencies()

View File

@ -42,8 +42,8 @@ def initialize_database():
# to avoid circular imports
# Note: This import needs to be here, not at the top level
try:
from ra_aid.database.models import KeyFact, KeySnippet, HumanInput, ResearchNote, Trajectory
db.create_tables([KeyFact, KeySnippet, HumanInput, ResearchNote, Trajectory], safe=True)
from ra_aid.database.models import KeyFact, KeySnippet, HumanInput, ResearchNote, Trajectory, Session
db.create_tables([KeyFact, KeySnippet, HumanInput, ResearchNote, Trajectory, Session], safe=True)
logger.debug("Ensured database tables exist")
except Exception as e:
logger.error(f"Error creating tables: {str(e)}")
@ -99,6 +99,25 @@ class BaseModel(peewee.Model):
raise
class Session(BaseModel):
"""
Model representing a session stored in the database.
Sessions track information about each program run, providing a way to group
related records like human inputs, trajectories, and key facts.
Each session record captures details about when the program was started,
what command line arguments were used, and environment information.
"""
start_time = peewee.DateTimeField(default=datetime.datetime.now)
command_line = peewee.TextField(null=True)
program_version = peewee.TextField(null=True)
machine_info = peewee.TextField(null=True) # JSON-encoded machine information
class Meta:
table_name = "session"
class HumanInput(BaseModel):
"""
Model representing human input stored in the database.
@ -109,6 +128,7 @@ class HumanInput(BaseModel):
"""
content = peewee.TextField()
source = peewee.TextField() # 'cli', 'chat', or 'hil'
session = peewee.ForeignKeyField(Session, backref='human_inputs', null=True)
# created_at and updated_at are inherited from BaseModel
class Meta:
@ -124,6 +144,7 @@ class KeyFact(BaseModel):
"""
content = peewee.TextField()
human_input = peewee.ForeignKeyField(HumanInput, backref='key_facts', null=True)
session = peewee.ForeignKeyField(Session, backref='key_facts', null=True)
# created_at and updated_at are inherited from BaseModel
class Meta:
@ -143,6 +164,7 @@ class KeySnippet(BaseModel):
snippet = peewee.TextField()
description = peewee.TextField(null=True)
human_input = peewee.ForeignKeyField(HumanInput, backref='key_snippets', null=True)
session = peewee.ForeignKeyField(Session, backref='key_snippets', null=True)
# created_at and updated_at are inherited from BaseModel
class Meta:
@ -159,6 +181,7 @@ class ResearchNote(BaseModel):
"""
content = peewee.TextField()
human_input = peewee.ForeignKeyField(HumanInput, backref='research_notes', null=True)
session = peewee.ForeignKeyField(Session, backref='research_notes', null=True)
# created_at and updated_at are inherited from BaseModel
class Meta:
@ -193,6 +216,7 @@ class Trajectory(BaseModel):
error_message = peewee.TextField(null=True) # The error message
error_type = peewee.TextField(null=True) # The type/class of the error
error_details = peewee.TextField(null=True) # Additional error details like stack traces or context
session = peewee.ForeignKeyField(Session, backref='trajectories', null=True)
# created_at and updated_at are inherited from BaseModel
class Meta:

View File

@ -0,0 +1,243 @@
"""
Session repository implementation for database access.
This module provides a repository implementation for the Session model,
following the repository pattern for data access abstraction. It handles
operations for storing and retrieving application session information.
"""
from typing import Dict, List, Optional, Any
import contextvars
import datetime
import json
import logging
import sys
import peewee
from ra_aid.database.models import Session
from ra_aid.__version__ import __version__
from ra_aid.logging_config import get_logger
logger = get_logger(__name__)
# Create contextvar to hold the SessionRepository instance
session_repo_var = contextvars.ContextVar("session_repo", default=None)
class SessionRepositoryManager:
"""
Context manager for SessionRepository.
This class provides a context manager interface for SessionRepository,
using the contextvars approach for thread safety.
Example:
with DatabaseManager() as db:
with SessionRepositoryManager(db) as repo:
# Use the repository
session = repo.create_session()
current_session = repo.get_current_session()
"""
def __init__(self, db):
"""
Initialize the SessionRepositoryManager.
Args:
db: Database connection to use (required)
"""
self.db = db
def __enter__(self) -> 'SessionRepository':
"""
Initialize the SessionRepository and return it.
Returns:
SessionRepository: The initialized repository
"""
repo = SessionRepository(self.db)
session_repo_var.set(repo)
return repo
def __exit__(
self,
exc_type: Optional[type],
exc_val: Optional[Exception],
exc_tb: Optional[object],
) -> None:
"""
Reset the repository when exiting the context.
Args:
exc_type: The exception type if an exception was raised
exc_val: The exception value if an exception was raised
exc_tb: The traceback if an exception was raised
"""
# Reset the contextvar to None
session_repo_var.set(None)
# Don't suppress exceptions
return False
def get_session_repository() -> 'SessionRepository':
"""
Get the current SessionRepository instance.
Returns:
SessionRepository: The current repository instance
Raises:
RuntimeError: If no repository has been initialized with SessionRepositoryManager
"""
repo = session_repo_var.get()
if repo is None:
raise RuntimeError(
"No SessionRepository available. "
"Make sure to initialize one with SessionRepositoryManager first."
)
return repo
class SessionRepository:
"""
Repository for handling Session records in the database.
This class provides methods for creating, retrieving, and managing Session records.
It abstracts away the database operations and provides a clean interface for working
with Session entities.
"""
def __init__(self, db):
"""
Initialize the SessionRepository.
Args:
db: Database connection to use (required)
"""
if db is None:
raise ValueError("Database connection is required for SessionRepository")
self.db = db
self.current_session = None
def create_session(self, metadata: Optional[Dict[str, Any]] = None) -> Session:
"""
Create a new session record in the database.
Args:
metadata: Optional dictionary of additional metadata to store with the session
Returns:
Session: The newly created session instance
Raises:
peewee.DatabaseError: If there's an error creating the record
"""
try:
# Get command line arguments
command_line = " ".join(sys.argv)
# Get program version
program_version = __version__
# JSON encode metadata if provided
machine_info = json.dumps(metadata) if metadata is not None else None
session = Session.create(
start_time=datetime.datetime.now(),
command_line=command_line,
program_version=program_version,
machine_info=machine_info
)
# Store the current session
self.current_session = session
logger.debug(f"Created new session with ID {session.id}")
return session
except peewee.DatabaseError as e:
logger.error(f"Failed to create session record: {str(e)}")
raise
def get_current_session(self) -> Optional[Session]:
"""
Get the current active session.
If no session has been created in this repository instance,
retrieves the most recent session from the database.
Returns:
Optional[Session]: The current session or None if no sessions exist
"""
if self.current_session is not None:
return self.current_session
try:
# Find the most recent session
session = Session.select().order_by(Session.created_at.desc()).first()
if session:
self.current_session = session
return session
except peewee.DatabaseError as e:
logger.error(f"Failed to get current session: {str(e)}")
return None
def get_current_session_id(self) -> Optional[int]:
"""
Get the ID of the current active session.
Returns:
Optional[int]: The ID of the current session or None if no session exists
"""
session = self.get_current_session()
return session.id if session else None
def get(self, session_id: int) -> Optional[Session]:
"""
Get a session by its ID.
Args:
session_id: The ID of the session to retrieve
Returns:
Optional[Session]: The session with the given ID or None if not found
"""
try:
return Session.get_or_none(Session.id == session_id)
except peewee.DatabaseError as e:
logger.error(f"Database error getting session {session_id}: {str(e)}")
return None
def get_all(self) -> List[Session]:
"""
Get all sessions from the database.
Returns:
List[Session]: List of all sessions
"""
try:
return list(Session.select().order_by(Session.created_at.desc()))
except peewee.DatabaseError as e:
logger.error(f"Failed to get all sessions: {str(e)}")
return []
def get_recent(self, limit: int = 10) -> List[Session]:
"""
Get the most recent sessions from the database.
Args:
limit: Maximum number of sessions to return (default: 10)
Returns:
List[Session]: List of the most recent sessions
"""
try:
return list(
Session.select()
.order_by(Session.created_at.desc())
.limit(limit)
)
except peewee.DatabaseError as e:
logger.error(f"Failed to get recent sessions: {str(e)}")
return []

View File

@ -0,0 +1,79 @@
"""Peewee migrations -- 008_20250311_191232_add_session_model.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Create the session table for storing application session information."""
table_exists = False
# Check if the table already exists
try:
database.execute_sql("SELECT id FROM session LIMIT 1")
# If we reach here, the table exists
table_exists = True
except pw.OperationalError:
# Table doesn't exist, safe to create
pass
# Create the Session model - this registers it in migrator.orm as 'Session'
@migrator.create_model
class Session(pw.Model):
id = pw.AutoField()
created_at = pw.DateTimeField()
updated_at = pw.DateTimeField()
start_time = pw.DateTimeField()
command_line = pw.TextField(null=True)
program_version = pw.TextField(null=True)
machine_info = pw.TextField(null=True)
class Meta:
table_name = "session"
# FIX: Explicitly register the model under the lowercase table name key
# This ensures that later migrations can access it via either:
# - migrator.orm['Session'] (class name)
# - migrator.orm['session'] (table name)
if 'Session' in migrator.orm:
migrator.orm['session'] = migrator.orm['Session']
# Only return after model registration is complete
if table_exists:
return
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Remove the session table."""
migrator.remove_model('session')

View File

@ -0,0 +1,67 @@
"""Peewee migrations -- 009_20250311_191517_add_session_fk_to_human_input.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Add session foreign key to HumanInput table."""
# Get the Session model from migrator.orm
Session = migrator.orm['session']
# Check if the column already exists
try:
database.execute_sql("SELECT session_id FROM human_input LIMIT 1")
# If we reach here, the column exists
return
except pw.OperationalError:
# Column doesn't exist, safe to add
pass
# Add the session_id foreign key column
migrator.add_fields(
'human_input',
session=pw.ForeignKeyField(
Session,
null=True,
field='id',
on_delete='CASCADE'
)
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Remove session foreign key from HumanInput table."""
migrator.remove_fields('human_input', 'session')

View File

@ -0,0 +1,67 @@
"""Peewee migrations -- 010_20250311_191617_add_session_fk_to_key_fact.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Add session foreign key to KeyFact table."""
# Get the Session model from migrator.orm
Session = migrator.orm['session']
# Check if the column already exists
try:
database.execute_sql("SELECT session_id FROM key_fact LIMIT 1")
# If we reach here, the column exists
return
except pw.OperationalError:
# Column doesn't exist, safe to add
pass
# Add the session_id foreign key column
migrator.add_fields(
'key_fact',
session=pw.ForeignKeyField(
Session,
null=True,
field='id',
on_delete='CASCADE'
)
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Remove session foreign key from KeyFact table."""
migrator.remove_fields('key_fact', 'session')

View File

@ -0,0 +1,67 @@
"""Peewee migrations -- 011_20250311_191732_add_session_fk_to_key_snippet.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Add session foreign key to KeySnippet table."""
# Get the Session model from migrator.orm
Session = migrator.orm['session']
# Check if the column already exists
try:
database.execute_sql("SELECT session_id FROM key_snippet LIMIT 1")
# If we reach here, the column exists
return
except pw.OperationalError:
# Column doesn't exist, safe to add
pass
# Add the session_id foreign key column
migrator.add_fields(
'key_snippet',
session=pw.ForeignKeyField(
Session,
null=True,
field='id',
on_delete='CASCADE'
)
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Remove session foreign key from KeySnippet table."""
migrator.remove_fields('key_snippet', 'session')

View File

@ -0,0 +1,67 @@
"""Peewee migrations -- 012_20250311_191832_add_session_fk_to_research_note.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Add session foreign key to ResearchNote table."""
# Get the Session model from migrator.orm
Session = migrator.orm['session']
# Check if the column already exists
try:
database.execute_sql("SELECT session_id FROM research_note LIMIT 1")
# If we reach here, the column exists
return
except pw.OperationalError:
# Column doesn't exist, safe to add
pass
# Add the session_id foreign key column
migrator.add_fields(
'research_note',
session=pw.ForeignKeyField(
Session,
null=True,
field='id',
on_delete='CASCADE'
)
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Remove session foreign key from ResearchNote table."""
migrator.remove_fields('research_note', 'session')

View File

@ -0,0 +1,67 @@
"""Peewee migrations -- 013_20250311_191701_add_session_fk_to_trajectory.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Add session foreign key to Trajectory table."""
# Get the Session model from migrator.orm
Session = migrator.orm['session']
# Check if the column already exists
try:
database.execute_sql("SELECT session_id FROM trajectory LIMIT 1")
# If we reach here, the column exists
return
except pw.OperationalError:
# Column doesn't exist, safe to add
pass
# Add the session_id foreign key column
migrator.add_fields(
'trajectory',
session=pw.ForeignKeyField(
Session,
null=True,
field='id',
on_delete='CASCADE'
)
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Remove session foreign key from Trajectory table."""
migrator.remove_fields('trajectory', 'session')