217 lines
6.4 KiB
Python
217 lines
6.4 KiB
Python
"""
|
|
Tests for the database utils module.
|
|
"""
|
|
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import peewee
|
|
import pytest
|
|
|
|
from ra_aid.database.connection import db_var, init_db
|
|
from ra_aid.database.models import BaseModel
|
|
from ra_aid.database.utils import ensure_tables_created, get_model_count, truncate_table
|
|
|
|
|
|
@pytest.fixture
|
|
def cleanup_db():
|
|
"""Reset the database contextvar and connection state after each test."""
|
|
# Reset before the test
|
|
db = db_var.get()
|
|
if db is not None:
|
|
try:
|
|
if not db.is_closed():
|
|
db.close()
|
|
except Exception:
|
|
# Ignore errors when closing the database
|
|
pass
|
|
db_var.set(None)
|
|
|
|
# Run the test
|
|
yield
|
|
|
|
# Reset after the test
|
|
db = db_var.get()
|
|
if db is not None:
|
|
try:
|
|
if not db.is_closed():
|
|
db.close()
|
|
except Exception:
|
|
# Ignore errors when closing the database
|
|
pass
|
|
db_var.set(None)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_logger():
|
|
"""Mock the logger to test for output messages."""
|
|
with patch("ra_aid.database.utils.logger") as mock:
|
|
yield mock
|
|
|
|
|
|
@pytest.fixture
|
|
def setup_test_model(cleanup_db):
|
|
"""Set up a test model for database tests."""
|
|
# Initialize the database in memory
|
|
db = init_db(in_memory=True)
|
|
|
|
# Initialize the database proxy
|
|
from ra_aid.database.models import initialize_database
|
|
initialize_database()
|
|
|
|
# Define a test model class
|
|
class TestModel(BaseModel):
|
|
name = peewee.CharField(max_length=100)
|
|
value = peewee.IntegerField(default=0)
|
|
|
|
# Create the test table in a transaction
|
|
with db.atomic():
|
|
db.create_tables([TestModel], safe=True)
|
|
|
|
# Yield control to the test
|
|
yield TestModel
|
|
|
|
# Clean up: drop the test table
|
|
with db.atomic():
|
|
db.drop_tables([TestModel], safe=True)
|
|
|
|
|
|
def test_ensure_tables_created_with_models(cleanup_db, mock_logger):
|
|
"""Test ensure_tables_created with explicit models."""
|
|
# Initialize the database in memory
|
|
db = init_db(in_memory=True)
|
|
|
|
# Initialize the database proxy
|
|
from ra_aid.database.models import initialize_database
|
|
initialize_database()
|
|
|
|
# Define a test model that uses the proxy database
|
|
class TestModel(BaseModel):
|
|
name = peewee.CharField(max_length=100)
|
|
value = peewee.IntegerField(default=0)
|
|
|
|
# Call ensure_tables_created with explicit models
|
|
ensure_tables_created([TestModel])
|
|
|
|
# Verify success message was logged
|
|
mock_logger.info.assert_called_with("Successfully created tables for 1 models")
|
|
|
|
# Verify the table exists by trying to use it
|
|
TestModel.create(name="test", value=42)
|
|
count = TestModel.select().count()
|
|
assert count == 1
|
|
|
|
|
|
@patch("ra_aid.database.utils.initialize_database")
|
|
def test_ensure_tables_created_database_error(
|
|
mock_initialize_database, setup_test_model, cleanup_db, mock_logger
|
|
):
|
|
"""Test ensure_tables_created handles database errors."""
|
|
# Get the TestModel class from the fixture
|
|
TestModel = setup_test_model
|
|
|
|
# Create a mock database with a create_tables method that raises an error
|
|
mock_db = MagicMock()
|
|
mock_db.atomic.return_value.__enter__.return_value = None
|
|
mock_db.atomic.return_value.__exit__.return_value = None
|
|
mock_db.create_tables.side_effect = peewee.DatabaseError("Test database error")
|
|
|
|
# Configure initialize_database to return our mock
|
|
mock_initialize_database.return_value = mock_db
|
|
|
|
# Call ensure_tables_created and expect an exception
|
|
with pytest.raises(peewee.DatabaseError):
|
|
ensure_tables_created([TestModel])
|
|
|
|
# Verify error message was logged
|
|
mock_logger.error.assert_called_with(
|
|
"Database Error: Failed to create tables: Test database error"
|
|
)
|
|
|
|
|
|
def test_get_model_count(setup_test_model, mock_logger):
|
|
"""Test get_model_count returns the correct count."""
|
|
# Get the TestModel class from the fixture
|
|
TestModel = setup_test_model
|
|
|
|
# First ensure the table is empty
|
|
TestModel.delete().execute()
|
|
|
|
# Create some test records
|
|
TestModel.create(name="test1", value=1)
|
|
TestModel.create(name="test2", value=2)
|
|
|
|
# Call get_model_count
|
|
count = get_model_count(TestModel)
|
|
|
|
# Verify the count is correct
|
|
assert count == 2
|
|
|
|
|
|
@patch("peewee.ModelSelect.count")
|
|
def test_get_model_count_database_error(mock_count, setup_test_model, mock_logger):
|
|
"""Test get_model_count handles database errors."""
|
|
# Get the TestModel class from the fixture
|
|
TestModel = setup_test_model
|
|
|
|
# Configure the mock to raise a DatabaseError
|
|
mock_count.side_effect = peewee.DatabaseError("Test count error")
|
|
|
|
# Call get_model_count
|
|
count = get_model_count(TestModel)
|
|
|
|
# Verify error message was logged
|
|
mock_logger.error.assert_called_with(
|
|
"Database Error: Failed to count records: Test count error"
|
|
)
|
|
|
|
# Verify the function returns 0 on error
|
|
assert count == 0
|
|
|
|
|
|
def test_truncate_table(setup_test_model, mock_logger):
|
|
"""Test truncate_table deletes all records."""
|
|
# Get the TestModel class from the fixture
|
|
TestModel = setup_test_model
|
|
|
|
# Create some test records
|
|
TestModel.create(name="test1", value=1)
|
|
TestModel.create(name="test2", value=2)
|
|
|
|
# Verify records exist
|
|
assert TestModel.select().count() == 2
|
|
|
|
# Call truncate_table
|
|
truncate_table(TestModel)
|
|
|
|
# Verify success message was logged
|
|
mock_logger.info.assert_called_with(
|
|
f"Successfully truncated table for {TestModel.__name__}"
|
|
)
|
|
|
|
# Verify all records were deleted
|
|
assert TestModel.select().count() == 0
|
|
|
|
|
|
@patch("ra_aid.database.models.BaseModel.delete")
|
|
def test_truncate_table_database_error(mock_delete, setup_test_model, mock_logger):
|
|
"""Test truncate_table handles database errors."""
|
|
# Get the TestModel class from the fixture
|
|
TestModel = setup_test_model
|
|
|
|
# Create a test record
|
|
TestModel.create(name="test", value=42)
|
|
|
|
# Configure the mock to return a mock query with execute that raises an error
|
|
mock_query = MagicMock()
|
|
mock_query.execute.side_effect = peewee.DatabaseError("Test delete error")
|
|
mock_delete.return_value = mock_query
|
|
|
|
# Call truncate_table and expect an exception
|
|
with pytest.raises(peewee.DatabaseError):
|
|
truncate_table(TestModel)
|
|
|
|
# Verify error message was logged
|
|
mock_logger.error.assert_called_with(
|
|
"Database Error: Failed to truncate table: Test delete error"
|
|
)
|