From c8fbd942ac3862de295e719069fe58b3131030d1 Mon Sep 17 00:00:00 2001 From: AI Christianson Date: Tue, 11 Mar 2025 20:11:14 -0400 Subject: [PATCH] session model --- ra_aid/__main__.py | 11 +- ra_aid/database/models.py | 28 +- .../repositories/session_repository.py | 243 ++++++++++++++++++ .../008_20250311_191232_add_session_model.py | 79 ++++++ ...11_191517_add_session_fk_to_human_input.py | 67 +++++ ...50311_191617_add_session_fk_to_key_fact.py | 67 +++++ ...11_191732_add_session_fk_to_key_snippet.py | 67 +++++ ..._191832_add_session_fk_to_research_note.py | 67 +++++ ...311_191701_add_session_fk_to_trajectory.py | 67 +++++ 9 files changed, 693 insertions(+), 3 deletions(-) create mode 100644 ra_aid/database/repositories/session_repository.py create mode 100644 ra_aid/migrations/008_20250311_191232_add_session_model.py create mode 100644 ra_aid/migrations/009_20250311_191517_add_session_fk_to_human_input.py create mode 100644 ra_aid/migrations/010_20250311_191617_add_session_fk_to_key_fact.py create mode 100644 ra_aid/migrations/011_20250311_191732_add_session_fk_to_key_snippet.py create mode 100644 ra_aid/migrations/012_20250311_191832_add_session_fk_to_research_note.py create mode 100644 ra_aid/migrations/013_20250311_191701_add_session_fk_to_trajectory.py diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index 75e25cb..aaf6f84 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -56,6 +56,9 @@ from ra_aid.database.repositories.research_note_repository import ( from ra_aid.database.repositories.trajectory_repository import ( TrajectoryRepositoryManager, get_trajectory_repository ) +from ra_aid.database.repositories.session_repository import ( + SessionRepositoryManager, get_session_repository +) from ra_aid.database.repositories.related_files_repository import ( RelatedFilesRepositoryManager ) @@ -526,7 +529,8 @@ def main(): env_discovery.discover() env_data = env_discovery.format_markdown() - with KeyFactRepositoryManager(db) as key_fact_repo, \ + with SessionRepositoryManager(db) as session_repo, \ + KeyFactRepositoryManager(db) as key_fact_repo, \ KeySnippetRepositoryManager(db) as key_snippet_repo, \ HumanInputRepositoryManager(db) as human_input_repo, \ ResearchNoteRepositoryManager(db) as research_note_repo, \ @@ -536,6 +540,7 @@ def main(): ConfigRepositoryManager(config) as config_repo, \ EnvInvManager(env_data) as env_inv: # This initializes all repositories and makes them available via their respective get methods + logger.debug("Initialized SessionRepository") logger.debug("Initialized KeyFactRepository") logger.debug("Initialized KeySnippetRepository") logger.debug("Initialized HumanInputRepository") @@ -545,6 +550,10 @@ def main(): logger.debug("Initialized WorkLogRepository") logger.debug("Initialized ConfigRepository") logger.debug("Initialized Environment Inventory") + + # Create a new session for this program run + logger.debug("Initializing new session") + session_repo.create_session() # Check dependencies before proceeding check_dependencies() diff --git a/ra_aid/database/models.py b/ra_aid/database/models.py index 3fc7033..7a9ec99 100644 --- a/ra_aid/database/models.py +++ b/ra_aid/database/models.py @@ -42,8 +42,8 @@ def initialize_database(): # to avoid circular imports # Note: This import needs to be here, not at the top level try: - from ra_aid.database.models import KeyFact, KeySnippet, HumanInput, ResearchNote, Trajectory - db.create_tables([KeyFact, KeySnippet, HumanInput, ResearchNote, Trajectory], safe=True) + from ra_aid.database.models import KeyFact, KeySnippet, HumanInput, ResearchNote, Trajectory, Session + db.create_tables([KeyFact, KeySnippet, HumanInput, ResearchNote, Trajectory, Session], safe=True) logger.debug("Ensured database tables exist") except Exception as e: logger.error(f"Error creating tables: {str(e)}") @@ -99,6 +99,25 @@ class BaseModel(peewee.Model): raise +class Session(BaseModel): + """ + Model representing a session stored in the database. + + Sessions track information about each program run, providing a way to group + related records like human inputs, trajectories, and key facts. + + Each session record captures details about when the program was started, + what command line arguments were used, and environment information. + """ + start_time = peewee.DateTimeField(default=datetime.datetime.now) + command_line = peewee.TextField(null=True) + program_version = peewee.TextField(null=True) + machine_info = peewee.TextField(null=True) # JSON-encoded machine information + + class Meta: + table_name = "session" + + class HumanInput(BaseModel): """ Model representing human input stored in the database. @@ -109,6 +128,7 @@ class HumanInput(BaseModel): """ content = peewee.TextField() source = peewee.TextField() # 'cli', 'chat', or 'hil' + session = peewee.ForeignKeyField(Session, backref='human_inputs', null=True) # created_at and updated_at are inherited from BaseModel class Meta: @@ -124,6 +144,7 @@ class KeyFact(BaseModel): """ content = peewee.TextField() human_input = peewee.ForeignKeyField(HumanInput, backref='key_facts', null=True) + session = peewee.ForeignKeyField(Session, backref='key_facts', null=True) # created_at and updated_at are inherited from BaseModel class Meta: @@ -143,6 +164,7 @@ class KeySnippet(BaseModel): snippet = peewee.TextField() description = peewee.TextField(null=True) human_input = peewee.ForeignKeyField(HumanInput, backref='key_snippets', null=True) + session = peewee.ForeignKeyField(Session, backref='key_snippets', null=True) # created_at and updated_at are inherited from BaseModel class Meta: @@ -159,6 +181,7 @@ class ResearchNote(BaseModel): """ content = peewee.TextField() human_input = peewee.ForeignKeyField(HumanInput, backref='research_notes', null=True) + session = peewee.ForeignKeyField(Session, backref='research_notes', null=True) # created_at and updated_at are inherited from BaseModel class Meta: @@ -193,6 +216,7 @@ class Trajectory(BaseModel): error_message = peewee.TextField(null=True) # The error message error_type = peewee.TextField(null=True) # The type/class of the error error_details = peewee.TextField(null=True) # Additional error details like stack traces or context + session = peewee.ForeignKeyField(Session, backref='trajectories', null=True) # created_at and updated_at are inherited from BaseModel class Meta: diff --git a/ra_aid/database/repositories/session_repository.py b/ra_aid/database/repositories/session_repository.py new file mode 100644 index 0000000..85509f4 --- /dev/null +++ b/ra_aid/database/repositories/session_repository.py @@ -0,0 +1,243 @@ +""" +Session repository implementation for database access. + +This module provides a repository implementation for the Session model, +following the repository pattern for data access abstraction. It handles +operations for storing and retrieving application session information. +""" + +from typing import Dict, List, Optional, Any +import contextvars +import datetime +import json +import logging +import sys + +import peewee + +from ra_aid.database.models import Session +from ra_aid.__version__ import __version__ +from ra_aid.logging_config import get_logger + +logger = get_logger(__name__) + +# Create contextvar to hold the SessionRepository instance +session_repo_var = contextvars.ContextVar("session_repo", default=None) + + +class SessionRepositoryManager: + """ + Context manager for SessionRepository. + + This class provides a context manager interface for SessionRepository, + using the contextvars approach for thread safety. + + Example: + with DatabaseManager() as db: + with SessionRepositoryManager(db) as repo: + # Use the repository + session = repo.create_session() + current_session = repo.get_current_session() + """ + + def __init__(self, db): + """ + Initialize the SessionRepositoryManager. + + Args: + db: Database connection to use (required) + """ + self.db = db + + def __enter__(self) -> 'SessionRepository': + """ + Initialize the SessionRepository and return it. + + Returns: + SessionRepository: The initialized repository + """ + repo = SessionRepository(self.db) + session_repo_var.set(repo) + return repo + + def __exit__( + self, + exc_type: Optional[type], + exc_val: Optional[Exception], + exc_tb: Optional[object], + ) -> None: + """ + Reset the repository when exiting the context. + + Args: + exc_type: The exception type if an exception was raised + exc_val: The exception value if an exception was raised + exc_tb: The traceback if an exception was raised + """ + # Reset the contextvar to None + session_repo_var.set(None) + + # Don't suppress exceptions + return False + + +def get_session_repository() -> 'SessionRepository': + """ + Get the current SessionRepository instance. + + Returns: + SessionRepository: The current repository instance + + Raises: + RuntimeError: If no repository has been initialized with SessionRepositoryManager + """ + repo = session_repo_var.get() + if repo is None: + raise RuntimeError( + "No SessionRepository available. " + "Make sure to initialize one with SessionRepositoryManager first." + ) + return repo + + +class SessionRepository: + """ + Repository for handling Session records in the database. + + This class provides methods for creating, retrieving, and managing Session records. + It abstracts away the database operations and provides a clean interface for working + with Session entities. + """ + + def __init__(self, db): + """ + Initialize the SessionRepository. + + Args: + db: Database connection to use (required) + """ + if db is None: + raise ValueError("Database connection is required for SessionRepository") + self.db = db + self.current_session = None + + def create_session(self, metadata: Optional[Dict[str, Any]] = None) -> Session: + """ + Create a new session record in the database. + + Args: + metadata: Optional dictionary of additional metadata to store with the session + + Returns: + Session: The newly created session instance + + Raises: + peewee.DatabaseError: If there's an error creating the record + """ + try: + # Get command line arguments + command_line = " ".join(sys.argv) + + # Get program version + program_version = __version__ + + # JSON encode metadata if provided + machine_info = json.dumps(metadata) if metadata is not None else None + + session = Session.create( + start_time=datetime.datetime.now(), + command_line=command_line, + program_version=program_version, + machine_info=machine_info + ) + + # Store the current session + self.current_session = session + + logger.debug(f"Created new session with ID {session.id}") + return session + except peewee.DatabaseError as e: + logger.error(f"Failed to create session record: {str(e)}") + raise + + def get_current_session(self) -> Optional[Session]: + """ + Get the current active session. + + If no session has been created in this repository instance, + retrieves the most recent session from the database. + + Returns: + Optional[Session]: The current session or None if no sessions exist + """ + if self.current_session is not None: + return self.current_session + + try: + # Find the most recent session + session = Session.select().order_by(Session.created_at.desc()).first() + if session: + self.current_session = session + return session + except peewee.DatabaseError as e: + logger.error(f"Failed to get current session: {str(e)}") + return None + + def get_current_session_id(self) -> Optional[int]: + """ + Get the ID of the current active session. + + Returns: + Optional[int]: The ID of the current session or None if no session exists + """ + session = self.get_current_session() + return session.id if session else None + + def get(self, session_id: int) -> Optional[Session]: + """ + Get a session by its ID. + + Args: + session_id: The ID of the session to retrieve + + Returns: + Optional[Session]: The session with the given ID or None if not found + """ + try: + return Session.get_or_none(Session.id == session_id) + except peewee.DatabaseError as e: + logger.error(f"Database error getting session {session_id}: {str(e)}") + return None + + def get_all(self) -> List[Session]: + """ + Get all sessions from the database. + + Returns: + List[Session]: List of all sessions + """ + try: + return list(Session.select().order_by(Session.created_at.desc())) + except peewee.DatabaseError as e: + logger.error(f"Failed to get all sessions: {str(e)}") + return [] + + def get_recent(self, limit: int = 10) -> List[Session]: + """ + Get the most recent sessions from the database. + + Args: + limit: Maximum number of sessions to return (default: 10) + + Returns: + List[Session]: List of the most recent sessions + """ + try: + return list( + Session.select() + .order_by(Session.created_at.desc()) + .limit(limit) + ) + except peewee.DatabaseError as e: + logger.error(f"Failed to get recent sessions: {str(e)}") + return [] \ No newline at end of file diff --git a/ra_aid/migrations/008_20250311_191232_add_session_model.py b/ra_aid/migrations/008_20250311_191232_add_session_model.py new file mode 100644 index 0000000..1393bd4 --- /dev/null +++ b/ra_aid/migrations/008_20250311_191232_add_session_model.py @@ -0,0 +1,79 @@ +"""Peewee migrations -- 008_20250311_191232_add_session_model.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Create the session table for storing application session information.""" + + table_exists = False + # Check if the table already exists + try: + database.execute_sql("SELECT id FROM session LIMIT 1") + # If we reach here, the table exists + table_exists = True + except pw.OperationalError: + # Table doesn't exist, safe to create + pass + + # Create the Session model - this registers it in migrator.orm as 'Session' + @migrator.create_model + class Session(pw.Model): + id = pw.AutoField() + created_at = pw.DateTimeField() + updated_at = pw.DateTimeField() + start_time = pw.DateTimeField() + command_line = pw.TextField(null=True) + program_version = pw.TextField(null=True) + machine_info = pw.TextField(null=True) + + class Meta: + table_name = "session" + + # FIX: Explicitly register the model under the lowercase table name key + # This ensures that later migrations can access it via either: + # - migrator.orm['Session'] (class name) + # - migrator.orm['session'] (table name) + if 'Session' in migrator.orm: + migrator.orm['session'] = migrator.orm['Session'] + + # Only return after model registration is complete + if table_exists: + return + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Remove the session table.""" + + migrator.remove_model('session') \ No newline at end of file diff --git a/ra_aid/migrations/009_20250311_191517_add_session_fk_to_human_input.py b/ra_aid/migrations/009_20250311_191517_add_session_fk_to_human_input.py new file mode 100644 index 0000000..f7fa9d2 --- /dev/null +++ b/ra_aid/migrations/009_20250311_191517_add_session_fk_to_human_input.py @@ -0,0 +1,67 @@ +"""Peewee migrations -- 009_20250311_191517_add_session_fk_to_human_input.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Add session foreign key to HumanInput table.""" + + # Get the Session model from migrator.orm + Session = migrator.orm['session'] + + # Check if the column already exists + try: + database.execute_sql("SELECT session_id FROM human_input LIMIT 1") + # If we reach here, the column exists + return + except pw.OperationalError: + # Column doesn't exist, safe to add + pass + + # Add the session_id foreign key column + migrator.add_fields( + 'human_input', + session=pw.ForeignKeyField( + Session, + null=True, + field='id', + on_delete='CASCADE' + ) + ) + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Remove session foreign key from HumanInput table.""" + + migrator.remove_fields('human_input', 'session') \ No newline at end of file diff --git a/ra_aid/migrations/010_20250311_191617_add_session_fk_to_key_fact.py b/ra_aid/migrations/010_20250311_191617_add_session_fk_to_key_fact.py new file mode 100644 index 0000000..ab0322d --- /dev/null +++ b/ra_aid/migrations/010_20250311_191617_add_session_fk_to_key_fact.py @@ -0,0 +1,67 @@ +"""Peewee migrations -- 010_20250311_191617_add_session_fk_to_key_fact.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Add session foreign key to KeyFact table.""" + + # Get the Session model from migrator.orm + Session = migrator.orm['session'] + + # Check if the column already exists + try: + database.execute_sql("SELECT session_id FROM key_fact LIMIT 1") + # If we reach here, the column exists + return + except pw.OperationalError: + # Column doesn't exist, safe to add + pass + + # Add the session_id foreign key column + migrator.add_fields( + 'key_fact', + session=pw.ForeignKeyField( + Session, + null=True, + field='id', + on_delete='CASCADE' + ) + ) + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Remove session foreign key from KeyFact table.""" + + migrator.remove_fields('key_fact', 'session') \ No newline at end of file diff --git a/ra_aid/migrations/011_20250311_191732_add_session_fk_to_key_snippet.py b/ra_aid/migrations/011_20250311_191732_add_session_fk_to_key_snippet.py new file mode 100644 index 0000000..dd9c49c --- /dev/null +++ b/ra_aid/migrations/011_20250311_191732_add_session_fk_to_key_snippet.py @@ -0,0 +1,67 @@ +"""Peewee migrations -- 011_20250311_191732_add_session_fk_to_key_snippet.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Add session foreign key to KeySnippet table.""" + + # Get the Session model from migrator.orm + Session = migrator.orm['session'] + + # Check if the column already exists + try: + database.execute_sql("SELECT session_id FROM key_snippet LIMIT 1") + # If we reach here, the column exists + return + except pw.OperationalError: + # Column doesn't exist, safe to add + pass + + # Add the session_id foreign key column + migrator.add_fields( + 'key_snippet', + session=pw.ForeignKeyField( + Session, + null=True, + field='id', + on_delete='CASCADE' + ) + ) + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Remove session foreign key from KeySnippet table.""" + + migrator.remove_fields('key_snippet', 'session') \ No newline at end of file diff --git a/ra_aid/migrations/012_20250311_191832_add_session_fk_to_research_note.py b/ra_aid/migrations/012_20250311_191832_add_session_fk_to_research_note.py new file mode 100644 index 0000000..23e05d8 --- /dev/null +++ b/ra_aid/migrations/012_20250311_191832_add_session_fk_to_research_note.py @@ -0,0 +1,67 @@ +"""Peewee migrations -- 012_20250311_191832_add_session_fk_to_research_note.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Add session foreign key to ResearchNote table.""" + + # Get the Session model from migrator.orm + Session = migrator.orm['session'] + + # Check if the column already exists + try: + database.execute_sql("SELECT session_id FROM research_note LIMIT 1") + # If we reach here, the column exists + return + except pw.OperationalError: + # Column doesn't exist, safe to add + pass + + # Add the session_id foreign key column + migrator.add_fields( + 'research_note', + session=pw.ForeignKeyField( + Session, + null=True, + field='id', + on_delete='CASCADE' + ) + ) + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Remove session foreign key from ResearchNote table.""" + + migrator.remove_fields('research_note', 'session') \ No newline at end of file diff --git a/ra_aid/migrations/013_20250311_191701_add_session_fk_to_trajectory.py b/ra_aid/migrations/013_20250311_191701_add_session_fk_to_trajectory.py new file mode 100644 index 0000000..b79d900 --- /dev/null +++ b/ra_aid/migrations/013_20250311_191701_add_session_fk_to_trajectory.py @@ -0,0 +1,67 @@ +"""Peewee migrations -- 013_20250311_191701_add_session_fk_to_trajectory.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Add session foreign key to Trajectory table.""" + + # Get the Session model from migrator.orm + Session = migrator.orm['session'] + + # Check if the column already exists + try: + database.execute_sql("SELECT session_id FROM trajectory LIMIT 1") + # If we reach here, the column exists + return + except pw.OperationalError: + # Column doesn't exist, safe to add + pass + + # Add the session_id foreign key column + migrator.add_fields( + 'trajectory', + session=pw.ForeignKeyField( + Session, + null=True, + field='id', + on_delete='CASCADE' + ) + ) + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Remove session foreign key from Trajectory table.""" + + migrator.remove_fields('trajectory', 'session') \ No newline at end of file