session model
This commit is contained in:
parent
376fe18b83
commit
c8fbd942ac
|
|
@ -56,6 +56,9 @@ from ra_aid.database.repositories.research_note_repository import (
|
|||
from ra_aid.database.repositories.trajectory_repository import (
|
||||
TrajectoryRepositoryManager, get_trajectory_repository
|
||||
)
|
||||
from ra_aid.database.repositories.session_repository import (
|
||||
SessionRepositoryManager, get_session_repository
|
||||
)
|
||||
from ra_aid.database.repositories.related_files_repository import (
|
||||
RelatedFilesRepositoryManager
|
||||
)
|
||||
|
|
@ -526,7 +529,8 @@ def main():
|
|||
env_discovery.discover()
|
||||
env_data = env_discovery.format_markdown()
|
||||
|
||||
with KeyFactRepositoryManager(db) as key_fact_repo, \
|
||||
with SessionRepositoryManager(db) as session_repo, \
|
||||
KeyFactRepositoryManager(db) as key_fact_repo, \
|
||||
KeySnippetRepositoryManager(db) as key_snippet_repo, \
|
||||
HumanInputRepositoryManager(db) as human_input_repo, \
|
||||
ResearchNoteRepositoryManager(db) as research_note_repo, \
|
||||
|
|
@ -536,6 +540,7 @@ def main():
|
|||
ConfigRepositoryManager(config) as config_repo, \
|
||||
EnvInvManager(env_data) as env_inv:
|
||||
# This initializes all repositories and makes them available via their respective get methods
|
||||
logger.debug("Initialized SessionRepository")
|
||||
logger.debug("Initialized KeyFactRepository")
|
||||
logger.debug("Initialized KeySnippetRepository")
|
||||
logger.debug("Initialized HumanInputRepository")
|
||||
|
|
@ -546,6 +551,10 @@ def main():
|
|||
logger.debug("Initialized ConfigRepository")
|
||||
logger.debug("Initialized Environment Inventory")
|
||||
|
||||
# Create a new session for this program run
|
||||
logger.debug("Initializing new session")
|
||||
session_repo.create_session()
|
||||
|
||||
# Check dependencies before proceeding
|
||||
check_dependencies()
|
||||
|
||||
|
|
|
|||
|
|
@ -42,8 +42,8 @@ def initialize_database():
|
|||
# to avoid circular imports
|
||||
# Note: This import needs to be here, not at the top level
|
||||
try:
|
||||
from ra_aid.database.models import KeyFact, KeySnippet, HumanInput, ResearchNote, Trajectory
|
||||
db.create_tables([KeyFact, KeySnippet, HumanInput, ResearchNote, Trajectory], safe=True)
|
||||
from ra_aid.database.models import KeyFact, KeySnippet, HumanInput, ResearchNote, Trajectory, Session
|
||||
db.create_tables([KeyFact, KeySnippet, HumanInput, ResearchNote, Trajectory, Session], safe=True)
|
||||
logger.debug("Ensured database tables exist")
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating tables: {str(e)}")
|
||||
|
|
@ -99,6 +99,25 @@ class BaseModel(peewee.Model):
|
|||
raise
|
||||
|
||||
|
||||
class Session(BaseModel):
|
||||
"""
|
||||
Model representing a session stored in the database.
|
||||
|
||||
Sessions track information about each program run, providing a way to group
|
||||
related records like human inputs, trajectories, and key facts.
|
||||
|
||||
Each session record captures details about when the program was started,
|
||||
what command line arguments were used, and environment information.
|
||||
"""
|
||||
start_time = peewee.DateTimeField(default=datetime.datetime.now)
|
||||
command_line = peewee.TextField(null=True)
|
||||
program_version = peewee.TextField(null=True)
|
||||
machine_info = peewee.TextField(null=True) # JSON-encoded machine information
|
||||
|
||||
class Meta:
|
||||
table_name = "session"
|
||||
|
||||
|
||||
class HumanInput(BaseModel):
|
||||
"""
|
||||
Model representing human input stored in the database.
|
||||
|
|
@ -109,6 +128,7 @@ class HumanInput(BaseModel):
|
|||
"""
|
||||
content = peewee.TextField()
|
||||
source = peewee.TextField() # 'cli', 'chat', or 'hil'
|
||||
session = peewee.ForeignKeyField(Session, backref='human_inputs', null=True)
|
||||
# created_at and updated_at are inherited from BaseModel
|
||||
|
||||
class Meta:
|
||||
|
|
@ -124,6 +144,7 @@ class KeyFact(BaseModel):
|
|||
"""
|
||||
content = peewee.TextField()
|
||||
human_input = peewee.ForeignKeyField(HumanInput, backref='key_facts', null=True)
|
||||
session = peewee.ForeignKeyField(Session, backref='key_facts', null=True)
|
||||
# created_at and updated_at are inherited from BaseModel
|
||||
|
||||
class Meta:
|
||||
|
|
@ -143,6 +164,7 @@ class KeySnippet(BaseModel):
|
|||
snippet = peewee.TextField()
|
||||
description = peewee.TextField(null=True)
|
||||
human_input = peewee.ForeignKeyField(HumanInput, backref='key_snippets', null=True)
|
||||
session = peewee.ForeignKeyField(Session, backref='key_snippets', null=True)
|
||||
# created_at and updated_at are inherited from BaseModel
|
||||
|
||||
class Meta:
|
||||
|
|
@ -159,6 +181,7 @@ class ResearchNote(BaseModel):
|
|||
"""
|
||||
content = peewee.TextField()
|
||||
human_input = peewee.ForeignKeyField(HumanInput, backref='research_notes', null=True)
|
||||
session = peewee.ForeignKeyField(Session, backref='research_notes', null=True)
|
||||
# created_at and updated_at are inherited from BaseModel
|
||||
|
||||
class Meta:
|
||||
|
|
@ -193,6 +216,7 @@ class Trajectory(BaseModel):
|
|||
error_message = peewee.TextField(null=True) # The error message
|
||||
error_type = peewee.TextField(null=True) # The type/class of the error
|
||||
error_details = peewee.TextField(null=True) # Additional error details like stack traces or context
|
||||
session = peewee.ForeignKeyField(Session, backref='trajectories', null=True)
|
||||
# created_at and updated_at are inherited from BaseModel
|
||||
|
||||
class Meta:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,243 @@
|
|||
"""
|
||||
Session repository implementation for database access.
|
||||
|
||||
This module provides a repository implementation for the Session model,
|
||||
following the repository pattern for data access abstraction. It handles
|
||||
operations for storing and retrieving application session information.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Any
|
||||
import contextvars
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import peewee
|
||||
|
||||
from ra_aid.database.models import Session
|
||||
from ra_aid.__version__ import __version__
|
||||
from ra_aid.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Create contextvar to hold the SessionRepository instance
|
||||
session_repo_var = contextvars.ContextVar("session_repo", default=None)
|
||||
|
||||
|
||||
class SessionRepositoryManager:
|
||||
"""
|
||||
Context manager for SessionRepository.
|
||||
|
||||
This class provides a context manager interface for SessionRepository,
|
||||
using the contextvars approach for thread safety.
|
||||
|
||||
Example:
|
||||
with DatabaseManager() as db:
|
||||
with SessionRepositoryManager(db) as repo:
|
||||
# Use the repository
|
||||
session = repo.create_session()
|
||||
current_session = repo.get_current_session()
|
||||
"""
|
||||
|
||||
def __init__(self, db):
|
||||
"""
|
||||
Initialize the SessionRepositoryManager.
|
||||
|
||||
Args:
|
||||
db: Database connection to use (required)
|
||||
"""
|
||||
self.db = db
|
||||
|
||||
def __enter__(self) -> 'SessionRepository':
|
||||
"""
|
||||
Initialize the SessionRepository and return it.
|
||||
|
||||
Returns:
|
||||
SessionRepository: The initialized repository
|
||||
"""
|
||||
repo = SessionRepository(self.db)
|
||||
session_repo_var.set(repo)
|
||||
return repo
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[type],
|
||||
exc_val: Optional[Exception],
|
||||
exc_tb: Optional[object],
|
||||
) -> None:
|
||||
"""
|
||||
Reset the repository when exiting the context.
|
||||
|
||||
Args:
|
||||
exc_type: The exception type if an exception was raised
|
||||
exc_val: The exception value if an exception was raised
|
||||
exc_tb: The traceback if an exception was raised
|
||||
"""
|
||||
# Reset the contextvar to None
|
||||
session_repo_var.set(None)
|
||||
|
||||
# Don't suppress exceptions
|
||||
return False
|
||||
|
||||
|
||||
def get_session_repository() -> 'SessionRepository':
|
||||
"""
|
||||
Get the current SessionRepository instance.
|
||||
|
||||
Returns:
|
||||
SessionRepository: The current repository instance
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no repository has been initialized with SessionRepositoryManager
|
||||
"""
|
||||
repo = session_repo_var.get()
|
||||
if repo is None:
|
||||
raise RuntimeError(
|
||||
"No SessionRepository available. "
|
||||
"Make sure to initialize one with SessionRepositoryManager first."
|
||||
)
|
||||
return repo
|
||||
|
||||
|
||||
class SessionRepository:
|
||||
"""
|
||||
Repository for handling Session records in the database.
|
||||
|
||||
This class provides methods for creating, retrieving, and managing Session records.
|
||||
It abstracts away the database operations and provides a clean interface for working
|
||||
with Session entities.
|
||||
"""
|
||||
|
||||
def __init__(self, db):
|
||||
"""
|
||||
Initialize the SessionRepository.
|
||||
|
||||
Args:
|
||||
db: Database connection to use (required)
|
||||
"""
|
||||
if db is None:
|
||||
raise ValueError("Database connection is required for SessionRepository")
|
||||
self.db = db
|
||||
self.current_session = None
|
||||
|
||||
def create_session(self, metadata: Optional[Dict[str, Any]] = None) -> Session:
|
||||
"""
|
||||
Create a new session record in the database.
|
||||
|
||||
Args:
|
||||
metadata: Optional dictionary of additional metadata to store with the session
|
||||
|
||||
Returns:
|
||||
Session: The newly created session instance
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error creating the record
|
||||
"""
|
||||
try:
|
||||
# Get command line arguments
|
||||
command_line = " ".join(sys.argv)
|
||||
|
||||
# Get program version
|
||||
program_version = __version__
|
||||
|
||||
# JSON encode metadata if provided
|
||||
machine_info = json.dumps(metadata) if metadata is not None else None
|
||||
|
||||
session = Session.create(
|
||||
start_time=datetime.datetime.now(),
|
||||
command_line=command_line,
|
||||
program_version=program_version,
|
||||
machine_info=machine_info
|
||||
)
|
||||
|
||||
# Store the current session
|
||||
self.current_session = session
|
||||
|
||||
logger.debug(f"Created new session with ID {session.id}")
|
||||
return session
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to create session record: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_current_session(self) -> Optional[Session]:
|
||||
"""
|
||||
Get the current active session.
|
||||
|
||||
If no session has been created in this repository instance,
|
||||
retrieves the most recent session from the database.
|
||||
|
||||
Returns:
|
||||
Optional[Session]: The current session or None if no sessions exist
|
||||
"""
|
||||
if self.current_session is not None:
|
||||
return self.current_session
|
||||
|
||||
try:
|
||||
# Find the most recent session
|
||||
session = Session.select().order_by(Session.created_at.desc()).first()
|
||||
if session:
|
||||
self.current_session = session
|
||||
return session
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to get current session: {str(e)}")
|
||||
return None
|
||||
|
||||
def get_current_session_id(self) -> Optional[int]:
|
||||
"""
|
||||
Get the ID of the current active session.
|
||||
|
||||
Returns:
|
||||
Optional[int]: The ID of the current session or None if no session exists
|
||||
"""
|
||||
session = self.get_current_session()
|
||||
return session.id if session else None
|
||||
|
||||
def get(self, session_id: int) -> Optional[Session]:
|
||||
"""
|
||||
Get a session by its ID.
|
||||
|
||||
Args:
|
||||
session_id: The ID of the session to retrieve
|
||||
|
||||
Returns:
|
||||
Optional[Session]: The session with the given ID or None if not found
|
||||
"""
|
||||
try:
|
||||
return Session.get_or_none(Session.id == session_id)
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Database error getting session {session_id}: {str(e)}")
|
||||
return None
|
||||
|
||||
def get_all(self) -> List[Session]:
|
||||
"""
|
||||
Get all sessions from the database.
|
||||
|
||||
Returns:
|
||||
List[Session]: List of all sessions
|
||||
"""
|
||||
try:
|
||||
return list(Session.select().order_by(Session.created_at.desc()))
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to get all sessions: {str(e)}")
|
||||
return []
|
||||
|
||||
def get_recent(self, limit: int = 10) -> List[Session]:
|
||||
"""
|
||||
Get the most recent sessions from the database.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of sessions to return (default: 10)
|
||||
|
||||
Returns:
|
||||
List[Session]: List of the most recent sessions
|
||||
"""
|
||||
try:
|
||||
return list(
|
||||
Session.select()
|
||||
.order_by(Session.created_at.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to get recent sessions: {str(e)}")
|
||||
return []
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
"""Peewee migrations -- 008_20250311_191232_add_session_model.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Create the session table for storing application session information."""
|
||||
|
||||
table_exists = False
|
||||
# Check if the table already exists
|
||||
try:
|
||||
database.execute_sql("SELECT id FROM session LIMIT 1")
|
||||
# If we reach here, the table exists
|
||||
table_exists = True
|
||||
except pw.OperationalError:
|
||||
# Table doesn't exist, safe to create
|
||||
pass
|
||||
|
||||
# Create the Session model - this registers it in migrator.orm as 'Session'
|
||||
@migrator.create_model
|
||||
class Session(pw.Model):
|
||||
id = pw.AutoField()
|
||||
created_at = pw.DateTimeField()
|
||||
updated_at = pw.DateTimeField()
|
||||
start_time = pw.DateTimeField()
|
||||
command_line = pw.TextField(null=True)
|
||||
program_version = pw.TextField(null=True)
|
||||
machine_info = pw.TextField(null=True)
|
||||
|
||||
class Meta:
|
||||
table_name = "session"
|
||||
|
||||
# FIX: Explicitly register the model under the lowercase table name key
|
||||
# This ensures that later migrations can access it via either:
|
||||
# - migrator.orm['Session'] (class name)
|
||||
# - migrator.orm['session'] (table name)
|
||||
if 'Session' in migrator.orm:
|
||||
migrator.orm['session'] = migrator.orm['Session']
|
||||
|
||||
# Only return after model registration is complete
|
||||
if table_exists:
|
||||
return
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Remove the session table."""
|
||||
|
||||
migrator.remove_model('session')
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
"""Peewee migrations -- 009_20250311_191517_add_session_fk_to_human_input.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Add session foreign key to HumanInput table."""
|
||||
|
||||
# Get the Session model from migrator.orm
|
||||
Session = migrator.orm['session']
|
||||
|
||||
# Check if the column already exists
|
||||
try:
|
||||
database.execute_sql("SELECT session_id FROM human_input LIMIT 1")
|
||||
# If we reach here, the column exists
|
||||
return
|
||||
except pw.OperationalError:
|
||||
# Column doesn't exist, safe to add
|
||||
pass
|
||||
|
||||
# Add the session_id foreign key column
|
||||
migrator.add_fields(
|
||||
'human_input',
|
||||
session=pw.ForeignKeyField(
|
||||
Session,
|
||||
null=True,
|
||||
field='id',
|
||||
on_delete='CASCADE'
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Remove session foreign key from HumanInput table."""
|
||||
|
||||
migrator.remove_fields('human_input', 'session')
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
"""Peewee migrations -- 010_20250311_191617_add_session_fk_to_key_fact.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Add session foreign key to KeyFact table."""
|
||||
|
||||
# Get the Session model from migrator.orm
|
||||
Session = migrator.orm['session']
|
||||
|
||||
# Check if the column already exists
|
||||
try:
|
||||
database.execute_sql("SELECT session_id FROM key_fact LIMIT 1")
|
||||
# If we reach here, the column exists
|
||||
return
|
||||
except pw.OperationalError:
|
||||
# Column doesn't exist, safe to add
|
||||
pass
|
||||
|
||||
# Add the session_id foreign key column
|
||||
migrator.add_fields(
|
||||
'key_fact',
|
||||
session=pw.ForeignKeyField(
|
||||
Session,
|
||||
null=True,
|
||||
field='id',
|
||||
on_delete='CASCADE'
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Remove session foreign key from KeyFact table."""
|
||||
|
||||
migrator.remove_fields('key_fact', 'session')
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
"""Peewee migrations -- 011_20250311_191732_add_session_fk_to_key_snippet.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Add session foreign key to KeySnippet table."""
|
||||
|
||||
# Get the Session model from migrator.orm
|
||||
Session = migrator.orm['session']
|
||||
|
||||
# Check if the column already exists
|
||||
try:
|
||||
database.execute_sql("SELECT session_id FROM key_snippet LIMIT 1")
|
||||
# If we reach here, the column exists
|
||||
return
|
||||
except pw.OperationalError:
|
||||
# Column doesn't exist, safe to add
|
||||
pass
|
||||
|
||||
# Add the session_id foreign key column
|
||||
migrator.add_fields(
|
||||
'key_snippet',
|
||||
session=pw.ForeignKeyField(
|
||||
Session,
|
||||
null=True,
|
||||
field='id',
|
||||
on_delete='CASCADE'
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Remove session foreign key from KeySnippet table."""
|
||||
|
||||
migrator.remove_fields('key_snippet', 'session')
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
"""Peewee migrations -- 012_20250311_191832_add_session_fk_to_research_note.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Add session foreign key to ResearchNote table."""
|
||||
|
||||
# Get the Session model from migrator.orm
|
||||
Session = migrator.orm['session']
|
||||
|
||||
# Check if the column already exists
|
||||
try:
|
||||
database.execute_sql("SELECT session_id FROM research_note LIMIT 1")
|
||||
# If we reach here, the column exists
|
||||
return
|
||||
except pw.OperationalError:
|
||||
# Column doesn't exist, safe to add
|
||||
pass
|
||||
|
||||
# Add the session_id foreign key column
|
||||
migrator.add_fields(
|
||||
'research_note',
|
||||
session=pw.ForeignKeyField(
|
||||
Session,
|
||||
null=True,
|
||||
field='id',
|
||||
on_delete='CASCADE'
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Remove session foreign key from ResearchNote table."""
|
||||
|
||||
migrator.remove_fields('research_note', 'session')
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
"""Peewee migrations -- 013_20250311_191701_add_session_fk_to_trajectory.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Add session foreign key to Trajectory table."""
|
||||
|
||||
# Get the Session model from migrator.orm
|
||||
Session = migrator.orm['session']
|
||||
|
||||
# Check if the column already exists
|
||||
try:
|
||||
database.execute_sql("SELECT session_id FROM trajectory LIMIT 1")
|
||||
# If we reach here, the column exists
|
||||
return
|
||||
except pw.OperationalError:
|
||||
# Column doesn't exist, safe to add
|
||||
pass
|
||||
|
||||
# Add the session_id foreign key column
|
||||
migrator.add_fields(
|
||||
'trajectory',
|
||||
session=pw.ForeignKeyField(
|
||||
Session,
|
||||
null=True,
|
||||
field='id',
|
||||
on_delete='CASCADE'
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Remove session foreign key from Trajectory table."""
|
||||
|
||||
migrator.remove_fields('trajectory', 'session')
|
||||
Loading…
Reference in New Issue