db migrations
This commit is contained in:
parent
e6d98737a8
commit
724dbd4fda
|
|
@ -0,0 +1,57 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Create initial database migration script.
|
||||
|
||||
This script creates a baseline migration representing the current database schema.
|
||||
It serves as the foundation for future schema changes.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Add the project root to the Python path
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from ra_aid.database import DatabaseManager, create_new_migration
|
||||
from ra_aid.logging_config import get_logger, setup_logging
|
||||
|
||||
# Set up logging
|
||||
setup_logging(verbose=True)
|
||||
logger = get_logger(__name__)
|
||||
|
||||
def create_initial_migration():
|
||||
"""
|
||||
Create the initial migration for the current database schema.
|
||||
|
||||
Returns:
|
||||
bool: True if migration was created successfully, False otherwise
|
||||
"""
|
||||
try:
|
||||
with DatabaseManager() as db:
|
||||
# Create a descriptive name for the initial migration
|
||||
migration_name = "initial_schema"
|
||||
|
||||
# Create the migration
|
||||
logger.info(f"Creating initial migration '{migration_name}'...")
|
||||
result = create_new_migration(migration_name, auto=True)
|
||||
|
||||
if result:
|
||||
logger.info(f"Successfully created initial migration: {result}")
|
||||
print(f"✅ Initial migration created successfully: {result}")
|
||||
return True
|
||||
else:
|
||||
logger.error("Failed to create initial migration")
|
||||
print("❌ Failed to create initial migration")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating initial migration: {str(e)}")
|
||||
print(f"❌ Error creating initial migration: {str(e)}")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Creating initial database migration...")
|
||||
success = create_initial_migration()
|
||||
|
||||
# Exit with appropriate code
|
||||
sys.exit(0 if success else 1)
|
||||
|
|
@ -45,6 +45,7 @@ dependencies = [
|
|||
"python-Levenshtein>=0.26.1",
|
||||
"python-magic>=0.4.27",
|
||||
"peewee>=3.17.9",
|
||||
"peewee-migrate>=1.13.0",
|
||||
"platformdirs>=3.17.9",
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from ra_aid.agent_utils import (
|
|||
run_planning_agent,
|
||||
run_research_agent,
|
||||
)
|
||||
from ra_aid.database import init_db, close_db, DatabaseManager
|
||||
from ra_aid.database import init_db, close_db, DatabaseManager, ensure_migrations_applied
|
||||
from ra_aid.config import (
|
||||
DEFAULT_MAX_TEST_CMD_RETRIES,
|
||||
DEFAULT_RECURSION_LIMIT,
|
||||
|
|
@ -335,6 +335,14 @@ def main():
|
|||
|
||||
try:
|
||||
with DatabaseManager() as db:
|
||||
# Apply any pending database migrations
|
||||
try:
|
||||
migration_result = ensure_migrations_applied()
|
||||
if not migration_result:
|
||||
logger.warning("Database migrations failed but execution will continue")
|
||||
except Exception as e:
|
||||
logger.error(f"Database migration error: {str(e)}")
|
||||
|
||||
# Check dependencies before proceeding
|
||||
check_dependencies()
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
Database package for ra_aid.
|
||||
|
||||
This package provides database functionality for the ra_aid application,
|
||||
including connection management, models, and utility functions.
|
||||
including connection management, models, utility functions, and migrations.
|
||||
"""
|
||||
|
||||
from ra_aid.database.connection import (
|
||||
|
|
@ -13,6 +13,13 @@ from ra_aid.database.connection import (
|
|||
)
|
||||
from ra_aid.database.models import BaseModel
|
||||
from ra_aid.database.utils import get_model_count, truncate_table, ensure_tables_created
|
||||
from ra_aid.database.migrations import (
|
||||
init_migrations,
|
||||
ensure_migrations_applied,
|
||||
create_new_migration,
|
||||
get_migration_status,
|
||||
MigrationManager
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'init_db',
|
||||
|
|
@ -23,4 +30,9 @@ __all__ = [
|
|||
'get_model_count',
|
||||
'truncate_table',
|
||||
'ensure_tables_created',
|
||||
'init_migrations',
|
||||
'ensure_migrations_applied',
|
||||
'create_new_migration',
|
||||
'get_migration_status',
|
||||
'MigrationManager',
|
||||
]
|
||||
|
|
|
|||
|
|
@ -332,7 +332,7 @@ def close_db() -> None:
|
|||
db.close()
|
||||
logger.info("Database connection closed successfully")
|
||||
else:
|
||||
logger.warning("Database connection was already closed")
|
||||
logger.debug("Database connection was already closed (normal during shutdown)")
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Database Error: Failed to close connection: {str(e)}")
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,284 @@
|
|||
"""
|
||||
Database migrations for ra_aid.
|
||||
|
||||
This module provides functionality for managing database schema migrations
|
||||
using peewee-migrate. It includes tools for creating, checking, and applying
|
||||
migrations automatically.
|
||||
"""
|
||||
|
||||
import os
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Dict, Any
|
||||
|
||||
import peewee
|
||||
from peewee_migrate import Router
|
||||
from peewee_migrate.router import DEFAULT_MIGRATE_DIR
|
||||
|
||||
from ra_aid.database.connection import get_db, DatabaseManager
|
||||
from ra_aid.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Constants
|
||||
MIGRATIONS_DIRNAME = "migrations"
|
||||
MIGRATIONS_TABLE = "migrationshistory"
|
||||
|
||||
|
||||
class MigrationManager:
|
||||
"""
|
||||
Manages database migrations for the ra_aid application.
|
||||
|
||||
This class provides methods to initialize the migrator, check for
|
||||
pending migrations, apply migrations, and create new migrations.
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: Optional[str] = None, migrations_dir: Optional[str] = None):
|
||||
"""
|
||||
Initialize the MigrationManager.
|
||||
|
||||
Args:
|
||||
db_path: Optional path to the database file. If None, uses the default.
|
||||
migrations_dir: Optional path to the migrations directory. If None, uses default.
|
||||
"""
|
||||
self.db = get_db()
|
||||
|
||||
# Determine database path
|
||||
if db_path is None:
|
||||
# Get current working directory
|
||||
cwd = os.getcwd()
|
||||
ra_aid_dir = os.path.join(cwd, ".ra-aid")
|
||||
db_path = os.path.join(ra_aid_dir, "pk.db")
|
||||
|
||||
self.db_path = db_path
|
||||
|
||||
# Determine migrations directory
|
||||
if migrations_dir is None:
|
||||
# Use a directory within .ra-aid
|
||||
ra_aid_dir = os.path.dirname(self.db_path)
|
||||
migrations_dir = os.path.join(ra_aid_dir, MIGRATIONS_DIRNAME)
|
||||
|
||||
self.migrations_dir = migrations_dir
|
||||
|
||||
# Ensure migrations directory exists
|
||||
self._ensure_migrations_dir()
|
||||
|
||||
# Initialize router
|
||||
self.router = self._init_router()
|
||||
|
||||
def _ensure_migrations_dir(self) -> None:
|
||||
"""
|
||||
Ensure that the migrations directory exists.
|
||||
|
||||
Creates the directory if it doesn't exist.
|
||||
"""
|
||||
try:
|
||||
migrations_path = Path(self.migrations_dir)
|
||||
if not migrations_path.exists():
|
||||
logger.debug(f"Creating migrations directory at: {self.migrations_dir}")
|
||||
migrations_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create __init__.py to make it a proper package
|
||||
init_file = migrations_path / "__init__.py"
|
||||
if not init_file.exists():
|
||||
init_file.touch()
|
||||
|
||||
logger.debug(f"Using migrations directory: {self.migrations_dir}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create migrations directory: {str(e)}")
|
||||
raise
|
||||
|
||||
def _init_router(self) -> Router:
|
||||
"""
|
||||
Initialize the peewee-migrate Router.
|
||||
|
||||
Returns:
|
||||
Router: Configured peewee-migrate Router instance
|
||||
"""
|
||||
try:
|
||||
router = Router(self.db, migrate_dir=self.migrations_dir, migrate_table=MIGRATIONS_TABLE)
|
||||
logger.debug(f"Initialized migration router with table: {MIGRATIONS_TABLE}")
|
||||
return router
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize migration router: {str(e)}")
|
||||
raise
|
||||
|
||||
def check_migrations(self) -> Tuple[List[str], List[str]]:
|
||||
"""
|
||||
Check for pending migrations.
|
||||
|
||||
Returns:
|
||||
Tuple[List[str], List[str]]: A tuple containing (applied_migrations, pending_migrations)
|
||||
"""
|
||||
try:
|
||||
# Get all migrations
|
||||
all_migrations = self.router.todo
|
||||
|
||||
# Get applied migrations
|
||||
applied = self.router.done
|
||||
|
||||
# Calculate pending migrations
|
||||
pending = [m for m in all_migrations if m not in applied]
|
||||
|
||||
logger.debug(f"Found {len(applied)} applied migrations and {len(pending)} pending migrations")
|
||||
return applied, pending
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check migrations: {str(e)}")
|
||||
return [], []
|
||||
|
||||
def apply_migrations(self, fake: bool = False) -> bool:
|
||||
"""
|
||||
Apply all pending migrations.
|
||||
|
||||
Args:
|
||||
fake: If True, mark migrations as applied without running them
|
||||
|
||||
Returns:
|
||||
bool: True if migrations were applied successfully, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Get pending migrations
|
||||
_, pending = self.check_migrations()
|
||||
|
||||
if not pending:
|
||||
logger.info("No pending migrations to apply")
|
||||
return True
|
||||
|
||||
logger.info(f"Applying {len(pending)} pending migrations...")
|
||||
|
||||
# Apply migrations
|
||||
for migration in pending:
|
||||
try:
|
||||
logger.info(f"Applying migration: {migration}")
|
||||
self.router.run(migration, fake=fake)
|
||||
logger.info(f"Successfully applied migration: {migration}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to apply migration {migration}: {str(e)}")
|
||||
return False
|
||||
|
||||
logger.info(f"Successfully applied {len(pending)} migrations")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to apply migrations: {str(e)}")
|
||||
return False
|
||||
|
||||
def create_migration(self, name: str, auto: bool = True) -> Optional[str]:
|
||||
"""
|
||||
Create a new migration.
|
||||
|
||||
Args:
|
||||
name: Name of the migration
|
||||
auto: If True, automatically detect model changes
|
||||
|
||||
Returns:
|
||||
Optional[str]: The name of the created migration, or None if creation failed
|
||||
"""
|
||||
try:
|
||||
# Sanitize migration name
|
||||
safe_name = name.replace(' ', '_').lower()
|
||||
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
migration_name = f"{timestamp}_{safe_name}"
|
||||
|
||||
logger.info(f"Creating new migration: {migration_name}")
|
||||
|
||||
# Create migration
|
||||
self.router.create(migration_name, auto=auto)
|
||||
|
||||
logger.info(f"Successfully created migration: {migration_name}")
|
||||
return migration_name
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create migration: {str(e)}")
|
||||
return None
|
||||
|
||||
def get_migration_status(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get the current migration status.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary containing migration status information
|
||||
"""
|
||||
applied, pending = self.check_migrations()
|
||||
|
||||
return {
|
||||
"applied_count": len(applied),
|
||||
"pending_count": len(pending),
|
||||
"applied": applied,
|
||||
"pending": pending,
|
||||
"migrations_dir": self.migrations_dir,
|
||||
"db_path": self.db_path,
|
||||
}
|
||||
|
||||
|
||||
def init_migrations(db_path: Optional[str] = None, migrations_dir: Optional[str] = None) -> MigrationManager:
|
||||
"""
|
||||
Initialize the migration manager.
|
||||
|
||||
Args:
|
||||
db_path: Optional path to the database file
|
||||
migrations_dir: Optional path to the migrations directory
|
||||
|
||||
Returns:
|
||||
MigrationManager: Initialized migration manager
|
||||
"""
|
||||
return MigrationManager(db_path, migrations_dir)
|
||||
|
||||
|
||||
def ensure_migrations_applied() -> bool:
|
||||
"""
|
||||
Check for and apply any pending migrations.
|
||||
|
||||
This function should be called during application startup to ensure
|
||||
the database schema is up to date.
|
||||
|
||||
Returns:
|
||||
bool: True if migrations were applied successfully or none were pending
|
||||
"""
|
||||
with DatabaseManager() as db:
|
||||
try:
|
||||
migration_manager = init_migrations()
|
||||
return migration_manager.apply_migrations()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to apply migrations: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
def create_new_migration(name: str, auto: bool = True) -> Optional[str]:
|
||||
"""
|
||||
Create a new migration with the given name.
|
||||
|
||||
Args:
|
||||
name: Name of the migration
|
||||
auto: If True, automatically detect model changes
|
||||
|
||||
Returns:
|
||||
Optional[str]: The name of the created migration, or None if creation failed
|
||||
"""
|
||||
with DatabaseManager() as db:
|
||||
try:
|
||||
migration_manager = init_migrations()
|
||||
return migration_manager.create_migration(name, auto)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create migration: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def get_migration_status() -> Dict[str, Any]:
|
||||
"""
|
||||
Get the current migration status.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary containing migration status information
|
||||
"""
|
||||
with DatabaseManager() as db:
|
||||
try:
|
||||
migration_manager = init_migrations()
|
||||
return migration_manager.get_migration_status()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get migration status: {str(e)}")
|
||||
return {
|
||||
"error": str(e),
|
||||
"applied_count": 0,
|
||||
"pending_count": 0,
|
||||
"applied": [],
|
||||
"pending": [],
|
||||
}
|
||||
|
|
@ -0,0 +1,547 @@
|
|||
"""
|
||||
Tests for the database migrations module.
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock, call, PropertyMock
|
||||
|
||||
import pytest
|
||||
import peewee
|
||||
from peewee_migrate import Router
|
||||
|
||||
from ra_aid.database.connection import DatabaseManager, db_var
|
||||
from ra_aid.database.migrations import (
|
||||
MigrationManager,
|
||||
init_migrations,
|
||||
ensure_migrations_applied,
|
||||
create_new_migration,
|
||||
get_migration_status,
|
||||
MIGRATIONS_DIRNAME,
|
||||
MIGRATIONS_TABLE
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cleanup_db():
|
||||
"""Reset the database contextvar and connection state after each test."""
|
||||
# Reset before the test
|
||||
db = db_var.get()
|
||||
if db is not None:
|
||||
try:
|
||||
if not db.is_closed():
|
||||
db.close()
|
||||
except Exception:
|
||||
# Ignore errors when closing the database
|
||||
pass
|
||||
db_var.set(None)
|
||||
|
||||
# Run the test
|
||||
yield
|
||||
|
||||
# Reset after the test
|
||||
db = db_var.get()
|
||||
if db is not None:
|
||||
try:
|
||||
if not db.is_closed():
|
||||
db.close()
|
||||
except Exception:
|
||||
# Ignore errors when closing the database
|
||||
pass
|
||||
db_var.set(None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_logger():
|
||||
"""Mock the logger to test for output messages."""
|
||||
with patch('ra_aid.database.migrations.logger') as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
"""Create a temporary directory for test files."""
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
yield temp_dir
|
||||
# Clean up
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_migrations_dir(temp_dir):
|
||||
"""Create a temporary migrations directory."""
|
||||
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
|
||||
os.makedirs(migrations_dir, exist_ok=True)
|
||||
# Create __init__.py to make it a proper package
|
||||
with open(os.path.join(migrations_dir, "__init__.py"), "w") as f:
|
||||
pass
|
||||
yield migrations_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_router():
|
||||
"""Mock the peewee_migrate Router class."""
|
||||
with patch('ra_aid.database.migrations.Router') as mock:
|
||||
# Configure the mock router
|
||||
mock_instance = MagicMock()
|
||||
mock.return_value = mock_instance
|
||||
|
||||
# Set up router properties
|
||||
mock_instance.todo = ["001_initial", "002_add_users"]
|
||||
mock_instance.done = ["001_initial"]
|
||||
|
||||
yield mock_instance
|
||||
|
||||
|
||||
class TestMigrationManager:
|
||||
"""Tests for the MigrationManager class."""
|
||||
|
||||
def test_init(self, cleanup_db, temp_dir, mock_logger):
|
||||
"""Test MigrationManager initialization."""
|
||||
# Set up test paths
|
||||
db_path = os.path.join(temp_dir, "test.db")
|
||||
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
|
||||
|
||||
# Initialize manager
|
||||
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
||||
|
||||
# Verify initialization
|
||||
assert manager.db_path == db_path
|
||||
assert manager.migrations_dir == migrations_dir
|
||||
assert os.path.exists(migrations_dir)
|
||||
assert os.path.exists(os.path.join(migrations_dir, "__init__.py"))
|
||||
|
||||
# Verify router initialization was logged
|
||||
mock_logger.debug.assert_any_call(f"Using migrations directory: {migrations_dir}")
|
||||
mock_logger.debug.assert_any_call(f"Initialized migration router with table: {MIGRATIONS_TABLE}")
|
||||
|
||||
def test_ensure_migrations_dir(self, cleanup_db, temp_dir, mock_logger):
|
||||
"""Test _ensure_migrations_dir creates directory if it doesn't exist."""
|
||||
# Set up test paths
|
||||
db_path = os.path.join(temp_dir, "test.db")
|
||||
migrations_dir = os.path.join(temp_dir, "nonexistent_dir", MIGRATIONS_DIRNAME)
|
||||
|
||||
# Initialize manager
|
||||
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
||||
|
||||
# Verify directory was created
|
||||
assert os.path.exists(migrations_dir)
|
||||
assert os.path.exists(os.path.join(migrations_dir, "__init__.py"))
|
||||
|
||||
# Verify creation was logged
|
||||
mock_logger.debug.assert_any_call(f"Creating migrations directory at: {migrations_dir}")
|
||||
|
||||
def test_ensure_migrations_dir_error(self, cleanup_db, mock_logger):
|
||||
"""Test _ensure_migrations_dir handles errors."""
|
||||
# Mock os.makedirs to raise an exception
|
||||
with patch('pathlib.Path.mkdir', side_effect=PermissionError("Permission denied")):
|
||||
# Set up test paths - use a path that would require elevated permissions
|
||||
db_path = "/root/test.db"
|
||||
migrations_dir = "/root/migrations"
|
||||
|
||||
# Initialize manager should raise an exception
|
||||
with pytest.raises(Exception):
|
||||
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
||||
|
||||
# Verify error was logged
|
||||
mock_logger.error.assert_called_with(
|
||||
f"Failed to create migrations directory: [Errno 13] Permission denied: '/root/migrations'"
|
||||
)
|
||||
|
||||
def test_init_router(self, cleanup_db, temp_dir, mock_router):
|
||||
"""Test _init_router initializes the Router correctly."""
|
||||
# Set up test paths
|
||||
db_path = os.path.join(temp_dir, "test.db")
|
||||
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
|
||||
|
||||
# Create the migrations directory
|
||||
os.makedirs(migrations_dir, exist_ok=True)
|
||||
|
||||
# Initialize manager with mocked Router
|
||||
with patch('ra_aid.database.migrations.Router', return_value=mock_router):
|
||||
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
||||
|
||||
# Verify router was initialized
|
||||
assert manager.router == mock_router
|
||||
|
||||
def test_check_migrations(self, cleanup_db, temp_dir, mock_router, mock_logger):
|
||||
"""Test check_migrations returns correct migration lists."""
|
||||
# Set up test paths
|
||||
db_path = os.path.join(temp_dir, "test.db")
|
||||
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
|
||||
|
||||
# Initialize manager with mocked Router
|
||||
with patch('ra_aid.database.migrations.Router', return_value=mock_router):
|
||||
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
||||
|
||||
# Call check_migrations
|
||||
applied, pending = manager.check_migrations()
|
||||
|
||||
# Verify results
|
||||
assert applied == ["001_initial"]
|
||||
assert pending == ["002_add_users"]
|
||||
|
||||
# Verify logging
|
||||
mock_logger.debug.assert_called_with(
|
||||
"Found 1 applied migrations and 1 pending migrations"
|
||||
)
|
||||
|
||||
def test_check_migrations_error(self, cleanup_db, temp_dir, mock_logger):
|
||||
"""Test check_migrations handles errors."""
|
||||
# Set up test paths
|
||||
db_path = os.path.join(temp_dir, "test.db")
|
||||
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
|
||||
|
||||
# Create a mock router with a property that raises an exception
|
||||
mock_router = MagicMock()
|
||||
# Configure the todo property to raise an exception when accessed
|
||||
type(mock_router).todo = PropertyMock(side_effect=Exception("Test error"))
|
||||
mock_router.done = []
|
||||
|
||||
# Initialize manager with the mocked Router
|
||||
with patch('ra_aid.database.migrations.Router', return_value=mock_router):
|
||||
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
||||
|
||||
# Directly call check_migrations on the manager with the mocked router
|
||||
applied, pending = manager.check_migrations()
|
||||
|
||||
# Verify empty results are returned on error
|
||||
assert applied == []
|
||||
assert pending == []
|
||||
|
||||
# Verify error was logged
|
||||
mock_logger.error.assert_called_with("Failed to check migrations: Test error")
|
||||
|
||||
def test_apply_migrations(self, cleanup_db, temp_dir, mock_router, mock_logger):
|
||||
"""Test apply_migrations applies pending migrations."""
|
||||
# Set up test paths
|
||||
db_path = os.path.join(temp_dir, "test.db")
|
||||
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
|
||||
|
||||
# Initialize manager with mocked Router
|
||||
with patch('ra_aid.database.migrations.Router', return_value=mock_router):
|
||||
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
||||
|
||||
# Call apply_migrations
|
||||
result = manager.apply_migrations()
|
||||
|
||||
# Verify result
|
||||
assert result is True
|
||||
|
||||
# Verify migrations were applied
|
||||
mock_router.run.assert_called_once_with("002_add_users", fake=False)
|
||||
|
||||
# Verify logging
|
||||
mock_logger.info.assert_any_call("Applying 1 pending migrations...")
|
||||
mock_logger.info.assert_any_call("Applying migration: 002_add_users")
|
||||
mock_logger.info.assert_any_call("Successfully applied migration: 002_add_users")
|
||||
mock_logger.info.assert_any_call("Successfully applied 1 migrations")
|
||||
|
||||
def test_apply_migrations_no_pending(self, cleanup_db, temp_dir, mock_logger):
|
||||
"""Test apply_migrations when no migrations are pending."""
|
||||
# Set up test paths
|
||||
db_path = os.path.join(temp_dir, "test.db")
|
||||
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
|
||||
|
||||
# Create a mock router with no pending migrations
|
||||
mock_router = MagicMock()
|
||||
mock_router.todo = ["001_initial"]
|
||||
mock_router.done = ["001_initial"]
|
||||
|
||||
# Initialize manager with mocked Router
|
||||
with patch('ra_aid.database.migrations.Router', return_value=mock_router):
|
||||
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
||||
|
||||
# Call apply_migrations
|
||||
result = manager.apply_migrations()
|
||||
|
||||
# Verify result
|
||||
assert result is True
|
||||
|
||||
# Verify no migrations were applied
|
||||
mock_router.run.assert_not_called()
|
||||
|
||||
# Verify logging
|
||||
mock_logger.info.assert_called_with("No pending migrations to apply")
|
||||
|
||||
def test_apply_migrations_error(self, cleanup_db, temp_dir, mock_logger):
|
||||
"""Test apply_migrations handles errors during migration."""
|
||||
# Set up test paths
|
||||
db_path = os.path.join(temp_dir, "test.db")
|
||||
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
|
||||
|
||||
# Create a mock router that raises an exception during run
|
||||
mock_router = MagicMock()
|
||||
mock_router.todo = ["001_initial", "002_add_users"]
|
||||
mock_router.done = ["001_initial"]
|
||||
mock_router.run.side_effect = Exception("Migration error")
|
||||
|
||||
# Initialize manager with mocked Router
|
||||
with patch('ra_aid.database.migrations.Router', return_value=mock_router):
|
||||
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
||||
|
||||
# Call apply_migrations
|
||||
result = manager.apply_migrations()
|
||||
|
||||
# Verify result
|
||||
assert result is False
|
||||
|
||||
# Verify error was logged
|
||||
mock_logger.error.assert_called_with(
|
||||
"Failed to apply migration 002_add_users: Migration error"
|
||||
)
|
||||
|
||||
def test_create_migration(self, cleanup_db, temp_dir, mock_router, mock_logger):
|
||||
"""Test create_migration creates a new migration."""
|
||||
# Set up test paths
|
||||
db_path = os.path.join(temp_dir, "test.db")
|
||||
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
|
||||
|
||||
# Initialize manager with mocked Router
|
||||
with patch('ra_aid.database.migrations.Router', return_value=mock_router):
|
||||
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
||||
|
||||
# Call create_migration
|
||||
result = manager.create_migration("add_users", auto=True)
|
||||
|
||||
# Verify result contains timestamp and name
|
||||
assert result is not None
|
||||
assert "add_users" in result
|
||||
|
||||
# Verify migration was created
|
||||
mock_router.create.assert_called_once()
|
||||
|
||||
# Verify logging
|
||||
mock_logger.info.assert_any_call(f"Creating new migration: {result}")
|
||||
mock_logger.info.assert_any_call(f"Successfully created migration: {result}")
|
||||
|
||||
def test_create_migration_error(self, cleanup_db, temp_dir, mock_logger):
|
||||
"""Test create_migration handles errors."""
|
||||
# Set up test paths
|
||||
db_path = os.path.join(temp_dir, "test.db")
|
||||
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
|
||||
|
||||
# Create a mock router that raises an exception during create
|
||||
mock_router = MagicMock()
|
||||
mock_router.create.side_effect = Exception("Creation error")
|
||||
|
||||
# Initialize manager with mocked Router
|
||||
with patch('ra_aid.database.migrations.Router', return_value=mock_router):
|
||||
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
||||
|
||||
# Call create_migration
|
||||
result = manager.create_migration("add_users", auto=True)
|
||||
|
||||
# Verify result is None on error
|
||||
assert result is None
|
||||
|
||||
# Verify error was logged
|
||||
mock_logger.error.assert_called_with("Failed to create migration: Creation error")
|
||||
|
||||
def test_get_migration_status(self, cleanup_db, temp_dir, mock_router):
|
||||
"""Test get_migration_status returns correct status information."""
|
||||
# Set up test paths
|
||||
db_path = os.path.join(temp_dir, "test.db")
|
||||
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
|
||||
|
||||
# Initialize manager with mocked Router
|
||||
with patch('ra_aid.database.migrations.Router', return_value=mock_router):
|
||||
manager = MigrationManager(db_path=db_path, migrations_dir=migrations_dir)
|
||||
|
||||
# Call get_migration_status
|
||||
status = manager.get_migration_status()
|
||||
|
||||
# Verify status information
|
||||
assert status["applied_count"] == 1
|
||||
assert status["pending_count"] == 1
|
||||
assert status["applied"] == ["001_initial"]
|
||||
assert status["pending"] == ["002_add_users"]
|
||||
assert status["migrations_dir"] == migrations_dir
|
||||
assert status["db_path"] == db_path
|
||||
|
||||
|
||||
class TestMigrationFunctions:
|
||||
"""Tests for the migration utility functions."""
|
||||
|
||||
def test_init_migrations(self, cleanup_db, temp_dir):
|
||||
"""Test init_migrations returns a MigrationManager instance."""
|
||||
# Set up test paths
|
||||
db_path = os.path.join(temp_dir, "test.db")
|
||||
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
|
||||
|
||||
# Call init_migrations
|
||||
with patch('ra_aid.database.migrations.MigrationManager') as mock_manager:
|
||||
mock_manager.return_value = MagicMock()
|
||||
|
||||
manager = init_migrations(db_path=db_path, migrations_dir=migrations_dir)
|
||||
|
||||
# Verify MigrationManager was initialized with correct parameters
|
||||
mock_manager.assert_called_once_with(db_path, migrations_dir)
|
||||
assert manager == mock_manager.return_value
|
||||
|
||||
def test_ensure_migrations_applied(self, cleanup_db, mock_logger):
|
||||
"""Test ensure_migrations_applied applies pending migrations."""
|
||||
# Mock MigrationManager
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.apply_migrations.return_value = True
|
||||
|
||||
# Call ensure_migrations_applied
|
||||
with patch('ra_aid.database.migrations.init_migrations', return_value=mock_manager):
|
||||
result = ensure_migrations_applied()
|
||||
|
||||
# Verify result
|
||||
assert result is True
|
||||
|
||||
# Verify migrations were applied
|
||||
mock_manager.apply_migrations.assert_called_once()
|
||||
|
||||
def test_ensure_migrations_applied_error(self, cleanup_db, mock_logger):
|
||||
"""Test ensure_migrations_applied handles errors."""
|
||||
# Call ensure_migrations_applied with an exception
|
||||
with patch('ra_aid.database.migrations.init_migrations',
|
||||
side_effect=Exception("Test error")):
|
||||
result = ensure_migrations_applied()
|
||||
|
||||
# Verify result is False on error
|
||||
assert result is False
|
||||
|
||||
# Verify error was logged
|
||||
mock_logger.error.assert_called_with("Failed to apply migrations: Test error")
|
||||
|
||||
def test_create_new_migration(self, cleanup_db, mock_logger):
|
||||
"""Test create_new_migration creates a new migration."""
|
||||
# Mock MigrationManager
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.create_migration.return_value = "20250226_123456_test_migration"
|
||||
|
||||
# Call create_new_migration
|
||||
with patch('ra_aid.database.migrations.init_migrations', return_value=mock_manager):
|
||||
result = create_new_migration("test_migration", auto=True)
|
||||
|
||||
# Verify result
|
||||
assert result == "20250226_123456_test_migration"
|
||||
|
||||
# Verify migration was created
|
||||
mock_manager.create_migration.assert_called_once_with("test_migration", True)
|
||||
|
||||
def test_create_new_migration_error(self, cleanup_db, mock_logger):
|
||||
"""Test create_new_migration handles errors."""
|
||||
# Call create_new_migration with an exception
|
||||
with patch('ra_aid.database.migrations.init_migrations',
|
||||
side_effect=Exception("Test error")):
|
||||
result = create_new_migration("test_migration", auto=True)
|
||||
|
||||
# Verify result is None on error
|
||||
assert result is None
|
||||
|
||||
# Verify error was logged
|
||||
mock_logger.error.assert_called_with("Failed to create migration: Test error")
|
||||
|
||||
def test_get_migration_status(self, cleanup_db, mock_logger):
|
||||
"""Test get_migration_status returns correct status information."""
|
||||
# Mock MigrationManager
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.get_migration_status.return_value = {
|
||||
"applied_count": 2,
|
||||
"pending_count": 1,
|
||||
"applied": ["001_initial", "002_add_users"],
|
||||
"pending": ["003_add_profiles"],
|
||||
"migrations_dir": "/test/migrations",
|
||||
"db_path": "/test/db.sqlite"
|
||||
}
|
||||
|
||||
# Call get_migration_status
|
||||
with patch('ra_aid.database.migrations.init_migrations', return_value=mock_manager):
|
||||
status = get_migration_status()
|
||||
|
||||
# Verify status information
|
||||
assert status["applied_count"] == 2
|
||||
assert status["pending_count"] == 1
|
||||
assert status["applied"] == ["001_initial", "002_add_users"]
|
||||
assert status["pending"] == ["003_add_profiles"]
|
||||
assert status["migrations_dir"] == "/test/migrations"
|
||||
assert status["db_path"] == "/test/db.sqlite"
|
||||
|
||||
# Verify migration status was retrieved
|
||||
mock_manager.get_migration_status.assert_called_once()
|
||||
|
||||
def test_get_migration_status_error(self, cleanup_db, mock_logger):
|
||||
"""Test get_migration_status handles errors."""
|
||||
# Call get_migration_status with an exception
|
||||
with patch('ra_aid.database.migrations.init_migrations',
|
||||
side_effect=Exception("Test error")):
|
||||
status = get_migration_status()
|
||||
|
||||
# Verify default status on error
|
||||
assert status["error"] == "Test error"
|
||||
assert status["applied_count"] == 0
|
||||
assert status["pending_count"] == 0
|
||||
assert status["applied"] == []
|
||||
assert status["pending"] == []
|
||||
|
||||
# Verify error was logged
|
||||
mock_logger.error.assert_called_with("Failed to get migration status: Test error")
|
||||
|
||||
|
||||
class TestIntegration:
|
||||
"""Integration tests for the migrations module."""
|
||||
|
||||
def test_in_memory_migrations(self, cleanup_db):
|
||||
"""Test migrations with in-memory database."""
|
||||
# Initialize in-memory database
|
||||
with DatabaseManager(in_memory=True) as db:
|
||||
# Create a temporary migrations directory
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
migrations_dir = os.path.join(temp_dir, MIGRATIONS_DIRNAME)
|
||||
os.makedirs(migrations_dir, exist_ok=True)
|
||||
|
||||
# Create __init__.py to make it a proper package
|
||||
with open(os.path.join(migrations_dir, "__init__.py"), "w") as f:
|
||||
pass
|
||||
|
||||
# Initialize migration manager
|
||||
manager = MigrationManager(db_path=":memory:", migrations_dir=migrations_dir)
|
||||
|
||||
# Create a test migration
|
||||
migration_name = manager.create_migration("test_migration", auto=False)
|
||||
|
||||
# Write a simple migration file
|
||||
migration_path = os.path.join(migrations_dir, f"{migration_name}.py")
|
||||
with open(migration_path, "w") as f:
|
||||
f.write("""
|
||||
def migrate(migrator, database, fake=False, **kwargs):
|
||||
migrator.create_table('test_table', (
|
||||
('id', 'INTEGER', {'primary_key': True}),
|
||||
('name', 'STRING', {'null': False}),
|
||||
))
|
||||
|
||||
def rollback(migrator, database, fake=False, **kwargs):
|
||||
migrator.drop_table('test_table')
|
||||
""")
|
||||
|
||||
# Check migrations
|
||||
applied, pending = manager.check_migrations()
|
||||
assert len(applied) == 0
|
||||
assert len(pending) == 1
|
||||
assert migration_name in pending[0] # Instead of exact equality, check if name is contained
|
||||
|
||||
# Apply migrations
|
||||
result = manager.apply_migrations()
|
||||
assert result is True
|
||||
|
||||
# Check migrations again
|
||||
applied, pending = manager.check_migrations()
|
||||
assert len(applied) == 1
|
||||
assert len(pending) == 0
|
||||
assert migration_name in applied[0] # Instead of exact equality, check if name is contained
|
||||
|
||||
# Verify migration status
|
||||
status = manager.get_migration_status()
|
||||
assert status["applied_count"] == 1
|
||||
assert status["pending_count"] == 0
|
||||
# Use substring check for applied migrations
|
||||
assert len(status["applied"]) == 1
|
||||
assert migration_name in status["applied"][0]
|
||||
assert status["pending"] == []
|
||||
Loading…
Reference in New Issue