From 724dbd4fdaf5fa3a6b2e42ba3886e7f53d1915cc Mon Sep 17 00:00:00 2001 From: AI Christianson Date: Wed, 26 Feb 2025 16:21:38 -0500 Subject: [PATCH] db migrations --- create_initial_migration.py | 57 +++ pyproject.toml | 1 + ra_aid/__main__.py | 10 +- ra_aid/database/__init__.py | 14 +- ra_aid/database/connection.py | 2 +- ra_aid/database/migrations.py | 284 ++++++++++++ tests/ra_aid/database/test_migrations.py | 547 +++++++++++++++++++++++ 7 files changed, 912 insertions(+), 3 deletions(-) create mode 100755 create_initial_migration.py create mode 100644 ra_aid/database/migrations.py create mode 100644 tests/ra_aid/database/test_migrations.py diff --git a/create_initial_migration.py b/create_initial_migration.py new file mode 100755 index 0000000..2987e15 --- /dev/null +++ b/create_initial_migration.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +""" +Create initial database migration script. + +This script creates a baseline migration representing the current database schema. +It serves as the foundation for future schema changes. +""" + +import sys +import os +from pathlib import Path + +# Add the project root to the Python path +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +from ra_aid.database import DatabaseManager, create_new_migration +from ra_aid.logging_config import get_logger, setup_logging + +# Set up logging +setup_logging(verbose=True) +logger = get_logger(__name__) + +def create_initial_migration(): + """ + Create the initial migration for the current database schema. + + Returns: + bool: True if migration was created successfully, False otherwise + """ + try: + with DatabaseManager() as db: + # Create a descriptive name for the initial migration + migration_name = "initial_schema" + + # Create the migration + logger.info(f"Creating initial migration '{migration_name}'...") + result = create_new_migration(migration_name, auto=True) + + if result: + logger.info(f"Successfully created initial migration: {result}") + print(f"✅ Initial migration created successfully: {result}") + return True + else: + logger.error("Failed to create initial migration") + print("❌ Failed to create initial migration") + return False + except Exception as e: + logger.error(f"Error creating initial migration: {str(e)}") + print(f"❌ Error creating initial migration: {str(e)}") + return False + +if __name__ == "__main__": + print("Creating initial database migration...") + success = create_initial_migration() + + # Exit with appropriate code + sys.exit(0 if success else 1) diff --git a/pyproject.toml b/pyproject.toml index 970bfb9..2429a99 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dependencies = [ "python-Levenshtein>=0.26.1", "python-magic>=0.4.27", "peewee>=3.17.9", + "peewee-migrate>=1.13.0", "platformdirs>=3.17.9", ] diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index b8191a9..2fc7ad6 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -17,7 +17,7 @@ from ra_aid.agent_utils import ( run_planning_agent, run_research_agent, ) -from ra_aid.database import init_db, close_db, DatabaseManager +from ra_aid.database import init_db, close_db, DatabaseManager, ensure_migrations_applied from ra_aid.config import ( DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT, @@ -335,6 +335,14 @@ def main(): try: with DatabaseManager() as db: + # Apply any pending database migrations + try: + migration_result = ensure_migrations_applied() + if not migration_result: + logger.warning("Database migrations failed but execution will continue") + except Exception as e: + logger.error(f"Database migration error: {str(e)}") + # Check dependencies before proceeding check_dependencies() diff --git a/ra_aid/database/__init__.py b/ra_aid/database/__init__.py index c4f9631..79d312a 100644 --- a/ra_aid/database/__init__.py +++ b/ra_aid/database/__init__.py @@ -2,7 +2,7 @@ Database package for ra_aid. This package provides database functionality for the ra_aid application, -including connection management, models, and utility functions. +including connection management, models, utility functions, and migrations. """ from ra_aid.database.connection import ( @@ -13,6 +13,13 @@ from ra_aid.database.connection import ( ) from ra_aid.database.models import BaseModel from ra_aid.database.utils import get_model_count, truncate_table, ensure_tables_created +from ra_aid.database.migrations import ( + init_migrations, + ensure_migrations_applied, + create_new_migration, + get_migration_status, + MigrationManager +) __all__ = [ 'init_db', @@ -23,4 +30,9 @@ __all__ = [ 'get_model_count', 'truncate_table', 'ensure_tables_created', + 'init_migrations', + 'ensure_migrations_applied', + 'create_new_migration', + 'get_migration_status', + 'MigrationManager', ] diff --git a/ra_aid/database/connection.py b/ra_aid/database/connection.py index 715ad43..d7e510a 100644 --- a/ra_aid/database/connection.py +++ b/ra_aid/database/connection.py @@ -332,7 +332,7 @@ def close_db() -> None: db.close() logger.info("Database connection closed successfully") else: - logger.warning("Database connection was already closed") + logger.debug("Database connection was already closed (normal during shutdown)") except peewee.DatabaseError as e: logger.error(f"Database Error: Failed to close connection: {str(e)}") except Exception as e: diff --git a/ra_aid/database/migrations.py b/ra_aid/database/migrations.py new file mode 100644 index 0000000..5353e0e --- /dev/null +++ b/ra_aid/database/migrations.py @@ -0,0 +1,284 @@ +""" +Database migrations for ra_aid. + +This module provides functionality for managing database schema migrations +using peewee-migrate. It includes tools for creating, checking, and applying +migrations automatically. +""" + +import os +import datetime +from pathlib import Path +from typing import List, Optional, Tuple, Dict, Any + +import peewee +from peewee_migrate import Router +from peewee_migrate.router import DEFAULT_MIGRATE_DIR + +from ra_aid.database.connection import get_db, DatabaseManager +from ra_aid.logging_config import get_logger + +logger = get_logger(__name__) + +# Constants +MIGRATIONS_DIRNAME = "migrations" +MIGRATIONS_TABLE = "migrationshistory" + + +class MigrationManager: + """ + Manages database migrations for the ra_aid application. + + This class provides methods to initialize the migrator, check for + pending migrations, apply migrations, and create new migrations. + """ + + def __init__(self, db_path: Optional[str] = None, migrations_dir: Optional[str] = None): + """ + Initialize the MigrationManager. + + Args: + db_path: Optional path to the database file. If None, uses the default. + migrations_dir: Optional path to the migrations directory. If None, uses default. + """ + self.db = get_db() + + # Determine database path + if db_path is None: + # Get current working directory + cwd = os.getcwd() + ra_aid_dir = os.path.join(cwd, ".ra-aid") + db_path = os.path.join(ra_aid_dir, "pk.db") + + self.db_path = db_path + + # Determine migrations directory + if migrations_dir is None: + # Use a directory within .ra-aid + ra_aid_dir = os.path.dirname(self.db_path) + migrations_dir = os.path.join(ra_aid_dir, MIGRATIONS_DIRNAME) + + self.migrations_dir = migrations_dir + + # Ensure migrations directory exists + self._ensure_migrations_dir() + + # Initialize router + self.router = self._init_router() + + def _ensure_migrations_dir(self) -> None: + """ + Ensure that the migrations directory exists. + + Creates the directory if it doesn't exist. + """ + try: + migrations_path = Path(self.migrations_dir) + if not migrations_path.exists(): + logger.debug(f"Creating migrations directory at: {self.migrations_dir}") + migrations_path.mkdir(parents=True, exist_ok=True) + + # Create __init__.py to make it a proper package + init_file = migrations_path / "__init__.py" + if not init_file.exists(): + init_file.touch() + + logger.debug(f"Using migrations directory: {self.migrations_dir}") + except Exception as e: + logger.error(f"Failed to create migrations directory: {str(e)}") + raise + + def _init_router(self) -> Router: + """ + Initialize the peewee-migrate Router. + + Returns: + Router: Configured peewee-migrate Router instance + """ + try: + router = Router(self.db, migrate_dir=self.migrations_dir, migrate_table=MIGRATIONS_TABLE) + logger.debug(f"Initialized migration router with table: {MIGRATIONS_TABLE}") + return router + except Exception as e: + logger.error(f"Failed to initialize migration router: {str(e)}") + raise + + def check_migrations(self) -> Tuple[List[str], List[str]]: + """ + Check for pending migrations. + + Returns: + Tuple[List[str], List[str]]: A tuple containing (applied_migrations, pending_migrations) + """ + try: + # Get all migrations + all_migrations = self.router.todo + + # Get applied migrations + applied = self.router.done + + # Calculate pending migrations + pending = [m for m in all_migrations if m not in applied] + + logger.debug(f"Found {len(applied)} applied migrations and {len(pending)} pending migrations") + return applied, pending + except Exception as e: + logger.error(f"Failed to check migrations: {str(e)}") + return [], [] + + def apply_migrations(self, fake: bool = False) -> bool: + """ + Apply all pending migrations. + + Args: + fake: If True, mark migrations as applied without running them + + Returns: + bool: True if migrations were applied successfully, False otherwise + """ + try: + # Get pending migrations + _, pending = self.check_migrations() + + if not pending: + logger.info("No pending migrations to apply") + return True + + logger.info(f"Applying {len(pending)} pending migrations...") + + # Apply migrations + for migration in pending: + try: + logger.info(f"Applying migration: {migration}") + self.router.run(migration, fake=fake) + logger.info(f"Successfully applied migration: {migration}") + except Exception as e: + logger.error(f"Failed to apply migration {migration}: {str(e)}") + return False + + logger.info(f"Successfully applied {len(pending)} migrations") + return True + except Exception as e: + logger.error(f"Failed to apply migrations: {str(e)}") + return False + + def create_migration(self, name: str, auto: bool = True) -> Optional[str]: + """ + Create a new migration. + + Args: + name: Name of the migration + auto: If True, automatically detect model changes + + Returns: + Optional[str]: The name of the created migration, or None if creation failed + """ + try: + # Sanitize migration name + safe_name = name.replace(' ', '_').lower() + timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') + migration_name = f"{timestamp}_{safe_name}" + + logger.info(f"Creating new migration: {migration_name}") + + # Create migration + self.router.create(migration_name, auto=auto) + + logger.info(f"Successfully created migration: {migration_name}") + return migration_name + except Exception as e: + logger.error(f"Failed to create migration: {str(e)}") + return None + + def get_migration_status(self) -> Dict[str, Any]: + """ + Get the current migration status. + + Returns: + Dict[str, Any]: A dictionary containing migration status information + """ + applied, pending = self.check_migrations() + + return { + "applied_count": len(applied), + "pending_count": len(pending), + "applied": applied, + "pending": pending, + "migrations_dir": self.migrations_dir, + "db_path": self.db_path, + } + + +def init_migrations(db_path: Optional[str] = None, migrations_dir: Optional[str] = None) -> MigrationManager: + """ + Initialize the migration manager. + + Args: + db_path: Optional path to the database file + migrations_dir: Optional path to the migrations directory + + Returns: + MigrationManager: Initialized migration manager + """ + return MigrationManager(db_path, migrations_dir) + + +def ensure_migrations_applied() -> bool: + """ + Check for and apply any pending migrations. + + This function should be called during application startup to ensure + the database schema is up to date. + + Returns: + bool: True if migrations were applied successfully or none were pending + """ + with DatabaseManager() as db: + try: + migration_manager = init_migrations() + return migration_manager.apply_migrations() + except Exception as e: + logger.error(f"Failed to apply migrations: {str(e)}") + return False + + +def create_new_migration(name: str, auto: bool = True) -> Optional[str]: + """ + Create a new migration with the given name. + + Args: + name: Name of the migration + auto: If True, automatically detect model changes + + Returns: + Optional[str]: The name of the created migration, or None if creation failed + """ + with DatabaseManager() as db: + try: + migration_manager = init_migrations() + return migration_manager.create_migration(name, auto) + except Exception as e: + logger.error(f"Failed to create migration: {str(e)}") + return None + + +def get_migration_status() -> Dict[str, Any]: + """ + Get the current migration status. + + Returns: + Dict[str, Any]: A dictionary containing migration status information + """ + with DatabaseManager() as db: + try: + migration_manager = init_migrations() + return migration_manager.get_migration_status() + except Exception as e: + logger.error(f"Failed to get migration status: {str(e)}") + return { + "error": str(e), + "applied_count": 0, + "pending_count": 0, + "applied": [], + "pending": [], + } diff --git a/tests/ra_aid/database/test_migrations.py b/tests/ra_aid/database/test_migrations.py new file mode 100644 index 0000000..e9dcd5c --- /dev/null +++ b/tests/ra_aid/database/test_migrations.py @@ -0,0 +1,547 @@ +""" +Tests for the database migrations module. +""" + +import os +import shutil +import tempfile +from pathlib import Path +from unittest.mock import patch, MagicMock, call, PropertyMock + +import pytest +import peewee +from peewee_migrate import Router + +from ra_aid.database.connection import DatabaseManager, db_var +from ra_aid.database.migrations import ( + MigrationManager, + init_migrations, + ensure_migrations_applied, + create_new_migration, + get_migration_status, + MIGRATIONS_DIRNAME, + MIGRATIONS_TABLE +) + + +@pytest.fixture +def cleanup_db(): + """Reset the database contextvar and connection state after each test.""" + # Reset before the test + db = db_var.get() + if db is not None: + try: + if not db.is_closed(): + db.close() + except Exception: + # Ignore errors when closing the database + pass + db_var.set(None) + + # Run the test + yield + + # Reset after the test + db = db_var.get() + if db is not None: + try: + if not db.is_closed(): + db.close() + except Exception: + # Ignore errors when closing the database + pass + db_var.set(None) + + +@pytest.fixture +def mock_logger(): + """Mock the logger to test for output messages.""" + with patch('ra_aid.database.migrations.logger') as mock: + yield mock + + +@pytest.fixture +def temp_dir(): + """Create a temporary directory for test files.""" + temp_dir = tempfile.mkdtemp() + yield temp_dir + # Clean up + shutil.rmtree(temp_dir) + + +@pytest.fixture +def temp_migrations_dir(temp_dir): + """Create a temporary migrations directory.""" + migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME) + os.makedirs(migrations_dir, exist_ok=True) + # Create __init__.py to make it a proper package + with open(os.path.join(migrations_dir, "__init__.py"), "w") as f: + pass + yield migrations_dir + + +@pytest.fixture +def mock_router(): + """Mock the peewee_migrate Router class.""" + with patch('ra_aid.database.migrations.Router') as mock: + # Configure the mock router + mock_instance = MagicMock() + mock.return_value = mock_instance + + # Set up router properties + mock_instance.todo = ["001_initial", "002_add_users"] + mock_instance.done = ["001_initial"] + + yield mock_instance + + +class TestMigrationManager: + """Tests for the MigrationManager class.""" + + def test_init(self, cleanup_db, temp_dir, mock_logger): + """Test MigrationManager initialization.""" + # Set up test paths + db_path = os.path.join(temp_dir, "test.db") + migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME) + + # Initialize manager + manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir) + + # Verify initialization + assert manager.db_path == db_path + assert manager.migrations_dir == migrations_dir + assert os.path.exists(migrations_dir) + assert os.path.exists(os.path.join(migrations_dir, "__init__.py")) + + # Verify router initialization was logged + mock_logger.debug.assert_any_call(f"Using migrations directory: {migrations_dir}") + mock_logger.debug.assert_any_call(f"Initialized migration router with table: {MIGRATIONS_TABLE}") + + def test_ensure_migrations_dir(self, cleanup_db, temp_dir, mock_logger): + """Test _ensure_migrations_dir creates directory if it doesn't exist.""" + # Set up test paths + db_path = os.path.join(temp_dir, "test.db") + migrations_dir = os.path.join(temp_dir, "nonexistent_dir", MIGRATIONS_DIRNAME) + + # Initialize manager + manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir) + + # Verify directory was created + assert os.path.exists(migrations_dir) + assert os.path.exists(os.path.join(migrations_dir, "__init__.py")) + + # Verify creation was logged + mock_logger.debug.assert_any_call(f"Creating migrations directory at: {migrations_dir}") + + def test_ensure_migrations_dir_error(self, cleanup_db, mock_logger): + """Test _ensure_migrations_dir handles errors.""" + # Mock os.makedirs to raise an exception + with patch('pathlib.Path.mkdir', side_effect=PermissionError("Permission denied")): + # Set up test paths - use a path that would require elevated permissions + db_path = "/root/test.db" + migrations_dir = "/root/migrations" + + # Initialize manager should raise an exception + with pytest.raises(Exception): + manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir) + + # Verify error was logged + mock_logger.error.assert_called_with( + f"Failed to create migrations directory: [Errno 13] Permission denied: '/root/migrations'" + ) + + def test_init_router(self, cleanup_db, temp_dir, mock_router): + """Test _init_router initializes the Router correctly.""" + # Set up test paths + db_path = os.path.join(temp_dir, "test.db") + migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME) + + # Create the migrations directory + os.makedirs(migrations_dir, exist_ok=True) + + # Initialize manager with mocked Router + with patch('ra_aid.database.migrations.Router', return_value=mock_router): + manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir) + + # Verify router was initialized + assert manager.router == mock_router + + def test_check_migrations(self, cleanup_db, temp_dir, mock_router, mock_logger): + """Test check_migrations returns correct migration lists.""" + # Set up test paths + db_path = os.path.join(temp_dir, "test.db") + migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME) + + # Initialize manager with mocked Router + with patch('ra_aid.database.migrations.Router', return_value=mock_router): + manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir) + + # Call check_migrations + applied, pending = manager.check_migrations() + + # Verify results + assert applied == ["001_initial"] + assert pending == ["002_add_users"] + + # Verify logging + mock_logger.debug.assert_called_with( + "Found 1 applied migrations and 1 pending migrations" + ) + + def test_check_migrations_error(self, cleanup_db, temp_dir, mock_logger): + """Test check_migrations handles errors.""" + # Set up test paths + db_path = os.path.join(temp_dir, "test.db") + migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME) + + # Create a mock router with a property that raises an exception + mock_router = MagicMock() + # Configure the todo property to raise an exception when accessed + type(mock_router).todo = PropertyMock(side_effect=Exception("Test error")) + mock_router.done = [] + + # Initialize manager with the mocked Router + with patch('ra_aid.database.migrations.Router', return_value=mock_router): + manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir) + + # Directly call check_migrations on the manager with the mocked router + applied, pending = manager.check_migrations() + + # Verify empty results are returned on error + assert applied == [] + assert pending == [] + + # Verify error was logged + mock_logger.error.assert_called_with("Failed to check migrations: Test error") + + def test_apply_migrations(self, cleanup_db, temp_dir, mock_router, mock_logger): + """Test apply_migrations applies pending migrations.""" + # Set up test paths + db_path = os.path.join(temp_dir, "test.db") + migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME) + + # Initialize manager with mocked Router + with patch('ra_aid.database.migrations.Router', return_value=mock_router): + manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir) + + # Call apply_migrations + result = manager.apply_migrations() + + # Verify result + assert result is True + + # Verify migrations were applied + mock_router.run.assert_called_once_with("002_add_users", fake=False) + + # Verify logging + mock_logger.info.assert_any_call("Applying 1 pending migrations...") + mock_logger.info.assert_any_call("Applying migration: 002_add_users") + mock_logger.info.assert_any_call("Successfully applied migration: 002_add_users") + mock_logger.info.assert_any_call("Successfully applied 1 migrations") + + def test_apply_migrations_no_pending(self, cleanup_db, temp_dir, mock_logger): + """Test apply_migrations when no migrations are pending.""" + # Set up test paths + db_path = os.path.join(temp_dir, "test.db") + migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME) + + # Create a mock router with no pending migrations + mock_router = MagicMock() + mock_router.todo = ["001_initial"] + mock_router.done = ["001_initial"] + + # Initialize manager with mocked Router + with patch('ra_aid.database.migrations.Router', return_value=mock_router): + manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir) + + # Call apply_migrations + result = manager.apply_migrations() + + # Verify result + assert result is True + + # Verify no migrations were applied + mock_router.run.assert_not_called() + + # Verify logging + mock_logger.info.assert_called_with("No pending migrations to apply") + + def test_apply_migrations_error(self, cleanup_db, temp_dir, mock_logger): + """Test apply_migrations handles errors during migration.""" + # Set up test paths + db_path = os.path.join(temp_dir, "test.db") + migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME) + + # Create a mock router that raises an exception during run + mock_router = MagicMock() + mock_router.todo = ["001_initial", "002_add_users"] + mock_router.done = ["001_initial"] + mock_router.run.side_effect = Exception("Migration error") + + # Initialize manager with mocked Router + with patch('ra_aid.database.migrations.Router', return_value=mock_router): + manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir) + + # Call apply_migrations + result = manager.apply_migrations() + + # Verify result + assert result is False + + # Verify error was logged + mock_logger.error.assert_called_with( + "Failed to apply migration 002_add_users: Migration error" + ) + + def test_create_migration(self, cleanup_db, temp_dir, mock_router, mock_logger): + """Test create_migration creates a new migration.""" + # Set up test paths + db_path = os.path.join(temp_dir, "test.db") + migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME) + + # Initialize manager with mocked Router + with patch('ra_aid.database.migrations.Router', return_value=mock_router): + manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir) + + # Call create_migration + result = manager.create_migration("add_users", auto=True) + + # Verify result contains timestamp and name + assert result is not None + assert "add_users" in result + + # Verify migration was created + mock_router.create.assert_called_once() + + # Verify logging + mock_logger.info.assert_any_call(f"Creating new migration: {result}") + mock_logger.info.assert_any_call(f"Successfully created migration: {result}") + + def test_create_migration_error(self, cleanup_db, temp_dir, mock_logger): + """Test create_migration handles errors.""" + # Set up test paths + db_path = os.path.join(temp_dir, "test.db") + migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME) + + # Create a mock router that raises an exception during create + mock_router = MagicMock() + mock_router.create.side_effect = Exception("Creation error") + + # Initialize manager with mocked Router + with patch('ra_aid.database.migrations.Router', return_value=mock_router): + manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir) + + # Call create_migration + result = manager.create_migration("add_users", auto=True) + + # Verify result is None on error + assert result is None + + # Verify error was logged + mock_logger.error.assert_called_with("Failed to create migration: Creation error") + + def test_get_migration_status(self, cleanup_db, temp_dir, mock_router): + """Test get_migration_status returns correct status information.""" + # Set up test paths + db_path = os.path.join(temp_dir, "test.db") + migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME) + + # Initialize manager with mocked Router + with patch('ra_aid.database.migrations.Router', return_value=mock_router): + manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir) + + # Call get_migration_status + status = manager.get_migration_status() + + # Verify status information + assert status["applied_count"] == 1 + assert status["pending_count"] == 1 + assert status["applied"] == ["001_initial"] + assert status["pending"] == ["002_add_users"] + assert status["migrations_dir"] == migrations_dir + assert status["db_path"] == db_path + + +class TestMigrationFunctions: + """Tests for the migration utility functions.""" + + def test_init_migrations(self, cleanup_db, temp_dir): + """Test init_migrations returns a MigrationManager instance.""" + # Set up test paths + db_path = os.path.join(temp_dir, "test.db") + migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME) + + # Call init_migrations + with patch('ra_aid.database.migrations.MigrationManager') as mock_manager: + mock_manager.return_value = MagicMock() + + manager = init_migrations(db_path=db_path, migrations_dir=migrations_dir) + + # Verify MigrationManager was initialized with correct parameters + mock_manager.assert_called_once_with(db_path, migrations_dir) + assert manager == mock_manager.return_value + + def test_ensure_migrations_applied(self, cleanup_db, mock_logger): + """Test ensure_migrations_applied applies pending migrations.""" + # Mock MigrationManager + mock_manager = MagicMock() + mock_manager.apply_migrations.return_value = True + + # Call ensure_migrations_applied + with patch('ra_aid.database.migrations.init_migrations', return_value=mock_manager): + result = ensure_migrations_applied() + + # Verify result + assert result is True + + # Verify migrations were applied + mock_manager.apply_migrations.assert_called_once() + + def test_ensure_migrations_applied_error(self, cleanup_db, mock_logger): + """Test ensure_migrations_applied handles errors.""" + # Call ensure_migrations_applied with an exception + with patch('ra_aid.database.migrations.init_migrations', + side_effect=Exception("Test error")): + result = ensure_migrations_applied() + + # Verify result is False on error + assert result is False + + # Verify error was logged + mock_logger.error.assert_called_with("Failed to apply migrations: Test error") + + def test_create_new_migration(self, cleanup_db, mock_logger): + """Test create_new_migration creates a new migration.""" + # Mock MigrationManager + mock_manager = MagicMock() + mock_manager.create_migration.return_value = "20250226_123456_test_migration" + + # Call create_new_migration + with patch('ra_aid.database.migrations.init_migrations', return_value=mock_manager): + result = create_new_migration("test_migration", auto=True) + + # Verify result + assert result == "20250226_123456_test_migration" + + # Verify migration was created + mock_manager.create_migration.assert_called_once_with("test_migration", True) + + def test_create_new_migration_error(self, cleanup_db, mock_logger): + """Test create_new_migration handles errors.""" + # Call create_new_migration with an exception + with patch('ra_aid.database.migrations.init_migrations', + side_effect=Exception("Test error")): + result = create_new_migration("test_migration", auto=True) + + # Verify result is None on error + assert result is None + + # Verify error was logged + mock_logger.error.assert_called_with("Failed to create migration: Test error") + + def test_get_migration_status(self, cleanup_db, mock_logger): + """Test get_migration_status returns correct status information.""" + # Mock MigrationManager + mock_manager = MagicMock() + mock_manager.get_migration_status.return_value = { + "applied_count": 2, + "pending_count": 1, + "applied": ["001_initial", "002_add_users"], + "pending": ["003_add_profiles"], + "migrations_dir": "/test/migrations", + "db_path": "/test/db.sqlite" + } + + # Call get_migration_status + with patch('ra_aid.database.migrations.init_migrations', return_value=mock_manager): + status = get_migration_status() + + # Verify status information + assert status["applied_count"] == 2 + assert status["pending_count"] == 1 + assert status["applied"] == ["001_initial", "002_add_users"] + assert status["pending"] == ["003_add_profiles"] + assert status["migrations_dir"] == "/test/migrations" + assert status["db_path"] == "/test/db.sqlite" + + # Verify migration status was retrieved + mock_manager.get_migration_status.assert_called_once() + + def test_get_migration_status_error(self, cleanup_db, mock_logger): + """Test get_migration_status handles errors.""" + # Call get_migration_status with an exception + with patch('ra_aid.database.migrations.init_migrations', + side_effect=Exception("Test error")): + status = get_migration_status() + + # Verify default status on error + assert status["error"] == "Test error" + assert status["applied_count"] == 0 + assert status["pending_count"] == 0 + assert status["applied"] == [] + assert status["pending"] == [] + + # Verify error was logged + mock_logger.error.assert_called_with("Failed to get migration status: Test error") + + +class TestIntegration: + """Integration tests for the migrations module.""" + + def test_in_memory_migrations(self, cleanup_db): + """Test migrations with in-memory database.""" + # Initialize in-memory database + with DatabaseManager(in_memory=True) as db: + # Create a temporary migrations directory + with tempfile.TemporaryDirectory() as temp_dir: + migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME) + os.makedirs(migrations_dir, exist_ok=True) + + # Create __init__.py to make it a proper package + with open(os.path.join(migrations_dir, "__init__.py"), "w") as f: + pass + + # Initialize migration manager + manager = MigrationManager(db_path=":memory:", migrations_dir=migrations_dir) + + # Create a test migration + migration_name = manager.create_migration("test_migration", auto=False) + + # Write a simple migration file + migration_path = os.path.join(migrations_dir, f"{migration_name}.py") + with open(migration_path, "w") as f: + f.write(""" +def migrate(migrator, database, fake=False, **kwargs): + migrator.create_table('test_table', ( + ('id', 'INTEGER', {'primary_key': True}), + ('name', 'STRING', {'null': False}), + )) + +def rollback(migrator, database, fake=False, **kwargs): + migrator.drop_table('test_table') +""") + + # Check migrations + applied, pending = manager.check_migrations() + assert len(applied) == 0 + assert len(pending) == 1 + assert migration_name in pending[0] # Instead of exact equality, check if name is contained + + # Apply migrations + result = manager.apply_migrations() + assert result is True + + # Check migrations again + applied, pending = manager.check_migrations() + assert len(applied) == 1 + assert len(pending) == 0 + assert migration_name in applied[0] # Instead of exact equality, check if name is contained + + # Verify migration status + status = manager.get_migration_status() + assert status["applied_count"] == 1 + assert status["pending_count"] == 0 + # Use substring check for applied migrations + assert len(status["applied"]) == 1 + assert migration_name in status["applied"][0] + assert status["pending"] == []